From 82bc0e3368e0396a6844536ef812beb361def475 Mon Sep 17 00:00:00 2001 From: Jonas Winkler Date: Wed, 5 Sep 2018 12:43:11 +0200 Subject: [PATCH] Fixed a few things --- src/documents/actions.py | 26 +++++++ src/documents/admin.py | 5 +- src/documents/classifier.py | 72 +++++++++++++++++-- .../commands/document_create_classifier.py | 67 +---------------- 4 files changed, 99 insertions(+), 71 deletions(-) diff --git a/src/documents/actions.py b/src/documents/actions.py index 144d3df03..1ee05cecb 100755 --- a/src/documents/actions.py +++ b/src/documents/actions.py @@ -4,6 +4,7 @@ from django.contrib.admin.utils import model_ngettext from django.core.exceptions import PermissionDenied from django.template.response import TemplateResponse +from documents.classifier import DocumentClassifier from documents.models import Tag, Correspondent, DocumentType @@ -223,3 +224,28 @@ def remove_document_type_from_selected(modeladmin, request, queryset): remove_document_type_from_selected.short_description = "Remove document type from selected documents" + + +def run_document_classifier_on_selected(modeladmin, request, queryset): + if not modeladmin.has_change_permission(request): + raise PermissionDenied + + try: + clf = DocumentClassifier.load_classifier() + except FileNotFoundError: + modeladmin.message_user(request, "Classifier model file not found.", messages.ERROR) + return None + + n = queryset.count() + if n: + for obj in queryset: + clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_type=True, replace_tags=True) + modeladmin.log_change(request, obj, str(obj)) + modeladmin.message_user(request, "Successfully applied tags, correspondent and type to %(count)d %(items)s." % { + "count": n, "items": model_ngettext(modeladmin.opts, n) + }, messages.SUCCESS) + + return None + + +run_document_classifier_on_selected.short_description = "Run document classifier on selected" diff --git a/src/documents/admin.py b/src/documents/admin.py index 277158459..21716a7d2 100755 --- a/src/documents/admin.py +++ b/src/documents/admin.py @@ -13,7 +13,8 @@ from django.utils.safestring import mark_safe from django.db import models from documents.actions import add_tag_to_selected, remove_tag_from_selected, set_correspondent_on_selected, \ - remove_correspondent_from_selected, set_document_type_on_selected, remove_document_type_from_selected + remove_correspondent_from_selected, set_document_type_on_selected, remove_document_type_from_selected, \ + run_document_classifier_on_selected from .models import Correspondent, Tag, Document, Log, DocumentType @@ -165,7 +166,7 @@ class DocumentAdmin(CommonAdmin): ordering = ["-created", "correspondent"] - actions = [add_tag_to_selected, remove_tag_from_selected, set_correspondent_on_selected, remove_correspondent_from_selected, set_document_type_on_selected, remove_document_type_from_selected] + actions = [add_tag_to_selected, remove_tag_from_selected, set_correspondent_on_selected, remove_correspondent_from_selected, set_document_type_on_selected, remove_document_type_from_selected, run_document_classifier_on_selected] date_hierarchy = 'created' diff --git a/src/documents/classifier.py b/src/documents/classifier.py index d925a73a9..6a73ce84f 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -1,9 +1,15 @@ +import logging import os import pickle -from documents.models import Correspondent, DocumentType, Tag +from documents.models import Correspondent, DocumentType, Tag, Document from paperless import settings +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.multiclass import OneVsRestClassifier +from sklearn.naive_bayes import MultinomialNB +from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer + def preprocess_content(content): content = content.lower() @@ -61,29 +67,85 @@ class DocumentClassifier(object): pickle.dump(self.correspondent_classifier, f) pickle.dump(self.type_classifier, f) - def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False): + def train(self): + data = list() + labels_tags = list() + labels_correspondent = list() + labels_type = list() + + # Step 1: Extract and preprocess training data from the database. + logging.getLogger(__name__).info("Gathering data from database...") + for doc in Document.objects.exclude(tags__is_inbox_tag=True): + data.append(preprocess_content(doc.content)) + labels_type.append(doc.document_type.name if doc.document_type is not None else "-") + labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") + tags = [tag.name for tag in doc.tags.all()] + labels_tags.append(tags) + + # Step 2: vectorize data + logging.getLogger(__name__).info("Vectorizing data...") + self.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(2, 6), min_df=0.1) + data_vectorized = self.data_vectorizer.fit_transform(data) + + self.tags_binarizer = MultiLabelBinarizer() + labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) + + self.correspondent_binarizer = LabelBinarizer() + labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) + + self.type_binarizer = LabelBinarizer() + labels_type_vectorized = self.type_binarizer.fit_transform(labels_type) + + # Step 3: train the classifiers + if len(self.tags_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training tags classifier...") + self.tags_classifier = OneVsRestClassifier(MultinomialNB()) + self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) + else: + self.tags_classifier = None + logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") + + if len(self.correspondent_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training correspondent classifier...") + self.correspondent_classifier = OneVsRestClassifier(MultinomialNB()) + self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) + else: + self.correspondent_classifier = None + logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") + + if len(self.type_binarizer.classes_) > 0: + logging.getLogger(__name__).info("Training document type classifier...") + self.type_classifier = OneVsRestClassifier(MultinomialNB()) + self.type_classifier.fit(data_vectorized, labels_type_vectorized) + else: + self.type_classifier = None + logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") + + def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False, replace_tags=False): X = self.data_vectorizer.transform([preprocess_content(document.content)]) update_fields=() - if classify_correspondent: + if classify_correspondent and self.correspondent_classifier is not None: y_correspondent = self.correspondent_classifier.predict(X) correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] print("Detected correspondent:", correspondent) document.correspondent = Correspondent.objects.filter(name=correspondent).first() update_fields = update_fields + ("correspondent",) - if classify_type: + if classify_type and self.type_classifier is not None: y_type = self.type_classifier.predict(X) type = self.type_binarizer.inverse_transform(y_type)[0] print("Detected document type:", type) document.document_type = DocumentType.objects.filter(name=type).first() update_fields = update_fields + ("document_type",) - if classify_tags: + if classify_tags and self.tags_classifier is not None: y_tags = self.tags_classifier.predict(X) tags = self.tags_binarizer.inverse_transform(y_tags)[0] print("Detected tags:", tags) + if replace_tags: + document.tags.clear() document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags]) document.save(update_fields=update_fields) diff --git a/src/documents/management/commands/document_create_classifier.py b/src/documents/management/commands/document_create_classifier.py index 66aca0d60..bc6ea737e 100755 --- a/src/documents/management/commands/document_create_classifier.py +++ b/src/documents/management/commands/document_create_classifier.py @@ -3,13 +3,7 @@ import os.path import pickle from django.core.management.base import BaseCommand -from sklearn.feature_extraction.text import CountVectorizer -from sklearn.multiclass import OneVsRestClassifier -from sklearn.naive_bayes import MultinomialNB -from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder - -from documents.classifier import preprocess_content, DocumentClassifier -from documents.models import Document +from documents.classifier import DocumentClassifier from paperless import settings from ...mixins import Renderable @@ -26,64 +20,9 @@ class Command(Renderable, BaseCommand): def handle(self, *args, **options): clf = DocumentClassifier() - data = list() - labels_tags = list() - labels_correspondent = list() - labels_type = list() + clf.train() - # Step 1: Extract and preprocess training data from the database. - logging.getLogger(__name__).info("Gathering data from database...") - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - data.append(preprocess_content(doc.content)) - labels_type.append(doc.document_type.name if doc.document_type is not None else "-") - labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") - tags = [tag.name for tag in doc.tags.all()] - labels_tags.append(tags) - - # Step 2: vectorize data - logging.getLogger(__name__).info("Vectorizing data...") - clf.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(2, 6), min_df=0.1) - data_vectorized = clf.data_vectorizer.fit_transform(data) - - print(clf.data_vectorizer.vocabulary_) - - logging.getLogger(__name__).info("Shape of vectorized data: {}".format(data_vectorized.shape)) - - - clf.tags_binarizer = MultiLabelBinarizer() - labels_tags_vectorized = clf.tags_binarizer.fit_transform(labels_tags) - - clf.correspondent_binarizer = LabelEncoder() - labels_correspondent_vectorized = clf.correspondent_binarizer.fit_transform(labels_correspondent) - - clf.type_binarizer = LabelEncoder() - labels_type_vectorized = clf.type_binarizer.fit_transform(labels_type) - - # Step 3: train the classifiers - if len(clf.tags_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training tags classifier") - clf.tags_classifier = OneVsRestClassifier(MultinomialNB()) - clf.tags_classifier.fit(data_vectorized, labels_tags_vectorized) - else: - clf.tags_classifier = None - logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") - - if len(clf.correspondent_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training correspondent classifier") - clf.correspondent_classifier = MultinomialNB() - clf.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) - else: - clf.correspondent_classifier = None - logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") - - if len(clf.type_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training document type classifier") - clf.type_classifier = MultinomialNB() - clf.type_classifier.fit(data_vectorized, labels_type_vectorized) - else: - clf.type_classifier = None - logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") logging.getLogger(__name__).info("Saving models to " + settings.MODEL_FILE + "...") - clf.save_classifier() \ No newline at end of file + clf.save_classifier()