From d46ee11143c83684d8894812c2de7ec275b6b232 Mon Sep 17 00:00:00 2001 From: Jonas Winkler Date: Tue, 11 Sep 2018 14:30:18 +0200 Subject: [PATCH] The classifier works with ids now, not names. Minor changes. --- src/documents/actions.py | 4 +- src/documents/apps.py | 2 + src/documents/classifier.py | 76 +++++++++++-------- .../management/commands/document_retagger.py | 7 +- src/documents/signals/handlers.py | 7 +- 5 files changed, 58 insertions(+), 38 deletions(-) diff --git a/src/documents/actions.py b/src/documents/actions.py index 1ee05cecb..9ad691ca2 100755 --- a/src/documents/actions.py +++ b/src/documents/actions.py @@ -239,9 +239,9 @@ def run_document_classifier_on_selected(modeladmin, request, queryset): n = queryset.count() if n: for obj in queryset: - clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_type=True, replace_tags=True) + clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_document_type=True, replace_tags=True) modeladmin.log_change(request, obj, str(obj)) - modeladmin.message_user(request, "Successfully applied tags, correspondent and type to %(count)d %(items)s." % { + modeladmin.message_user(request, "Successfully applied tags, correspondent and document type to %(count)d %(items)s." % { "count": n, "items": model_ngettext(modeladmin.opts, n) }, messages.SUCCESS) diff --git a/src/documents/apps.py b/src/documents/apps.py index 16a50843b..1d6574e67 100755 --- a/src/documents/apps.py +++ b/src/documents/apps.py @@ -12,6 +12,7 @@ class DocumentsConfig(AppConfig): from .signals import document_consumption_finished from .signals.handlers import ( classify_document, + add_inbox_tags, run_pre_consume_script, run_post_consume_script, cleanup_document_deletion, @@ -21,6 +22,7 @@ class DocumentsConfig(AppConfig): document_consumption_started.connect(run_pre_consume_script) document_consumption_finished.connect(classify_document) + document_consumption_finished.connect(add_inbox_tags) document_consumption_finished.connect(set_log_entry) document_consumption_finished.connect(run_post_consume_script) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 99507d41b..1fc46a3f1 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -8,7 +8,6 @@ from documents.models import Correspondent, DocumentType, Tag, Document from paperless import settings from sklearn.feature_extraction.text import CountVectorizer -from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer @@ -30,11 +29,11 @@ class DocumentClassifier(object): tags_binarizer = None correspondent_binarizer = None - type_binarizer = None + document_type_binarizer = None tags_classifier = None correspondent_classifier = None - type_classifier = None + document_type_classifier = None @staticmethod def load_classifier(): @@ -49,11 +48,11 @@ class DocumentClassifier(object): self.data_vectorizer = pickle.load(f) self.tags_binarizer = pickle.load(f) self.correspondent_binarizer = pickle.load(f) - self.type_binarizer = pickle.load(f) + self.document_type_binarizer = pickle.load(f) self.tags_classifier = pickle.load(f) self.correspondent_classifier = pickle.load(f) - self.type_classifier = pickle.load(f) + self.document_type_classifier = pickle.load(f) self.classifier_version = os.path.getmtime(settings.MODEL_FILE) def save_classifier(self): @@ -62,29 +61,29 @@ class DocumentClassifier(object): pickle.dump(self.tags_binarizer, f) pickle.dump(self.correspondent_binarizer, f) - pickle.dump(self.type_binarizer, f) + pickle.dump(self.document_type_binarizer, f) pickle.dump(self.tags_classifier, f) pickle.dump(self.correspondent_classifier, f) - pickle.dump(self.type_classifier, f) + pickle.dump(self.document_type_classifier, f) def train(self): data = list() labels_tags = list() labels_correspondent = list() - labels_type = list() + labels_document_type = list() # Step 1: Extract and preprocess training data from the database. logging.getLogger(__name__).info("Gathering data from database...") for doc in Document.objects.exclude(tags__is_inbox_tag=True): data.append(preprocess_content(doc.content)) - labels_type.append(doc.document_type.name if doc.document_type is not None and doc.document_type.automatic_classification else "-") - labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None and doc.correspondent.automatic_classification else "-") - tags = [tag.name for tag in doc.tags.filter(automatic_classification=True)] + labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1) + labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1) + tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)] labels_tags.append(tags) labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) - logging.getLogger(__name__).info("{} documents, {} tag(s) {}, {} correspondent(s) {}, {} type(s) {}.".format(len(data), len(labels_tags_unique), labels_tags_unique, len(set(labels_correspondent)), set(labels_correspondent), len(set(labels_type)), set(labels_type))) + logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type)))) # Step 2: vectorize data logging.getLogger(__name__).info("Vectorizing data...") @@ -97,8 +96,8 @@ class DocumentClassifier(object): self.correspondent_binarizer = LabelBinarizer() labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) - self.type_binarizer = LabelBinarizer() - labels_type_vectorized = self.type_binarizer.fit_transform(labels_type) + self.document_type_binarizer = LabelBinarizer() + labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type) # Step 3: train the classifiers if len(self.tags_binarizer.classes_) > 0: @@ -117,39 +116,52 @@ class DocumentClassifier(object): self.correspondent_classifier = None logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") - if len(self.type_binarizer.classes_) > 0: + if len(self.document_type_binarizer.classes_) > 0: logging.getLogger(__name__).info("Training document type classifier...") - self.type_classifier = MLPClassifier(verbose=True) - self.type_classifier.fit(data_vectorized, labels_type_vectorized) + self.document_type_classifier = MLPClassifier(verbose=True) + self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized) else: - self.type_classifier = None + self.document_type_classifier = None logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") - def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False, replace_tags=False): + def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False): X = self.data_vectorizer.transform([preprocess_content(document.content)]) update_fields=() if classify_correspondent and self.correspondent_classifier is not None: y_correspondent = self.correspondent_classifier.predict(X) - correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] - print("Detected correspondent:", correspondent) - document.correspondent = Correspondent.objects.filter(name=correspondent).first() - update_fields = update_fields + ("correspondent",) + correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] + try: + correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None + logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-")) + document.correspondent = correspondent + update_fields = update_fields + ("correspondent",) + except Correspondent.DoesNotExist: + logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id)) - if classify_type and self.type_classifier is not None: - y_type = self.type_classifier.predict(X) - type = self.type_binarizer.inverse_transform(y_type)[0] - print("Detected document type:", type) - document.document_type = DocumentType.objects.filter(name=type).first() - update_fields = update_fields + ("document_type",) + if classify_document_type and self.document_type_classifier is not None: + y_type = self.document_type_classifier.predict(X) + type_id = self.document_type_binarizer.inverse_transform(y_type)[0] + try: + document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None + logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-")) + document.document_type = document_type + update_fields = update_fields + ("document_type",) + except DocumentType.DoesNotExist: + logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id)) if classify_tags and self.tags_classifier is not None: y_tags = self.tags_classifier.predict(X) - tags = self.tags_binarizer.inverse_transform(y_tags)[0] - print("Detected tags:", tags) + tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0] if replace_tags: document.tags.clear() - document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags]) + for tag_id in tags_ids: + try: + tag = Tag.objects.get(id=tag_id) + document.tags.add(tag) + logging.getLogger(__name__).info("Detected tag: {}".format(tag.name)) + except Tag.DoesNotExist: + logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id)) document.save(update_fields=update_fields) diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 5366cd193..6a1e44b4c 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -35,6 +35,10 @@ class Command(Renderable, BaseCommand): "-i", "--inbox-only", action="store_true" ) + parser.add_argument( + "-r", "--replace-tags", + action="store_true" + ) def handle(self, *args, **options): @@ -52,7 +56,6 @@ class Command(Renderable, BaseCommand): logging.getLogger(__name__).fatal("Cannot classify documents, classifier model file was not found.") return - for document in documents: logging.getLogger(__name__).info("Processing document {}".format(document.title)) - clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent']) + clf.classify_document(document, classify_document_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'], replace_tags=options['replace_tags']) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 449e02200..15fa9e10d 100755 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -9,7 +9,7 @@ from django.contrib.contenttypes.models import ContentType from django.utils import timezone from documents.classifier import DocumentClassifier -from ..models import Correspondent, Document, Tag, DocumentType +from ..models import Document, Tag def logger(message, group): @@ -23,11 +23,14 @@ def classify_document(sender, document=None, logging_group=None, **kwargs): global classifier try: classifier.reload() - classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True) + classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_document_type=True) except FileNotFoundError: logging.getLogger(__name__).fatal("Cannot classify document, classifier model file was not found.") +def add_inbox_tags(sender, document=None, logging_group=None, **kwargs): + inbox_tags = Tag.objects.filter(is_inbox_tag=True) + document.tags.add(*inbox_tags) def run_pre_consume_script(sender, filename, **kwargs):