diff --git a/src/documents/actions.py b/src/documents/actions.py index 6e1cad45c..6cf9d1e99 100755 --- a/src/documents/actions.py +++ b/src/documents/actions.py @@ -165,8 +165,9 @@ def remove_document_type_from_selected(modeladmin, request, queryset): def run_document_classifier_on_selected(modeladmin, request, queryset): + clf = DocumentClassifier() try: - clf = DocumentClassifier.load_classifier() + clf.reload() return simple_action( modeladmin=modeladmin, request=request, @@ -201,4 +202,3 @@ remove_document_type_from_selected.short_description = \ "Remove document type from selected documents" run_document_classifier_on_selected.short_description = \ "Run document classifier on selected" - diff --git a/src/documents/classifier.py b/src/documents/classifier.py index bcfc1feb0..f71551455 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -2,14 +2,13 @@ import logging import os import pickle +from sklearn.feature_extraction.text import CountVectorizer from sklearn.neural_network import MLPClassifier +from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from documents.models import Correspondent, DocumentType, Tag, Document from paperless import settings -from sklearn.feature_extraction.text import CountVectorizer -from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer - def preprocess_content(content): content = content.lower() @@ -23,26 +22,21 @@ def preprocess_content(content): class DocumentClassifier(object): - classifier_version = None + def __init__(self): + self.classifier_version = 0 - data_vectorizer = None + self.data_vectorizer = None - tags_binarizer = None - correspondent_binarizer = None - document_type_binarizer = None + self.tags_binarizer = None + self.correspondent_binarizer = None + self.document_type_binarizer = None - tags_classifier = None - correspondent_classifier = None - document_type_classifier = None - - @staticmethod - def load_classifier(): - clf = DocumentClassifier() - clf.reload() - return clf + self.tags_classifier = None + self.correspondent_classifier = None + self.document_type_classifier = None def reload(self): - if self.classifier_version is None or os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: + if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: logging.getLogger(__name__).info("Reloading classifier models") with open(settings.MODEL_FILE, "rb") as f: self.data_vectorizer = pickle.load(f) @@ -77,27 +71,54 @@ class DocumentClassifier(object): logging.getLogger(__name__).info("Gathering data from database...") for doc in Document.objects.exclude(tags__is_inbox_tag=True): data.append(preprocess_content(doc.content)) - labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1) - labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1) - tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)] + + y = -1 + if doc.document_type: + if doc.document_type.automatic_classification: + y = doc.document_type.id + labels_document_type.append(y) + + y = -1 + if doc.correspondent: + if doc.correspondent.automatic_classification: + y = doc.correspondent.id + labels_correspondent.append(y) + + tags = [tag.id for tag in doc.tags.filter( + automatic_classification=True + )] labels_tags.append(tags) labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) - logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type)))) + logging.getLogger(__name__).info( + "{} documents, {} tag(s), {} correspondent(s), " + "{} document type(s).".format( + len(data), + len(labels_tags_unique), + len(set(labels_correspondent)), + len(set(labels_document_type)) + ) + ) # Step 2: vectorize data logging.getLogger(__name__).info("Vectorizing data...") - self.data_vectorizer = CountVectorizer(analyzer="char", ngram_range=(3, 5), min_df=0.1) + self.data_vectorizer = CountVectorizer( + analyzer="char", + ngram_range=(3, 5), + min_df=0.1 + ) data_vectorized = self.data_vectorizer.fit_transform(data) self.tags_binarizer = MultiLabelBinarizer() labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) self.correspondent_binarizer = LabelBinarizer() - labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) + labels_correspondent_vectorized = \ + self.correspondent_binarizer.fit_transform(labels_correspondent) self.document_type_binarizer = LabelBinarizer() - labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type) + labels_document_type_vectorized = \ + self.document_type_binarizer.fit_transform(labels_document_type) # Step 3: train the classifiers if len(self.tags_binarizer.classes_) > 0: @@ -106,62 +127,114 @@ class DocumentClassifier(object): self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) else: self.tags_classifier = None - logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") + logging.getLogger(__name__).info( + "There are no tags. Not training tags classifier." + ) if len(self.correspondent_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training correspondent classifier...") + logging.getLogger(__name__).info( + "Training correspondent classifier..." + ) self.correspondent_classifier = MLPClassifier(verbose=True) - self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) + self.correspondent_classifier.fit( + data_vectorized, + labels_correspondent_vectorized + ) else: self.correspondent_classifier = None - logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") + logging.getLogger(__name__).info( + "There are no correspondents. Not training correspondent " + "classifier." + ) if len(self.document_type_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training document type classifier...") + logging.getLogger(__name__).info( + "Training document type classifier..." + ) self.document_type_classifier = MLPClassifier(verbose=True) - self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized) + self.document_type_classifier.fit( + data_vectorized, + labels_document_type_vectorized + ) else: self.document_type_classifier = None - logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") + logging.getLogger(__name__).info( + "There are no document types. Not training document type " + "classifier." + ) - def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False): - X = self.data_vectorizer.transform([preprocess_content(document.content)]) + def classify_document( + self, document, classify_correspondent=False, + classify_document_type=False, classify_tags=False, + replace_tags=False): - update_fields = () + X = self.data_vectorizer.transform( + [preprocess_content(document.content)] + ) - if classify_correspondent and self.correspondent_classifier is not None: - y_correspondent = self.correspondent_classifier.predict(X) - correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] + if classify_correspondent and self.correspondent_classifier: + self._classify_correspondent(X, document) + + if classify_document_type and self.document_type_classifier: + self._classify_document_type(X, document) + + if classify_tags and self.tags_classifier: + self._classify_tags(X, document, replace_tags) + + document.save(update_fields=("correspondent", "document_type")) + + def _classify_correspondent(self, X, document): + y = self.correspondent_classifier.predict(X) + correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] + try: + correspondent = None + if correspondent_id != -1: + correspondent = Correspondent.objects.get(id=correspondent_id) + logging.getLogger(__name__).info( + "Detected correspondent: {}".format(correspondent.name) + ) + else: + logging.getLogger(__name__).info("Detected correspondent: -") + document.correspondent = correspondent + except Correspondent.DoesNotExist: + logging.getLogger(__name__).warning( + "Detected correspondent with id {} does not exist " + "anymore! Did you delete it?".format(correspondent_id) + ) + + def _classify_document_type(self, X, document): + y = self.document_type_classifier.predict(X) + document_type_id = self.document_type_binarizer.inverse_transform(y)[0] + try: + document_type = None + if document_type_id != -1: + document_type = DocumentType.objects.get(id=document_type_id) + logging.getLogger(__name__).info( + "Detected document type: {}".format(document_type.name) + ) + else: + logging.getLogger(__name__).info("Detected document type: -") + document.document_type = document_type + except DocumentType.DoesNotExist: + logging.getLogger(__name__).warning( + "Detected document type with id {} does not exist " + "anymore! Did you delete it?".format(document_type_id) + ) + + def _classify_tags(self, X, document, replace_tags): + y = self.tags_classifier.predict(X) + tags_ids = self.tags_binarizer.inverse_transform(y)[0] + if replace_tags: + document.tags.clear() + for tag_id in tags_ids: try: - correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None - logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-")) - document.correspondent = correspondent - update_fields = update_fields + ("correspondent",) - except Correspondent.DoesNotExist: - logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id)) - - if classify_document_type and self.document_type_classifier is not None: - y_type = self.document_type_classifier.predict(X) - type_id = self.document_type_binarizer.inverse_transform(y_type)[0] - try: - document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None - logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-")) - document.document_type = document_type - update_fields = update_fields + ("document_type",) - except DocumentType.DoesNotExist: - logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id)) - - if classify_tags and self.tags_classifier is not None: - y_tags = self.tags_classifier.predict(X) - tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0] - if replace_tags: - document.tags.clear() - for tag_id in tags_ids: - try: - tag = Tag.objects.get(id=tag_id) - document.tags.add(tag) - logging.getLogger(__name__).info("Detected tag: {}".format(tag.name)) - except Tag.DoesNotExist: - logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id)) - - document.save(update_fields=update_fields) + tag = Tag.objects.get(id=tag_id) + logging.getLogger(__name__).info( + "Detected tag: {}".format(tag.name) + ) + document.tags.add(tag) + except Tag.DoesNotExist: + logging.getLogger(__name__).warning( + "Detected tag with id {} does not exist anymore! Did " + "you delete it?".format(tag_id) + ) diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 2a9f32c13..d81789629 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -54,8 +54,9 @@ class Command(Renderable, BaseCommand): documents = queryset.distinct() logging.getLogger(__name__).info("Loading classifier") + clf = DocumentClassifier() try: - clf = DocumentClassifier.load_classifier() + clf.reload() except FileNotFoundError: logging.getLogger(__name__).fatal("Cannot classify documents, " "classifier model file was not " diff --git a/src/documents/views.py b/src/documents/views.py index 05f8f742c..fe02d6e98 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -20,7 +20,13 @@ from rest_framework.viewsets import ( ReadOnlyModelViewSet ) -from .filters import CorrespondentFilterSet, DocumentFilterSet, TagFilterSet, DocumentTypeFilterSet +from .filters import ( + CorrespondentFilterSet, + DocumentFilterSet, + TagFilterSet, + DocumentTypeFilterSet +) + from .forms import UploadForm from .models import Correspondent, Document, Log, Tag, DocumentType from .serialisers import ( diff --git a/src/paperless/settings.py b/src/paperless/settings.py index d39e1cf5d..03edef88e 100755 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -58,7 +58,7 @@ if _allowed_hosts: ALLOWED_HOSTS = _allowed_hosts.split(",") FORCE_SCRIPT_NAME = os.getenv("PAPERLESS_FORCE_SCRIPT_NAME") - + # Application definition INSTALLED_APPS = [