diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 1eefe76b3..52c508655 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -34,7 +34,6 @@ class DocumentClassifier(object): self.tags_classifier = None self.correspondent_classifier = None self.document_type_classifier = None - self.X = None def reload(self): if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: @@ -167,14 +166,10 @@ class DocumentClassifier(object): "classifier." ) - def update(self, document): - self.X = self.data_vectorizer.transform( - [preprocess_content(document.content)] - ) - - def predict_correspondent(self): + def predict_correspondent(self, content): if self.correspondent_classifier: - y = self.correspondent_classifier.predict(self.X) + X = self.data_vectorizer.transform([preprocess_content(content)]) + y = self.correspondent_classifier.predict(X) correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] if correspondent_id != -1: return correspondent_id @@ -183,9 +178,10 @@ class DocumentClassifier(object): else: return None - def predict_document_type(self): + def predict_document_type(self, content): if self.document_type_classifier: - y = self.document_type_classifier.predict(self.X) + X = self.data_vectorizer.transform([preprocess_content(content)]) + y = self.document_type_classifier.predict(X) document_type_id = self.document_type_binarizer.inverse_transform(y)[0] if document_type_id != -1: return document_type_id @@ -194,9 +190,10 @@ class DocumentClassifier(object): else: return None - def predict_tags(self): + def predict_tags(self, content): if self.tags_classifier: - y = self.tags_classifier.predict(self.X) + X = self.data_vectorizer.transform([preprocess_content(content)]) + y = self.tags_classifier.predict(X) tags_ids = self.tags_binarizer.inverse_transform(y)[0] return tags_ids else: diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 65889c6d9..401ef0ff0 100755 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -182,7 +182,6 @@ class Consumer: try: self.classifier.reload() - self.classifier.update(document) classifier = self.classifier except FileNotFoundError: logging.getLogger(__name__).warning("Cannot classify documents, " diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 040aa2b6e..007286935 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -84,9 +84,6 @@ class Command(Renderable, BaseCommand): "Processing document {}".format(document.title) ) - if classifier: - classifier.update(document) - if options['correspondent']: set_correspondent( sender=None, diff --git a/src/documents/matching.py b/src/documents/matching.py index 6fbb8efd6..a52f06a06 100644 --- a/src/documents/matching.py +++ b/src/documents/matching.py @@ -7,7 +7,7 @@ from documents.models import MatchingModel, Correspondent, DocumentType, Tag def match_correspondents(document_content, classifier): correspondents = Correspondent.objects.all() - predicted_correspondent_id = classifier.predict_correspondent() if classifier else None + predicted_correspondent_id = classifier.predict_correspondent(document_content) if classifier else None matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id] return matched_correspondents @@ -15,7 +15,7 @@ def match_correspondents(document_content, classifier): def match_document_types(document_content, classifier): document_types = DocumentType.objects.all() - predicted_document_type_id = classifier.predict_document_type() if classifier else None + predicted_document_type_id = classifier.predict_document_type(document_content) if classifier else None matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id] return matched_document_types @@ -23,7 +23,7 @@ def match_document_types(document_content, classifier): def match_tags(document_content, classifier): objects = Tag.objects.all() - predicted_tag_ids = classifier.predict_tags() if classifier else [] + predicted_tag_ids = classifier.predict_tags(document_content) if classifier else [] matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids] return matched_tags