mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-09 09:58:20 -05:00
tests for the classifier and fixes for edge cases with minimal data.
This commit is contained in:
parent
2a4fe4dceb
commit
30acfdd3f1
@ -6,7 +6,8 @@ import re
|
|||||||
|
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
from sklearn.neural_network import MLPClassifier
|
from sklearn.neural_network import MLPClassifier
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
from documents.models import Document, MatchingModel
|
from documents.models import Document, MatchingModel
|
||||||
from paperless import settings
|
from paperless import settings
|
||||||
@ -27,7 +28,7 @@ def preprocess_content(content):
|
|||||||
|
|
||||||
class DocumentClassifier(object):
|
class DocumentClassifier(object):
|
||||||
|
|
||||||
FORMAT_VERSION = 5
|
FORMAT_VERSION = 6
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# mtime of the model file on disk. used to prevent reloading when
|
# mtime of the model file on disk. used to prevent reloading when
|
||||||
@ -54,6 +55,8 @@ class DocumentClassifier(object):
|
|||||||
"Cannor load classifier, incompatible versions.")
|
"Cannor load classifier, incompatible versions.")
|
||||||
else:
|
else:
|
||||||
if self.classifier_version > 0:
|
if self.classifier_version > 0:
|
||||||
|
# Don't be confused by this check. It's simply here
|
||||||
|
# so that we wont log anything on initial reload.
|
||||||
logger.info("Classifier updated on disk, "
|
logger.info("Classifier updated on disk, "
|
||||||
"reloading classifier models")
|
"reloading classifier models")
|
||||||
self.data_hash = pickle.load(f)
|
self.data_hash = pickle.load(f)
|
||||||
@ -122,9 +125,14 @@ class DocumentClassifier(object):
|
|||||||
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
|
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
|
||||||
|
|
||||||
num_tags = len(labels_tags_unique)
|
num_tags = len(labels_tags_unique)
|
||||||
|
|
||||||
# substract 1 since -1 (null) is also part of the classes.
|
# substract 1 since -1 (null) is also part of the classes.
|
||||||
num_correspondents = len(set(labels_correspondent)) - 1
|
|
||||||
num_document_types = len(set(labels_document_type)) - 1
|
# union with {-1} accounts for cases where all documents have
|
||||||
|
# correspondents and types assigned, so -1 isnt part of labels_x, which
|
||||||
|
# it usually is.
|
||||||
|
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
|
||||||
|
num_document_types = len(set(labels_document_type) | {-1}) - 1
|
||||||
|
|
||||||
logging.getLogger(__name__).debug(
|
logging.getLogger(__name__).debug(
|
||||||
"{} documents, {} tag(s), {} correspondent(s), "
|
"{} documents, {} tag(s), {} correspondent(s), "
|
||||||
@ -145,12 +153,23 @@ class DocumentClassifier(object):
|
|||||||
)
|
)
|
||||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||||
|
|
||||||
self.tags_binarizer = MultiLabelBinarizer()
|
|
||||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
|
||||||
|
|
||||||
# Step 3: train the classifiers
|
# Step 3: train the classifiers
|
||||||
if num_tags > 0:
|
if num_tags > 0:
|
||||||
logging.getLogger(__name__).debug("Training tags classifier...")
|
logging.getLogger(__name__).debug("Training tags classifier...")
|
||||||
|
|
||||||
|
if num_tags == 1:
|
||||||
|
# Special case where only one tag has auto:
|
||||||
|
# Fallback to binary classification.
|
||||||
|
labels_tags = [label[0] if len(label) == 1 else -1
|
||||||
|
for label in labels_tags]
|
||||||
|
self.tags_binarizer = LabelBinarizer()
|
||||||
|
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||||
|
labels_tags).ravel()
|
||||||
|
else:
|
||||||
|
self.tags_binarizer = MultiLabelBinarizer()
|
||||||
|
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||||
|
labels_tags)
|
||||||
|
|
||||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||||
else:
|
else:
|
||||||
@ -222,6 +241,16 @@ class DocumentClassifier(object):
|
|||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||||
y = self.tags_classifier.predict(X)
|
y = self.tags_classifier.predict(X)
|
||||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||||
return tags_ids
|
if type_of_target(y).startswith('multilabel'):
|
||||||
|
# the usual case when there are multiple tags.
|
||||||
|
return list(tags_ids)
|
||||||
|
elif type_of_target(y) == 'binary' and tags_ids != -1:
|
||||||
|
# This is for when we have binary classification with only one
|
||||||
|
# tag and the result is to assign this tag.
|
||||||
|
return [tags_ids]
|
||||||
|
else:
|
||||||
|
# Usually binary as well with -1 as the result, but we're
|
||||||
|
# going to catch everything else here as well.
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
from time import sleep
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
|
|
||||||
from documents.classifier import DocumentClassifier
|
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||||
|
|
||||||
|
|
||||||
@ -15,10 +17,12 @@ class TestClassifier(TestCase):
|
|||||||
def generate_test_data(self):
|
def generate_test_data(self):
|
||||||
self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||||
self.c2 = Correspondent.objects.create(name="c2")
|
self.c2 = Correspondent.objects.create(name="c2")
|
||||||
|
self.c3 = Correspondent.objects.create(name="c3", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||||
self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True)
|
self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True)
|
||||||
self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45)
|
self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45)
|
||||||
self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||||
|
self.dt2 = DocumentType.objects.create(name="dt2", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||||
|
|
||||||
self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt)
|
self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt)
|
||||||
self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B")
|
self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B")
|
||||||
@ -59,8 +63,8 @@ class TestClassifier(TestCase):
|
|||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk)
|
self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk)
|
||||||
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
||||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,))
|
self.assertListEqual(self.classifier.predict_tags(self.doc1.content), [self.t1.pk])
|
||||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk, self.t3.pk))
|
self.assertListEqual(self.classifier.predict_tags(self.doc2.content), [self.t1.pk, self.t3.pk])
|
||||||
self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk)
|
self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk)
|
||||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||||
|
|
||||||
@ -71,6 +75,42 @@ class TestClassifier(TestCase):
|
|||||||
self.assertTrue(self.classifier.train())
|
self.assertTrue(self.classifier.train())
|
||||||
self.assertFalse(self.classifier.train())
|
self.assertFalse(self.classifier.train())
|
||||||
|
|
||||||
|
def testVersionIncreased(self):
|
||||||
|
|
||||||
|
self.generate_test_data()
|
||||||
|
self.assertTrue(self.classifier.train())
|
||||||
|
self.assertFalse(self.classifier.train())
|
||||||
|
|
||||||
|
classifier2 = DocumentClassifier()
|
||||||
|
|
||||||
|
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||||
|
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||||
|
# assure that we won't load old classifiers.
|
||||||
|
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
|
||||||
|
|
||||||
|
self.classifier.save_classifier()
|
||||||
|
|
||||||
|
# assure that we can load the classifier after saving it.
|
||||||
|
classifier2.reload()
|
||||||
|
|
||||||
|
def testReload(self):
|
||||||
|
|
||||||
|
self.generate_test_data()
|
||||||
|
self.assertTrue(self.classifier.train())
|
||||||
|
self.classifier.save_classifier()
|
||||||
|
|
||||||
|
classifier2 = DocumentClassifier()
|
||||||
|
classifier2.reload()
|
||||||
|
v1 = classifier2.classifier_version
|
||||||
|
|
||||||
|
# change the classifier after some time.
|
||||||
|
sleep(1)
|
||||||
|
self.classifier.save_classifier()
|
||||||
|
|
||||||
|
classifier2.reload()
|
||||||
|
v2 = classifier2.classifier_version
|
||||||
|
self.assertNotEqual(v1, v2)
|
||||||
|
|
||||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||||
def testSaveClassifier(self):
|
def testSaveClassifier(self):
|
||||||
|
|
||||||
@ -83,3 +123,112 @@ class TestClassifier(TestCase):
|
|||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.reload()
|
new_classifier.reload()
|
||||||
self.assertFalse(new_classifier.train())
|
self.assertFalse(new_classifier.train())
|
||||||
|
|
||||||
|
def test_one_correspondent_predict(self):
|
||||||
|
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||||
|
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||||
|
|
||||||
|
def test_one_correspondent_predict_manydocs(self):
|
||||||
|
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||||
|
doc2 = Document.objects.create(title="doc2", content="this is a document from noone", checksum="B")
|
||||||
|
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||||
|
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
|
||||||
|
|
||||||
|
def test_one_type_predict(self):
|
||||||
|
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||||
|
checksum="A", document_type=dt)
|
||||||
|
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||||
|
|
||||||
|
def test_one_type_predict_manydocs(self):
|
||||||
|
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||||
|
checksum="A", document_type=dt)
|
||||||
|
|
||||||
|
doc2 = Document.objects.create(title="doc1", content="this is a document from c2",
|
||||||
|
checksum="B")
|
||||||
|
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||||
|
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
|
||||||
|
|
||||||
|
def test_one_tag_predict(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||||
|
|
||||||
|
doc1.tags.add(t1)
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||||
|
|
||||||
|
def test_one_tag_predict_unassigned(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||||
|
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
|
||||||
|
|
||||||
|
def test_two_tags_predict_singledoc(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||||
|
|
||||||
|
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||||
|
|
||||||
|
doc4.tags.add(t1)
|
||||||
|
doc4.tags.add(t2)
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||||
|
|
||||||
|
def test_two_tags_predict(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||||
|
doc2 = Document.objects.create(title="doc1", content="this is a document from c2", checksum="B")
|
||||||
|
doc3 = Document.objects.create(title="doc1", content="this is a document from c3", checksum="C")
|
||||||
|
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||||
|
|
||||||
|
doc1.tags.add(t1)
|
||||||
|
doc2.tags.add(t2)
|
||||||
|
|
||||||
|
doc4.tags.add(t1)
|
||||||
|
doc4.tags.add(t2)
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||||
|
|
||||||
|
def test_one_tag_predict_multi(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||||
|
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||||
|
|
||||||
|
doc1.tags.add(t1)
|
||||||
|
doc2.tags.add(t1)
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
|
||||||
|
|
||||||
|
def test_one_tag_predict_multi_2(self):
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||||
|
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||||
|
|
||||||
|
doc1.tags.add(t1)
|
||||||
|
self.classifier.train()
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||||
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user