From c091eba26eeaf73716ad56413bc510905cfb4da3 Mon Sep 17 00:00:00 2001 From: Jonas Winkler Date: Tue, 4 Sep 2018 14:39:55 +0200 Subject: [PATCH] Implemented the classifier model, including automatic tagging of new documents --- .gitignore | 3 + src/documents/admin.py | 23 --- src/documents/apps.py | 8 +- src/documents/classifier.py | 67 +++++++ .../commands/document_correspondents.py | 82 -------- .../commands/document_create_classifier.py | 184 ++++++++---------- .../commands/document_create_dataset.py | 89 ++++----- .../management/commands/document_retagger.py | 42 +++- src/documents/signals/handlers.py | 76 +------- src/paperless/settings.py | 5 + 10 files changed, 240 insertions(+), 339 deletions(-) create mode 100755 src/documents/classifier.py delete mode 100755 src/documents/management/commands/document_correspondents.py diff --git a/.gitignore b/.gitignore index 439d9df4b..bcfe5510a 100644 --- a/.gitignore +++ b/.gitignore @@ -83,3 +83,6 @@ scripts/nuke # Static files collected by the collectstatic command static/ + +# Classification Models +models/ diff --git a/src/documents/admin.py b/src/documents/admin.py index 65548d91f..7c79ebaa0 100755 --- a/src/documents/admin.py +++ b/src/documents/admin.py @@ -106,14 +106,6 @@ class CorrespondentAdmin(CommonAdmin): list_filter = ("matching_algorithm",) list_editable = ("match", "matching_algorithm") - def save_model(self, request, obj, form, change): - super().save_model(request, obj, form, change) - - for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True): - if obj.matches(document.content): - document.correspondent = obj - document.save(update_fields=("correspondent",)) - def get_queryset(self, request): qs = super(CorrespondentAdmin, self).get_queryset(request) qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created")) @@ -135,13 +127,6 @@ class TagAdmin(CommonAdmin): list_filter = ("colour", "matching_algorithm") list_editable = ("colour", "match", "matching_algorithm") - def save_model(self, request, obj, form, change): - super().save_model(request, obj, form, change) - - for document in Document.objects.all().exclude(tags__is_archived_tag=True): - if obj.matches(document.content): - document.tags.add(obj) - def get_queryset(self, request): qs = super(TagAdmin, self).get_queryset(request) qs = qs.annotate(document_count=models.Count("documents")) @@ -158,14 +143,6 @@ class DocumentTypeAdmin(CommonAdmin): list_filter = ("matching_algorithm",) list_editable = ("match", "matching_algorithm") - def save_model(self, request, obj, form, change): - super().save_model(request, obj, form, change) - - for document in Document.objects.filter(document_type__isnull=True).exclude(tags__is_archived_tag=True): - if obj.matches(document.content): - document.document_type = obj - document.save(update_fields=("document_type",)) - def get_queryset(self, request): qs = super(DocumentTypeAdmin, self).get_queryset(request) qs = qs.annotate(document_count=models.Count("documents")) diff --git a/src/documents/apps.py b/src/documents/apps.py index 7b2d50f31..16a50843b 100755 --- a/src/documents/apps.py +++ b/src/documents/apps.py @@ -11,9 +11,7 @@ class DocumentsConfig(AppConfig): from .signals import document_consumption_started from .signals import document_consumption_finished from .signals.handlers import ( - set_correspondent, - set_tags, - set_document_type, + classify_document, run_pre_consume_script, run_post_consume_script, cleanup_document_deletion, @@ -22,9 +20,7 @@ class DocumentsConfig(AppConfig): document_consumption_started.connect(run_pre_consume_script) - document_consumption_finished.connect(set_tags) - document_consumption_finished.connect(set_correspondent) - document_consumption_finished.connect(set_document_type) + document_consumption_finished.connect(classify_document) document_consumption_finished.connect(set_log_entry) document_consumption_finished.connect(run_post_consume_script) diff --git a/src/documents/classifier.py b/src/documents/classifier.py new file mode 100755 index 000000000..727c4af3a --- /dev/null +++ b/src/documents/classifier.py @@ -0,0 +1,67 @@ +import pickle + +from documents.models import Correspondent, DocumentType, Tag +from paperless import settings + + +def preprocess_content(content): + content = content.lower() + content = content.strip() + content = content.replace("\n", " ") + content = content.replace("\r", " ") + while content.find(" ") > -1: + content = content.replace(" ", " ") + return content + + +class DocumentClassifier(object): + + @staticmethod + def load_classifier(): + clf = DocumentClassifier() + clf.reload() + return clf + + def reload(self): + with open(settings.MODEL_FILE, "rb") as f: + self.data_vectorizer = pickle.load(f) + self.tags_binarizer = pickle.load(f) + self.correspondent_binarizer = pickle.load(f) + self.type_binarizer = pickle.load(f) + + self.tags_classifier = pickle.load(f) + self.correspondent_classifier = pickle.load(f) + self.type_classifier = pickle.load(f) + + def save_classifier(self): + with open(settings.MODEL_FILE, "wb") as f: + pickle.dump(self.data_vectorizer, f) + + pickle.dump(self.tags_binarizer, f) + pickle.dump(self.correspondent_binarizer, f) + pickle.dump(self.type_binarizer, f) + + pickle.dump(self.tags_classifier, f) + pickle.dump(self.correspondent_classifier, f) + pickle.dump(self.type_classifier, f) + + def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False): + X = self.data_vectorizer.transform([preprocess_content(document.content)]) + + if classify_correspondent: + y_correspondent = self.correspondent_classifier.predict(X) + correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] + print("Detected correspondent:", correspondent) + document.correspondent = Correspondent.objects.filter(name=correspondent).first() + + if classify_type: + y_type = self.type_classifier.predict(X) + type = self.type_binarizer.inverse_transform(y_type)[0] + print("Detected document type:", type) + document.type = DocumentType.objects.filter(name=type).first() + + if classify_tags: + y_tags = self.tags_classifier.predict(X) + tags = self.tags_binarizer.inverse_transform(y_tags)[0] + print("Detected tags:", tags) + document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags]) diff --git a/src/documents/management/commands/document_correspondents.py b/src/documents/management/commands/document_correspondents.py deleted file mode 100755 index d3b324bb1..000000000 --- a/src/documents/management/commands/document_correspondents.py +++ /dev/null @@ -1,82 +0,0 @@ -import sys - -from django.core.management.base import BaseCommand - -from documents.models import Correspondent, Document - -from ...mixins import Renderable - - -class Command(Renderable, BaseCommand): - - help = """ - Using the current set of correspondent rules, apply said rules to all - documents in the database, effectively allowing you to back-tag all - previously indexed documents with correspondent created (or modified) - after their initial import. - """.replace(" ", "") - - TOO_MANY_CONTINUE = ( - "Detected {} potential correspondents for {}, so we've opted for {}") - TOO_MANY_SKIP = ( - "Detected {} potential correspondents for {}, so we're skipping it") - CHANGE_MESSAGE = ( - 'Document {}: "{}" was given the correspondent id {}: "{}"') - - def __init__(self, *args, **kwargs): - self.verbosity = 0 - BaseCommand.__init__(self, *args, **kwargs) - - def add_arguments(self, parser): - parser.add_argument( - "--use-first", - default=False, - action="store_true", - help="By default this command won't try to assign a correspondent " - "if more than one matches the document. Use this flag if " - "you'd rather it just pick the first one it finds." - ) - - def handle(self, *args, **options): - - self.verbosity = options["verbosity"] - - for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True): - - potential_correspondents = list( - Correspondent.match_all(document.content)) - - if not potential_correspondents: - continue - - potential_count = len(potential_correspondents) - correspondent = potential_correspondents[0] - - if potential_count > 1: - if not options["use_first"]: - print( - self.TOO_MANY_SKIP.format(potential_count, document), - file=sys.stderr - ) - continue - print( - self.TOO_MANY_CONTINUE.format( - potential_count, - document, - correspondent - ), - file=sys.stderr - ) - - document.correspondent = correspondent - document.save(update_fields=("correspondent",)) - - print( - self.CHANGE_MESSAGE.format( - document.pk, - document.title, - correspondent.pk, - correspondent.name - ), - file=sys.stderr - ) diff --git a/src/documents/management/commands/document_create_classifier.py b/src/documents/management/commands/document_create_classifier.py index 68bb746d7..0549709dd 100755 --- a/src/documents/management/commands/document_create_classifier.py +++ b/src/documents/management/commands/document_create_classifier.py @@ -1,100 +1,84 @@ -import logging -import os.path -import pickle - -from django.core.management.base import BaseCommand -from sklearn.feature_extraction.text import CountVectorizer -from sklearn.multiclass import OneVsRestClassifier -from sklearn.naive_bayes import MultinomialNB -from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder - -from documents.models import Document -from ...mixins import Renderable - - -def preprocess_content(content): - content = content.lower() - content = content.strip() - content = content.replace("\n", " ") - content = content.replace("\r", " ") - while content.find(" ") > -1: - content = content.replace(" ", " ") - return content - - -class Command(Renderable, BaseCommand): - - help = """ - There is no help. - """.replace(" ", "") - - def __init__(self, *args, **kwargs): - BaseCommand.__init__(self, *args, **kwargs) - - def handle(self, *args, **options): - data = list() - labels_tags = list() - labels_correspondent = list() - labels_type = list() - - # Step 1: Extract and preprocess training data from the database. - logging.getLogger(__name__).info("Gathering data from database...") - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - data.append(preprocess_content(doc.content)) - labels_type.append(doc.document_type.name if doc.document_type is not None else "-") - labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") - tags = [tag.name for tag in doc.tags.all()] - labels_tags.append(tags) - - # Step 2: vectorize data - logging.getLogger(__name__).info("Vectorizing data...") - data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05) - data_vectorized = data_vectorizer.fit_transform(data) - - tags_binarizer = MultiLabelBinarizer() - labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags) - - correspondent_binarizer = LabelEncoder() - labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent) - - type_binarizer = LabelEncoder() - labels_type_vectorized = type_binarizer.fit_transform(labels_type) - - # Step 3: train the classifiers - if len(tags_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training tags classifier") - tags_classifier = OneVsRestClassifier(MultinomialNB()) - tags_classifier.fit(data_vectorized, labels_tags_vectorized) - else: - tags_classifier = None - logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") - - if len(correspondent_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training correspondent classifier") - correspondent_classifier = MultinomialNB() - correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) - else: - correspondent_classifier = None - logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") - - if len(type_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training document type classifier") - type_classifier = MultinomialNB() - type_classifier.fit(data_vectorized, labels_type_vectorized) - else: - type_classifier = None - logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") - - models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle")) - logging.getLogger(__name__).info("Saving models to " + models_root + "...") - - with open(models_root, "wb") as f: - pickle.dump(data_vectorizer, f) - - pickle.dump(tags_binarizer, f) - pickle.dump(correspondent_binarizer, f) - pickle.dump(type_binarizer, f) - - pickle.dump(tags_classifier, f) - pickle.dump(correspondent_classifier, f) - pickle.dump(type_classifier, f) +import logging +import os.path +import pickle + +from django.core.management.base import BaseCommand +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.multiclass import OneVsRestClassifier +from sklearn.naive_bayes import MultinomialNB +from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder + +from documents.classifier import preprocess_content, DocumentClassifier +from documents.models import Document +from paperless import settings +from ...mixins import Renderable + + +class Command(Renderable, BaseCommand): + + help = """ + There is no help. + """.replace(" ", "") + + def __init__(self, *args, **kwargs): + BaseCommand.__init__(self, *args, **kwargs) + + def handle(self, *args, **options): + clf = DocumentClassifier() + + data = list() + labels_tags = list() + labels_correspondent = list() + labels_type = list() + + # Step 1: Extract and preprocess training data from the database. + logging.getLogger(__name__).info("Gathering data from database...") + for doc in Document.objects.exclude(tags__is_inbox_tag=True): + data.append(preprocess_content(doc.content)) + labels_type.append(doc.document_type.name if doc.document_type is not None else "-") + labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") + tags = [tag.name for tag in doc.tags.all()] + labels_tags.append(tags) + + # Step 2: vectorize data + logging.getLogger(__name__).info("Vectorizing data...") + clf.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05) + data_vectorized = clf.data_vectorizer.fit_transform(data) + + clf.tags_binarizer = MultiLabelBinarizer() + labels_tags_vectorized = clf.tags_binarizer.fit_transform(labels_tags) + + clf.correspondent_binarizer = LabelEncoder() + labels_correspondent_vectorized = clf.correspondent_binarizer.fit_transform(labels_correspondent) + + clf.type_binarizer = LabelEncoder() + labels_type_vectorized = clf.type_binarizer.fit_transform(labels_type) + + # Step 3: train the classifiers + if len(clf.tags_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training tags classifier") + clf.tags_classifier = OneVsRestClassifier(MultinomialNB()) + clf.tags_classifier.fit(data_vectorized, labels_tags_vectorized) + else: + clf.tags_classifier = None + logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") + + if len(clf.correspondent_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training correspondent classifier") + clf.correspondent_classifier = MultinomialNB() + clf.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) + else: + clf.correspondent_classifier = None + logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") + + if len(clf.type_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training document type classifier") + clf.type_classifier = MultinomialNB() + clf.type_classifier.fit(data_vectorized, labels_type_vectorized) + else: + clf.type_classifier = None + logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") + + logging.getLogger(__name__).info("Saving models to " + settings.MODEL_FILE + "...") + + clf.save_classifier() \ No newline at end of file diff --git a/src/documents/management/commands/document_create_dataset.py b/src/documents/management/commands/document_create_dataset.py index 843211677..4b30eff35 100755 --- a/src/documents/management/commands/document_create_dataset.py +++ b/src/documents/management/commands/document_create_dataset.py @@ -1,49 +1,40 @@ -from django.core.management.base import BaseCommand - -from documents.models import Document -from ...mixins import Renderable - - -def preprocess_content(content): - content = content.lower() - content = content.strip() - content = content.replace("\n", " ") - content = content.replace("\r", " ") - while content.find(" ") > -1: - content = content.replace(" ", " ") - return content - - -class Command(Renderable, BaseCommand): - - help = """ - There is no help. - """.replace(" ", "") - - def __init__(self, *args, **kwargs): - BaseCommand.__init__(self, *args, **kwargs) - - def handle(self, *args, **options): - with open("dataset_tags.txt", "w") as f: - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - labels = [] - for tag in doc.tags.all(): - labels.append(tag.name) - f.write(",".join(labels)) - f.write(";") - f.write(preprocess_content(doc.content)) - f.write("\n") - - with open("dataset_types.txt", "w") as f: - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - f.write(doc.document_type.name if doc.document_type is not None else "None") - f.write(";") - f.write(preprocess_content(doc.content)) - f.write("\n") - - with open("dataset_correspondents.txt", "w") as f: - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - f.write(doc.correspondent.name if doc.correspondent is not None else "None") - f.write(";") - f.write(preprocess_content(doc.content)) - f.write("\n") +from django.core.management.base import BaseCommand + +from documents.classifier import preprocess_content +from documents.models import Document +from ...mixins import Renderable + + +class Command(Renderable, BaseCommand): + + help = """ + There is no help. + """.replace(" ", "") + + def __init__(self, *args, **kwargs): + BaseCommand.__init__(self, *args, **kwargs) + + def handle(self, *args, **options): + with open("dataset_tags.txt", "w") as f: + for doc in Document.objects.exclude(tags__is_inbox_tag=True): + labels = [] + for tag in doc.tags.all(): + labels.append(tag.name) + f.write(",".join(labels)) + f.write(";") + f.write(preprocess_content(doc.content)) + f.write("\n") + + with open("dataset_types.txt", "w") as f: + for doc in Document.objects.exclude(tags__is_inbox_tag=True): + f.write(doc.document_type.name if doc.document_type is not None else "None") + f.write(";") + f.write(preprocess_content(doc.content)) + f.write("\n") + + with open("dataset_correspondents.txt", "w") as f: + for doc in Document.objects.exclude(tags__is_inbox_tag=True): + f.write(doc.correspondent.name if doc.correspondent is not None else "None") + f.write(";") + f.write(preprocess_content(doc.content)) + f.write("\n") diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index d3fd83962..7367f8057 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -1,5 +1,8 @@ +import logging + from django.core.management.base import BaseCommand +from documents.classifier import DocumentClassifier from documents.models import Document, Tag from ...mixins import Renderable @@ -8,25 +11,44 @@ from ...mixins import Renderable class Command(Renderable, BaseCommand): help = """ - Using the current set of tagging rules, apply said rules to all - documents in the database, effectively allowing you to back-tag all - previously indexed documents with tags created (or modified) after - their initial import. + There is no help. #TODO """.replace(" ", "") def __init__(self, *args, **kwargs): self.verbosity = 0 BaseCommand.__init__(self, *args, **kwargs) + def add_arguments(self, parser): + parser.add_argument( + "-c", "--correspondent", + action="store_true" + ) + parser.add_argument( + "-T", "--tags", + action="store_true" + ) + parser.add_argument( + "-t", "--type", + action="store_true" + ) + parser.add_argument( + "-i", "--inbox-only", + action="store_true" + ) + def handle(self, *args, **options): self.verbosity = options["verbosity"] - for document in Document.objects.all().exclude(tags__is_archived_tag=True): + if options['inbox_only']: + documents = Document.objects.filter(tags__is_inbox_tag=True).distinct() + else: + documents = Document.objects.all().exclude(tags__is_archived_tag=True).distinct() - tags = Tag.objects.exclude( - pk__in=document.tags.values_list("pk", flat=True)) + logging.getLogger(__name__).info("Loading classifier") + clf = DocumentClassifier.load_classifier() - for tag in Tag.match_all(document.content, tags): - print('Tagging {} with "{}"'.format(document, tag)) - document.tags.add(tag) + + for document in documents: + logging.getLogger(__name__).info("Processing document {}".format(document.title)) + clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent']) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 77713333e..48c6db952 100755 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -8,6 +8,7 @@ from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.utils import timezone +from documents.classifier import DocumentClassifier from ..models import Correspondent, Document, Tag, DocumentType @@ -15,79 +16,16 @@ def logger(message, group): logging.getLogger(__name__).debug(message, extra={"group": group}) -def set_correspondent(sender, document=None, logging_group=None, **kwargs): - - # No sense in assigning a correspondent when one is already set. - if document.correspondent: - return - - # No matching correspondents, so no need to continue - potential_correspondents = list(Correspondent.match_all(document.content)) - if not potential_correspondents: - return - - potential_count = len(potential_correspondents) - selected = potential_correspondents[0] - if potential_count > 1: - message = "Detected {} potential correspondents, so we've opted for {}" - logger( - message.format(potential_count, selected), - logging_group - ) - - logger( - 'Assigning correspondent "{}" to "{}" '.format(selected, document), - logging_group - ) - - document.correspondent = selected - document.save(update_fields=("correspondent",)) +classifier = None -def set_document_type(sender, document=None, logging_group=None, **kwargs): +def classify_document(sender, document=None, logging_group=None, **kwargs): + global classifier + if classifier is None: + classifier = DocumentClassifier.load_classifier() - # No sense in assigning a correspondent when one is already set. - if document.document_type: - return + classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True) - # No matching document types, so no need to continue - potential_document_types = list(DocumentType.match_all(document.content)) - if not potential_document_types: - return - - potential_count = len(potential_document_types) - selected = potential_document_types[0] - if potential_count > 1: - message = "Detected {} potential document types, so we've opted for {}" - logger( - message.format(potential_count, selected), - logging_group - ) - - logger( - 'Assigning document type "{}" to "{}" '.format(selected, document), - logging_group - ) - - document.document_type = selected - document.save(update_fields=("document_type",)) - - -def set_tags(sender, document=None, logging_group=None, **kwargs): - - current_tags = set(document.tags.all()) - relevant_tags = (set(Tag.match_all(document.content)) | set(Tag.objects.filter(is_inbox_tag=True))) - current_tags - - if not relevant_tags: - return - - message = 'Tagging "{}" with "{}"' - logger( - message.format(document, ", ".join([t.slug for t in relevant_tags])), - logging_group - ) - - document.tags.add(*relevant_tags) def run_pre_consume_script(sender, filename, **kwargs): diff --git a/src/paperless/settings.py b/src/paperless/settings.py index f354d5abf..c1e044f2e 100755 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -187,6 +187,11 @@ STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", "/static/") MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/") +# Document classification models location +MODEL_FILE = os.getenv( + "PAPERLESS_STATICDIR", os.path.join(BASE_DIR, "..", "models", "model.pickle")) + + # Paperless-specific stuff # You shouldn't have to edit any of these values. Rather, you can set these # values in /etc/paperless.conf instead.