the document classifier is now stateless

This commit is contained in:
Jonas Winkler 2020-10-29 14:33:42 +01:00
parent 3e50e51b8a
commit 05f20c19c3
4 changed files with 12 additions and 19 deletions

View File

@ -34,7 +34,6 @@ class DocumentClassifier(object):
self.tags_classifier = None self.tags_classifier = None
self.correspondent_classifier = None self.correspondent_classifier = None
self.document_type_classifier = None self.document_type_classifier = None
self.X = None
def reload(self): def reload(self):
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
@ -167,14 +166,10 @@ class DocumentClassifier(object):
"classifier." "classifier."
) )
def update(self, document): def predict_correspondent(self, content):
self.X = self.data_vectorizer.transform(
[preprocess_content(document.content)]
)
def predict_correspondent(self):
if self.correspondent_classifier: if self.correspondent_classifier:
y = self.correspondent_classifier.predict(self.X) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.correspondent_classifier.predict(X)
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
if correspondent_id != -1: if correspondent_id != -1:
return correspondent_id return correspondent_id
@ -183,9 +178,10 @@ class DocumentClassifier(object):
else: else:
return None return None
def predict_document_type(self): def predict_document_type(self, content):
if self.document_type_classifier: if self.document_type_classifier:
y = self.document_type_classifier.predict(self.X) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.document_type_classifier.predict(X)
document_type_id = self.document_type_binarizer.inverse_transform(y)[0] document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
if document_type_id != -1: if document_type_id != -1:
return document_type_id return document_type_id
@ -194,9 +190,10 @@ class DocumentClassifier(object):
else: else:
return None return None
def predict_tags(self): def predict_tags(self, content):
if self.tags_classifier: if self.tags_classifier:
y = self.tags_classifier.predict(self.X) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y)[0] tags_ids = self.tags_binarizer.inverse_transform(y)[0]
return tags_ids return tags_ids
else: else:

View File

@ -182,7 +182,6 @@ class Consumer:
try: try:
self.classifier.reload() self.classifier.reload()
self.classifier.update(document)
classifier = self.classifier classifier = self.classifier
except FileNotFoundError: except FileNotFoundError:
logging.getLogger(__name__).warning("Cannot classify documents, " logging.getLogger(__name__).warning("Cannot classify documents, "

View File

@ -84,9 +84,6 @@ class Command(Renderable, BaseCommand):
"Processing document {}".format(document.title) "Processing document {}".format(document.title)
) )
if classifier:
classifier.update(document)
if options['correspondent']: if options['correspondent']:
set_correspondent( set_correspondent(
sender=None, sender=None,

View File

@ -7,7 +7,7 @@ from documents.models import MatchingModel, Correspondent, DocumentType, Tag
def match_correspondents(document_content, classifier): def match_correspondents(document_content, classifier):
correspondents = Correspondent.objects.all() correspondents = Correspondent.objects.all()
predicted_correspondent_id = classifier.predict_correspondent() if classifier else None predicted_correspondent_id = classifier.predict_correspondent(document_content) if classifier else None
matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id] matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id]
return matched_correspondents return matched_correspondents
@ -15,7 +15,7 @@ def match_correspondents(document_content, classifier):
def match_document_types(document_content, classifier): def match_document_types(document_content, classifier):
document_types = DocumentType.objects.all() document_types = DocumentType.objects.all()
predicted_document_type_id = classifier.predict_document_type() if classifier else None predicted_document_type_id = classifier.predict_document_type(document_content) if classifier else None
matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id] matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id]
return matched_document_types return matched_document_types
@ -23,7 +23,7 @@ def match_document_types(document_content, classifier):
def match_tags(document_content, classifier): def match_tags(document_content, classifier):
objects = Tag.objects.all() objects = Tag.objects.all()
predicted_tag_ids = classifier.predict_tags() if classifier else [] predicted_tag_ids = classifier.predict_tags(document_content) if classifier else []
matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids] matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids]
return matched_tags return matched_tags