updated the classifier. Its now much faster and does not retrain when data hasnt changed.

This commit is contained in:
Jonas Winkler 2020-11-06 14:46:06 +01:00
parent 9fa5eac9b9
commit 296c113b16
4 changed files with 109 additions and 75 deletions

View File

@ -1,61 +1,74 @@
import hashlib
import logging import logging
import os import os
import pickle import pickle
import re
import time
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer
from documents.models import Document, MatchingModel from documents.models import Document, MatchingModel
from paperless import settings from paperless import settings
class IncompatibleClassifierVersionError(Exception):
pass
logger = logging.getLogger(__name__)
def preprocess_content(content): def preprocess_content(content):
content = content.lower() content = content.lower().strip()
content = content.strip() content = re.sub(r"\s+", " ", content)
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content return content
class DocumentClassifier(object): class DocumentClassifier(object):
FORMAT_VERSION = 5
def __init__(self): def __init__(self):
# mtime of the model file on disk. used to prevent reloading when nothing has changed.
self.classifier_version = 0 self.classifier_version = 0
# hash of the training data. used to prevent re-training when the training data has not changed.
self.data_hash = None
self.data_vectorizer = None self.data_vectorizer = None
self.tags_binarizer = None self.tags_binarizer = None
self.correspondent_binarizer = None
self.document_type_binarizer = None
self.tags_classifier = None self.tags_classifier = None
self.correspondent_classifier = None self.correspondent_classifier = None
self.document_type_classifier = None self.document_type_classifier = None
def reload(self): def reload(self):
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
logging.getLogger(__name__).info("Reloading classifier models")
with open(settings.MODEL_FILE, "rb") as f: with open(settings.MODEL_FILE, "rb") as f:
self.data_vectorizer = pickle.load(f) schema_version = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.correspondent_binarizer = pickle.load(f)
self.document_type_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f) if schema_version != self.FORMAT_VERSION:
self.correspondent_classifier = pickle.load(f) raise IncompatibleClassifierVersionError("Cannor load classifier, incompatible versions.")
self.document_type_classifier = pickle.load(f) else:
if self.classifier_version > 0:
logger.info("Classifier updated on disk, reloading classifier models")
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.classifier_version = os.path.getmtime(settings.MODEL_FILE) self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
def save_classifier(self): def save_classifier(self):
with open(settings.MODEL_FILE, "wb") as f: with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f) # Version
pickle.dump(self.data_hash, f)
pickle.dump(self.data_vectorizer, f) pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f) pickle.dump(self.tags_binarizer, f)
pickle.dump(self.correspondent_binarizer, f)
pickle.dump(self.document_type_binarizer, f)
pickle.dump(self.tags_classifier, f) pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f) pickle.dump(self.correspondent_classifier, f)
@ -68,109 +81,121 @@ class DocumentClassifier(object):
labels_document_type = list() labels_document_type = list()
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...") logging.getLogger(__name__).debug("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True): m = hashlib.sha1()
data.append(preprocess_content(doc.content)) for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True):
preprocessed_content = preprocess_content(doc.content)
m.update(preprocessed_content.encode('utf-8'))
data.append(preprocessed_content)
y = -1 y = -1
if doc.document_type: if doc.document_type:
if doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO: if doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.document_type.pk y = doc.document_type.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_document_type.append(y) labels_document_type.append(y)
y = -1 y = -1
if doc.correspondent: if doc.correspondent:
if doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO: if doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.correspondent.pk y = doc.correspondent.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_correspondent.append(y) labels_correspondent.append(y)
tags = [tag.pk for tag in doc.tags.filter( tags = [tag.pk for tag in doc.tags.filter(
matching_algorithm=MatchingModel.MATCH_AUTO matching_algorithm=MatchingModel.MATCH_AUTO
)] )]
m.update(bytearray(tags))
labels_tags.append(tags) labels_tags.append(tags)
if not data: if not data:
raise ValueError("No training data available.") raise ValueError("No training data available.")
new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash:
return False
labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
logging.getLogger(__name__).info(
num_tags = len(labels_tags_unique)
# substract 1 since -1 (null) is also part of the classes.
num_correspondents = len(labels_correspondent) - 1
num_document_types = len(labels_document_type) - 1
logging.getLogger(__name__).debug(
"{} documents, {} tag(s), {} correspondent(s), " "{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s).".format( "{} document type(s).".format(
len(data), len(data),
len(labels_tags_unique), num_tags,
len(set(labels_correspondent)), num_correspondents,
len(set(labels_document_type)) num_document_types
) )
) )
# Step 2: vectorize data # Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...") logging.getLogger(__name__).debug("Vectorizing data...")
self.data_vectorizer = CountVectorizer( self.data_vectorizer = CountVectorizer(
analyzer="char", analyzer="word",
ngram_range=(3, 5), ngram_range=(1,2),
min_df=0.1 min_df=0.01
) )
data_vectorized = self.data_vectorizer.fit_transform(data) data_vectorized = self.data_vectorizer.fit_transform(data)
self.tags_binarizer = MultiLabelBinarizer() self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
self.correspondent_binarizer = LabelBinarizer()
labels_correspondent_vectorized = \
self.correspondent_binarizer.fit_transform(labels_correspondent)
self.document_type_binarizer = LabelBinarizer()
labels_document_type_vectorized = \
self.document_type_binarizer.fit_transform(labels_document_type)
# Step 3: train the classifiers # Step 3: train the classifiers
if len(self.tags_binarizer.classes_) > 0: if num_tags > 0:
logging.getLogger(__name__).info("Training tags classifier...") logging.getLogger(__name__).debug("Training tags classifier...")
self.tags_classifier = MLPClassifier(verbose=True) self.tags_classifier = MLPClassifier(verbose=True, tol=0.01)
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else: else:
self.tags_classifier = None self.tags_classifier = None
logging.getLogger(__name__).info( logging.getLogger(__name__).debug(
"There are no tags. Not training tags classifier." "There are no tags. Not training tags classifier."
) )
if len(self.correspondent_binarizer.classes_) > 0: if num_correspondents > 0:
logging.getLogger(__name__).info( logging.getLogger(__name__).debug(
"Training correspondent classifier..." "Training correspondent classifier..."
) )
self.correspondent_classifier = MLPClassifier(verbose=True) self.correspondent_classifier = MLPClassifier(verbose=True, tol=0.01)
self.correspondent_classifier.fit( self.correspondent_classifier.fit(
data_vectorized, data_vectorized,
labels_correspondent_vectorized labels_correspondent
) )
else: else:
self.correspondent_classifier = None self.correspondent_classifier = None
logging.getLogger(__name__).info( logging.getLogger(__name__).debug(
"There are no correspondents. Not training correspondent " "There are no correspondents. Not training correspondent "
"classifier." "classifier."
) )
if len(self.document_type_binarizer.classes_) > 0: if num_document_types > 0:
logging.getLogger(__name__).info( logging.getLogger(__name__).debug(
"Training document type classifier..." "Training document type classifier..."
) )
self.document_type_classifier = MLPClassifier(verbose=True) self.document_type_classifier = MLPClassifier(verbose=True, tol=0.01)
self.document_type_classifier.fit( self.document_type_classifier.fit(
data_vectorized, data_vectorized,
labels_document_type_vectorized labels_document_type
) )
else: else:
self.document_type_classifier = None self.document_type_classifier = None
logging.getLogger(__name__).info( logging.getLogger(__name__).debug(
"There are no document types. Not training document type " "There are no document types. Not training document type "
"classifier." "classifier."
) )
self.data_hash = new_data_hash
return True
def predict_correspondent(self, content): def predict_correspondent(self, content):
if self.correspondent_classifier: if self.correspondent_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.correspondent_classifier.predict(X) correspondent_id = self.correspondent_classifier.predict(X)
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
if correspondent_id != -1: if correspondent_id != -1:
return correspondent_id return correspondent_id
else: else:
@ -181,8 +206,7 @@ class DocumentClassifier(object):
def predict_document_type(self, content): def predict_document_type(self, content):
if self.document_type_classifier: if self.document_type_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.document_type_classifier.predict(X) document_type_id = self.document_type_classifier.predict(X)
document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
if document_type_id != -1: if document_type_id != -1:
return document_type_id return document_type_id
else: else:

