updated the classifier. Its now much faster and does not retrain when data hasnt changed.

This commit is contained in:
Jonas Winkler 2020-11-06 14:46:06 +01:00
parent 9fa5eac9b9
commit 296c113b16
4 changed files with 109 additions and 75 deletions

View File

@ -1,48 +1,61 @@
import hashlib
import logging
import os
import pickle
import re
import time
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from documents.models import Document, MatchingModel
from paperless import settings
class IncompatibleClassifierVersionError(Exception):
pass
logger = logging.getLogger(__name__)
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
class DocumentClassifier(object):
FORMAT_VERSION = 5
def __init__(self):
# mtime of the model file on disk. used to prevent reloading when nothing has changed.
self.classifier_version = 0
# hash of the training data. used to prevent re-training when the training data has not changed.
self.data_hash = None
self.data_vectorizer = None
self.tags_binarizer = None
self.correspondent_binarizer = None
self.document_type_binarizer = None
self.tags_classifier = None
self.correspondent_classifier = None
self.document_type_classifier = None
def reload(self):
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
logging.getLogger(__name__).info("Reloading classifier models")
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError("Cannor load classifier, incompatible versions.")
else:
if self.classifier_version > 0:
logger.info("Classifier updated on disk, reloading classifier models")
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.correspondent_binarizer = pickle.load(f)
self.document_type_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
@ -51,11 +64,11 @@ class DocumentClassifier(object):
def save_classifier(self):
with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f) # Version
pickle.dump(self.data_hash, f)
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.correspondent_binarizer, f)
pickle.dump(self.document_type_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
@ -68,109 +81,121 @@ class DocumentClassifier(object):
labels_document_type = list()
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content))
logging.getLogger(__name__).debug("Gathering data from database...")
m = hashlib.sha1()
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True):
preprocessed_content = preprocess_content(doc.content)
m.update(preprocessed_content.encode('utf-8'))
data.append(preprocessed_content)
y = -1
if doc.document_type:
if doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.document_type.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_document_type.append(y)
y = -1
if doc.correspondent:
if doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.correspondent.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_correspondent.append(y)
tags = [tag.pk for tag in doc.tags.filter(
matching_algorithm=MatchingModel.MATCH_AUTO
)]
m.update(bytearray(tags))
labels_tags.append(tags)
if not data:
raise ValueError("No training data available.")
new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash:
return False
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
logging.getLogger(__name__).info(
num_tags = len(labels_tags_unique)
# substract 1 since -1 (null) is also part of the classes.
num_correspondents = len(labels_correspondent) - 1
num_document_types = len(labels_document_type) - 1
logging.getLogger(__name__).debug(
"{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s).".format(
len(data),
len(labels_tags_unique),
len(set(labels_correspondent)),
len(set(labels_document_type))
num_tags,
num_correspondents,
num_document_types
)
)
# Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...")
logging.getLogger(__name__).debug("Vectorizing data...")
self.data_vectorizer = CountVectorizer(
analyzer="char",
ngram_range=(3, 5),
min_df=0.1
analyzer="word",
ngram_range=(1,2),
min_df=0.01
)
data_vectorized = self.data_vectorizer.fit_transform(data)
self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
self.correspondent_binarizer = LabelBinarizer()
labels_correspondent_vectorized = \
self.correspondent_binarizer.fit_transform(labels_correspondent)
self.document_type_binarizer = LabelBinarizer()
labels_document_type_vectorized = \
self.document_type_binarizer.fit_transform(labels_document_type)
# Step 3: train the classifiers
if len(self.tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier...")
self.tags_classifier = MLPClassifier(verbose=True)
if num_tags > 0:
logging.getLogger(__name__).debug("Training tags classifier...")
self.tags_classifier = MLPClassifier(verbose=True, tol=0.01)
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
self.tags_classifier = None
logging.getLogger(__name__).info(
logging.getLogger(__name__).debug(
"There are no tags. Not training tags classifier."
)
if len(self.correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info(
if num_correspondents > 0:
logging.getLogger(__name__).debug(
"Training correspondent classifier..."
)
self.correspondent_classifier = MLPClassifier(verbose=True)
self.correspondent_classifier = MLPClassifier(verbose=True, tol=0.01)
self.correspondent_classifier.fit(
data_vectorized,
labels_correspondent_vectorized
labels_correspondent
)
else:
self.correspondent_classifier = None
logging.getLogger(__name__).info(
logging.getLogger(__name__).debug(
"There are no correspondents. Not training correspondent "
"classifier."
)
if len(self.document_type_binarizer.classes_) > 0:
logging.getLogger(__name__).info(
if num_document_types > 0:
logging.getLogger(__name__).debug(
"Training document type classifier..."
)
self.document_type_classifier = MLPClassifier(verbose=True)
self.document_type_classifier = MLPClassifier(verbose=True, tol=0.01)
self.document_type_classifier.fit(
data_vectorized,
labels_document_type_vectorized
labels_document_type
)
else:
self.document_type_classifier = None
logging.getLogger(__name__).info(
logging.getLogger(__name__).debug(
"There are no document types. Not training document type "
"classifier."
)
self.data_hash = new_data_hash
return True
def predict_correspondent(self, content):
if self.correspondent_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.correspondent_classifier.predict(X)
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
correspondent_id = self.correspondent_classifier.predict(X)
if correspondent_id != -1:
return correspondent_id
else:
@ -181,8 +206,7 @@ class DocumentClassifier(object):
def predict_document_type(self, content):
if self.document_type_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.document_type_classifier.predict(X)
document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
document_type_id = self.document_type_classifier.predict(X)
if document_type_id != -1:
return document_type_id
else:

View File

@ -10,7 +10,7 @@ from django.db import transaction
from django.utils import timezone
from paperless.db import GnuPG
from .classifier import DocumentClassifier
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
from .models import Document, FileInfo
from .parsers import ParseError, get_parser_class
from .signals import (
@ -133,11 +133,8 @@ class Consumer:
try:
self.classifier.reload()
classifier = self.classifier
except FileNotFoundError:
self.log("warning", "Cannot classify documents, classifier "
"model file was not found. Consider "
"running python manage.py "
"document_create_classifier.")
except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
document_consumption_finished.send(
sender=self.__class__,

View File

@ -1,7 +1,8 @@
import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from paperless import settings
from ...mixins import Renderable
@ -18,12 +19,25 @@ class Command(Renderable, BaseCommand):
def handle(self, *args, **options):
classifier = DocumentClassifier()
try:
classifier.train()
# load the classifier, since we might not have to train it again.
classifier.reload()
except (FileNotFoundError, IncompatibleClassifierVersionError):
# This is what we're going to fix here.
pass
try:
if classifier.train():
logging.getLogger(__name__).info(
"Saving models to {}...".format(settings.MODEL_FILE)
"Saving updated classifier model to {}...".format(settings.MODEL_FILE)
)
classifier.save_classifier()
else:
logging.getLogger(__name__).debug(
"Training data unchanged."
)
except Exception as e:
logging.getLogger(__name__).error(
"Classifier error: " + str(e)

View File

@ -2,7 +2,8 @@ import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from documents.models import Document
from ...mixins import Renderable
from ...signals.handlers import set_correspondent, set_document_type, set_tags
@ -72,10 +73,8 @@ class Command(Renderable, BaseCommand):
classifier = DocumentClassifier()
try:
classifier.reload()
except FileNotFoundError:
logging.getLogger(__name__).warning("Cannot classify documents, "
"classifier model file was not "
"found.")
except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
classifier = None
for document in documents: