tests for the classifier and fixes for edge cases with minimal data.

This commit is contained in:
Jonas Winkler
2020-11-26 14:18:10 +01:00
parent 2a4fe4dceb
commit 30acfdd3f1
2 changed files with 189 additions and 11 deletions

View File

@@ -6,7 +6,8 @@ import re
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.utils.multiclass import type_of_target
from documents.models import Document, MatchingModel
from paperless import settings
@@ -27,7 +28,7 @@ def preprocess_content(content):
class DocumentClassifier(object):
FORMAT_VERSION = 5
FORMAT_VERSION = 6
def __init__(self):
# mtime of the model file on disk. used to prevent reloading when
@@ -54,6 +55,8 @@ class DocumentClassifier(object):
"Cannor load classifier, incompatible versions.")
else:
if self.classifier_version > 0:
# Don't be confused by this check. It's simply here
# so that we wont log anything on initial reload.
logger.info("Classifier updated on disk, "
"reloading classifier models")
self.data_hash = pickle.load(f)
@@ -122,9 +125,14 @@ class DocumentClassifier(object):
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
num_tags = len(labels_tags_unique)
# substract 1 since -1 (null) is also part of the classes.
num_correspondents = len(set(labels_correspondent)) - 1
num_document_types = len(set(labels_document_type)) - 1
# union with {-1} accounts for cases where all documents have
# correspondents and types assigned, so -1 isnt part of labels_x, which
# it usually is.
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
num_document_types = len(set(labels_document_type) | {-1}) - 1
logging.getLogger(__name__).debug(
"{} documents, {} tag(s), {} correspondent(s), "
@@ -145,12 +153,23 @@ class DocumentClassifier(object):
)
data_vectorized = self.data_vectorizer.fit_transform(data)
self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
# Step 3: train the classifiers
if num_tags > 0:
logging.getLogger(__name__).debug("Training tags classifier...")
if num_tags == 1:
# Special case where only one tag has auto:
# Fallback to binary classification.
labels_tags = [label[0] if len(label) == 1 else -1
for label in labels_tags]
self.tags_binarizer = LabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(
labels_tags).ravel()
else:
self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(
labels_tags)
self.tags_classifier = MLPClassifier(tol=0.01)
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
@@ -222,6 +241,16 @@ class DocumentClassifier(object):
X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
return tags_ids
if type_of_target(y).startswith('multilabel'):
# the usual case when there are multiple tags.
return list(tags_ids)
elif type_of_target(y) == 'binary' and tags_ids != -1:
# This is for when we have binary classification with only one
# tag and the result is to assign this tag.
return [tags_ids]
else:
# Usually binary as well with -1 as the result, but we're
# going to catch everything else here as well.
return []
else:
return []