View File

@ -10,7 +10,7 @@ from django.db import transaction
from django.utils import timezone from django.utils import timezone
from paperless.db import GnuPG from paperless.db import GnuPG
from .classifier import DocumentClassifier from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
from .models import Document, FileInfo from .models import Document, FileInfo
from .parsers import ParseError, get_parser_class from .parsers import ParseError, get_parser_class
from .signals import ( from .signals import (
@ -133,11 +133,8 @@ class Consumer:
try: try:
self.classifier.reload() self.classifier.reload()
classifier = self.classifier classifier = self.classifier
except FileNotFoundError: except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
self.log("warning", "Cannot classify documents, classifier " logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
"model file was not found. Consider "
"running python manage.py "
"document_create_classifier.")
document_consumption_finished.send( document_consumption_finished.send(
sender=self.__class__, sender=self.__class__,

View File

@ -1,7 +1,8 @@
import logging import logging
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from paperless import settings from paperless import settings
from ...mixins import Renderable from ...mixins import Renderable
@ -18,12 +19,25 @@ class Command(Renderable, BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
classifier = DocumentClassifier() classifier = DocumentClassifier()
try: try:
classifier.train() # load the classifier, since we might not have to train it again.
logging.getLogger(__name__).info( classifier.reload()
"Saving models to {}...".format(settings.MODEL_FILE) except (FileNotFoundError, IncompatibleClassifierVersionError):
) # This is what we're going to fix here.
classifier.save_classifier() pass
try:
if classifier.train():
logging.getLogger(__name__).info(
"Saving updated classifier model to {}...".format(settings.MODEL_FILE)
)
classifier.save_classifier()
else:
logging.getLogger(__name__).debug(
"Training data unchanged."
)
except Exception as e: except Exception as e:
logging.getLogger(__name__).error( logging.getLogger(__name__).error(
"Classifier error: " + str(e) "Classifier error: " + str(e)

View File

@ -2,7 +2,8 @@ import logging
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from documents.models import Document from documents.models import Document
from ...mixins import Renderable from ...mixins import Renderable
from ...signals.handlers import set_correspondent, set_document_type, set_tags from ...signals.handlers import set_correspondent, set_document_type, set_tags
@ -72,10 +73,8 @@ class Command(Renderable, BaseCommand):
classifier = DocumentClassifier() classifier = DocumentClassifier()
try: try:
classifier.reload() classifier.reload()
except FileNotFoundError: except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning("Cannot classify documents, " logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
"classifier model file was not "
"found.")
classifier = None classifier = None
for document in documents: for document in documents: