diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 851a75899..6c90536b0 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -1,61 +1,74 @@ +import hashlib import logging import os import pickle +import re +import time from sklearn.feature_extraction.text import CountVectorizer from sklearn.neural_network import MLPClassifier -from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer +from sklearn.preprocessing import MultiLabelBinarizer from documents.models import Document, MatchingModel from paperless import settings +class IncompatibleClassifierVersionError(Exception): + pass + + +logger = logging.getLogger(__name__) + + def preprocess_content(content): - content = content.lower() - content = content.strip() - content = content.replace("\n", " ") - content = content.replace("\r", " ") - while content.find(" ") > -1: - content = content.replace(" ", " ") + content = content.lower().strip() + content = re.sub(r"\s+", " ", content) return content class DocumentClassifier(object): + FORMAT_VERSION = 5 + def __init__(self): + # mtime of the model file on disk. used to prevent reloading when nothing has changed. self.classifier_version = 0 + # hash of the training data. used to prevent re-training when the training data has not changed. + self.data_hash = None + self.data_vectorizer = None - self.tags_binarizer = None - self.correspondent_binarizer = None - self.document_type_binarizer = None - self.tags_classifier = None self.correspondent_classifier = None self.document_type_classifier = None def reload(self): if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: - logging.getLogger(__name__).info("Reloading classifier models") with open(settings.MODEL_FILE, "rb") as f: - self.data_vectorizer = pickle.load(f) - self.tags_binarizer = pickle.load(f) - self.correspondent_binarizer = pickle.load(f) - self.document_type_binarizer = pickle.load(f) + schema_version = pickle.load(f) - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) + if schema_version != self.FORMAT_VERSION: + raise IncompatibleClassifierVersionError("Cannor load classifier, incompatible versions.") + else: + if self.classifier_version > 0: + logger.info("Classifier updated on disk, reloading classifier models") + self.data_hash = pickle.load(f) + self.data_vectorizer = pickle.load(f) + self.tags_binarizer = pickle.load(f) + + self.tags_classifier = pickle.load(f) + self.correspondent_classifier = pickle.load(f) + self.document_type_classifier = pickle.load(f) self.classifier_version = os.path.getmtime(settings.MODEL_FILE) def save_classifier(self): with open(settings.MODEL_FILE, "wb") as f: + pickle.dump(self.FORMAT_VERSION, f) # Version + pickle.dump(self.data_hash, f) pickle.dump(self.data_vectorizer, f) pickle.dump(self.tags_binarizer, f) - pickle.dump(self.correspondent_binarizer, f) - pickle.dump(self.document_type_binarizer, f) pickle.dump(self.tags_classifier, f) pickle.dump(self.correspondent_classifier, f) @@ -68,109 +81,121 @@ class DocumentClassifier(object): labels_document_type = list() # Step 1: Extract and preprocess training data from the database. - logging.getLogger(__name__).info("Gathering data from database...") - for doc in Document.objects.exclude(tags__is_inbox_tag=True): - data.append(preprocess_content(doc.content)) + logging.getLogger(__name__).debug("Gathering data from database...") + m = hashlib.sha1() + for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True): + preprocessed_content = preprocess_content(doc.content) + m.update(preprocessed_content.encode('utf-8')) + data.append(preprocessed_content) y = -1 if doc.document_type: if doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO: y = doc.document_type.pk + m.update(y.to_bytes(4, 'little', signed=True)) labels_document_type.append(y) y = -1 if doc.correspondent: if doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO: y = doc.correspondent.pk + m.update(y.to_bytes(4, 'little', signed=True)) labels_correspondent.append(y) tags = [tag.pk for tag in doc.tags.filter( matching_algorithm=MatchingModel.MATCH_AUTO )] + m.update(bytearray(tags)) labels_tags.append(tags) if not data: raise ValueError("No training data available.") + new_data_hash = m.digest() + + if self.data_hash and new_data_hash == self.data_hash: + return False + labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) - logging.getLogger(__name__).info( + + num_tags = len(labels_tags_unique) + # substract 1 since -1 (null) is also part of the classes. + num_correspondents = len(labels_correspondent) - 1 + num_document_types = len(labels_document_type) - 1 + + logging.getLogger(__name__).debug( "{} documents, {} tag(s), {} correspondent(s), " "{} document type(s).".format( len(data), - len(labels_tags_unique), - len(set(labels_correspondent)), - len(set(labels_document_type)) + num_tags, + num_correspondents, + num_document_types ) ) # Step 2: vectorize data - logging.getLogger(__name__).info("Vectorizing data...") + logging.getLogger(__name__).debug("Vectorizing data...") self.data_vectorizer = CountVectorizer( - analyzer="char", - ngram_range=(3, 5), - min_df=0.1 + analyzer="word", + ngram_range=(1,2), + min_df=0.01 ) data_vectorized = self.data_vectorizer.fit_transform(data) self.tags_binarizer = MultiLabelBinarizer() labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) - self.correspondent_binarizer = LabelBinarizer() - labels_correspondent_vectorized = \ - self.correspondent_binarizer.fit_transform(labels_correspondent) - - self.document_type_binarizer = LabelBinarizer() - labels_document_type_vectorized = \ - self.document_type_binarizer.fit_transform(labels_document_type) - # Step 3: train the classifiers - if len(self.tags_binarizer.classes_) > 0: - logging.getLogger(__name__).info("Training tags classifier...") - self.tags_classifier = MLPClassifier(verbose=True) + if num_tags > 0: + logging.getLogger(__name__).debug("Training tags classifier...") + self.tags_classifier = MLPClassifier(verbose=True, tol=0.01) self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) else: self.tags_classifier = None - logging.getLogger(__name__).info( + logging.getLogger(__name__).debug( "There are no tags. Not training tags classifier." ) - if len(self.correspondent_binarizer.classes_) > 0: - logging.getLogger(__name__).info( + if num_correspondents > 0: + logging.getLogger(__name__).debug( "Training correspondent classifier..." ) - self.correspondent_classifier = MLPClassifier(verbose=True) + self.correspondent_classifier = MLPClassifier(verbose=True, tol=0.01) self.correspondent_classifier.fit( data_vectorized, - labels_correspondent_vectorized + labels_correspondent ) else: self.correspondent_classifier = None - logging.getLogger(__name__).info( + logging.getLogger(__name__).debug( "There are no correspondents. Not training correspondent " "classifier." ) - if len(self.document_type_binarizer.classes_) > 0: - logging.getLogger(__name__).info( + if num_document_types > 0: + logging.getLogger(__name__).debug( "Training document type classifier..." ) - self.document_type_classifier = MLPClassifier(verbose=True) + self.document_type_classifier = MLPClassifier(verbose=True, tol=0.01) self.document_type_classifier.fit( data_vectorized, - labels_document_type_vectorized + labels_document_type ) else: self.document_type_classifier = None - logging.getLogger(__name__).info( + logging.getLogger(__name__).debug( "There are no document types. Not training document type " "classifier." ) + self.data_hash = new_data_hash + + return True + def predict_correspondent(self, content): if self.correspondent_classifier: X = self.data_vectorizer.transform([preprocess_content(content)]) - y = self.correspondent_classifier.predict(X) - correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] + correspondent_id = self.correspondent_classifier.predict(X) if correspondent_id != -1: return correspondent_id else: @@ -181,8 +206,7 @@ class DocumentClassifier(object): def predict_document_type(self, content): if self.document_type_classifier: X = self.data_vectorizer.transform([preprocess_content(content)]) - y = self.document_type_classifier.predict(X) - document_type_id = self.document_type_binarizer.inverse_transform(y)[0] + document_type_id = self.document_type_classifier.predict(X) if document_type_id != -1: return document_type_id else: diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 3920f2942..f61d11136 100755 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -10,7 +10,7 @@ from django.db import transaction from django.utils import timezone from paperless.db import GnuPG -from .classifier import DocumentClassifier +from .classifier import DocumentClassifier, IncompatibleClassifierVersionError from .models import Document, FileInfo from .parsers import ParseError, get_parser_class from .signals import ( @@ -133,11 +133,8 @@ class Consumer: try: self.classifier.reload() classifier = self.classifier - except FileNotFoundError: - self.log("warning", "Cannot classify documents, classifier " - "model file was not found. Consider " - "running python manage.py " - "document_create_classifier.") + except (FileNotFoundError, IncompatibleClassifierVersionError) as e: + logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e)) document_consumption_finished.send( sender=self.__class__, diff --git a/src/documents/management/commands/document_create_classifier.py b/src/documents/management/commands/document_create_classifier.py index 9b8f28615..85cb3b446 100755 --- a/src/documents/management/commands/document_create_classifier.py +++ b/src/documents/management/commands/document_create_classifier.py @@ -1,7 +1,8 @@ import logging from django.core.management.base import BaseCommand -from documents.classifier import DocumentClassifier +from documents.classifier import DocumentClassifier, \ + IncompatibleClassifierVersionError from paperless import settings from ...mixins import Renderable @@ -18,12 +19,25 @@ class Command(Renderable, BaseCommand): def handle(self, *args, **options): classifier = DocumentClassifier() + try: - classifier.train() - logging.getLogger(__name__).info( - "Saving models to {}...".format(settings.MODEL_FILE) - ) - classifier.save_classifier() + # load the classifier, since we might not have to train it again. + classifier.reload() + except (FileNotFoundError, IncompatibleClassifierVersionError): + # This is what we're going to fix here. + pass + + try: + if classifier.train(): + logging.getLogger(__name__).info( + "Saving updated classifier model to {}...".format(settings.MODEL_FILE) + ) + classifier.save_classifier() + else: + logging.getLogger(__name__).debug( + "Training data unchanged." + ) + except Exception as e: logging.getLogger(__name__).error( "Classifier error: " + str(e) diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 9238bea71..e48b8802c 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -2,7 +2,8 @@ import logging from django.core.management.base import BaseCommand -from documents.classifier import DocumentClassifier +from documents.classifier import DocumentClassifier, \ + IncompatibleClassifierVersionError from documents.models import Document from ...mixins import Renderable from ...signals.handlers import set_correspondent, set_document_type, set_tags @@ -72,10 +73,8 @@ class Command(Renderable, BaseCommand): classifier = DocumentClassifier() try: classifier.reload() - except FileNotFoundError: - logging.getLogger(__name__).warning("Cannot classify documents, " - "classifier model file was not " - "found.") + except (FileNotFoundError, IncompatibleClassifierVersionError) as e: + logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e)) classifier = None for document in documents: