mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
tests for the classifier and fixes for edge cases with minimal data.
This commit is contained in:
parent
2a4fe4dceb
commit
30acfdd3f1
@ -6,7 +6,8 @@ import re
|
||||
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
from documents.models import Document, MatchingModel
|
||||
from paperless import settings
|
||||
@ -27,7 +28,7 @@ def preprocess_content(content):
|
||||
|
||||
class DocumentClassifier(object):
|
||||
|
||||
FORMAT_VERSION = 5
|
||||
FORMAT_VERSION = 6
|
||||
|
||||
def __init__(self):
|
||||
# mtime of the model file on disk. used to prevent reloading when
|
||||
@ -54,6 +55,8 @@ class DocumentClassifier(object):
|
||||
"Cannor load classifier, incompatible versions.")
|
||||
else:
|
||||
if self.classifier_version > 0:
|
||||
# Don't be confused by this check. It's simply here
|
||||
# so that we wont log anything on initial reload.
|
||||
logger.info("Classifier updated on disk, "
|
||||
"reloading classifier models")
|
||||
self.data_hash = pickle.load(f)
|
||||
@ -122,9 +125,14 @@ class DocumentClassifier(object):
|
||||
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
# substract 1 since -1 (null) is also part of the classes.
|
||||
num_correspondents = len(set(labels_correspondent)) - 1
|
||||
num_document_types = len(set(labels_document_type)) - 1
|
||||
|
||||
# union with {-1} accounts for cases where all documents have
|
||||
# correspondents and types assigned, so -1 isnt part of labels_x, which
|
||||
# it usually is.
|
||||
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
|
||||
num_document_types = len(set(labels_document_type) | {-1}) - 1
|
||||
|
||||
logging.getLogger(__name__).debug(
|
||||
"{} documents, {} tag(s), {} correspondent(s), "
|
||||
@ -145,12 +153,23 @@ class DocumentClassifier(object):
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
||||
|
||||
# Step 3: train the classifiers
|
||||
if num_tags > 0:
|
||||
logging.getLogger(__name__).debug("Training tags classifier...")
|
||||
|
||||
if num_tags == 1:
|
||||
# Special case where only one tag has auto:
|
||||
# Fallback to binary classification.
|
||||
labels_tags = [label[0] if len(label) == 1 else -1
|
||||
for label in labels_tags]
|
||||
self.tags_binarizer = LabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags).ravel()
|
||||
else:
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags)
|
||||
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||
else:
|
||||
@ -222,6 +241,16 @@ class DocumentClassifier(object):
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
return tags_ids
|
||||
if type_of_target(y).startswith('multilabel'):
|
||||
# the usual case when there are multiple tags.
|
||||
return list(tags_ids)
|
||||
elif type_of_target(y) == 'binary' and tags_ids != -1:
|
||||
# This is for when we have binary classification with only one
|
||||
# tag and the result is to assign this tag.
|
||||
return [tags_ids]
|
||||
else:
|
||||
# Usually binary as well with -1 as the result, but we're
|
||||
# going to catch everything else here as well.
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
@ -1,8 +1,10 @@
|
||||
import tempfile
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||
|
||||
|
||||
@ -15,10 +17,12 @@ class TestClassifier(TestCase):
|
||||
def generate_test_data(self):
|
||||
self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
self.c2 = Correspondent.objects.create(name="c2")
|
||||
self.c3 = Correspondent.objects.create(name="c3", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True)
|
||||
self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45)
|
||||
self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
self.dt2 = DocumentType.objects.create(name="dt2", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt)
|
||||
self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B")
|
||||
@ -59,8 +63,8 @@ class TestClassifier(TestCase):
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk)
|
||||
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,))
|
||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk, self.t3.pk))
|
||||
self.assertListEqual(self.classifier.predict_tags(self.doc1.content), [self.t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(self.doc2.content), [self.t1.pk, self.t3.pk])
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk)
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||
|
||||
@ -71,6 +75,42 @@ class TestClassifier(TestCase):
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
def testVersionIncreased(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
|
||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||
# assure that we won't load old classifiers.
|
||||
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
|
||||
|
||||
self.classifier.save_classifier()
|
||||
|
||||
# assure that we can load the classifier after saving it.
|
||||
classifier2.reload()
|
||||
|
||||
def testReload(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
classifier2.reload()
|
||||
v1 = classifier2.classifier_version
|
||||
|
||||
# change the classifier after some time.
|
||||
sleep(1)
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2.reload()
|
||||
v2 = classifier2.classifier_version
|
||||
self.assertNotEqual(v1, v2)
|
||||
|
||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||
def testSaveClassifier(self):
|
||||
|
||||
@ -83,3 +123,112 @@ class TestClassifier(TestCase):
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.reload()
|
||||
self.assertFalse(new_classifier.train())
|
||||
|
||||
def test_one_correspondent_predict(self):
|
||||
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||
|
||||
def test_one_correspondent_predict_manydocs(self):
|
||||
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from noone", checksum="B")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
|
||||
|
||||
def test_one_type_predict(self):
|
||||
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||
checksum="A", document_type=dt)
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||
|
||||
def test_one_type_predict_manydocs(self):
|
||||
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||
checksum="A", document_type=dt)
|
||||
|
||||
doc2 = Document.objects.create(title="doc1", content="this is a document from c2",
|
||||
checksum="B")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
|
||||
|
||||
def test_one_tag_predict(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
|
||||
def test_one_tag_predict_unassigned(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
|
||||
|
||||
def test_two_tags_predict_singledoc(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||
|
||||
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||
|
||||
doc4.tags.add(t1)
|
||||
doc4.tags.add(t2)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||
|
||||
def test_two_tags_predict(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc1", content="this is a document from c2", checksum="B")
|
||||
doc3 = Document.objects.create(title="doc1", content="this is a document from c3", checksum="C")
|
||||
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
doc2.tags.add(t2)
|
||||
|
||||
doc4.tags.add(t1)
|
||||
doc4.tags.add(t2)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||
|
||||
def test_one_tag_predict_multi(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
doc2.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
|
||||
|
||||
def test_one_tag_predict_multi_2(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
||||
|
Loading…
x
Reference in New Issue
Block a user