From 30acfdd3f12c5709189b2c302ed3861497f16ba9 Mon Sep 17 00:00:00 2001 From: Jonas Winkler Date: Thu, 26 Nov 2020 14:18:10 +0100 Subject: [PATCH] tests for the classifier and fixes for edge cases with minimal data. --- src/documents/classifier.py | 45 +++++-- src/documents/tests/test_classifier.py | 155 ++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 11 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 6e0d6f946..b0d7d87bb 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -6,7 +6,8 @@ import re from sklearn.feature_extraction.text import CountVectorizer from sklearn.neural_network import MLPClassifier -from sklearn.preprocessing import MultiLabelBinarizer +from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer +from sklearn.utils.multiclass import type_of_target from documents.models import Document, MatchingModel from paperless import settings @@ -27,7 +28,7 @@ def preprocess_content(content): class DocumentClassifier(object): - FORMAT_VERSION = 5 + FORMAT_VERSION = 6 def __init__(self): # mtime of the model file on disk. used to prevent reloading when @@ -54,6 +55,8 @@ class DocumentClassifier(object): "Cannor load classifier, incompatible versions.") else: if self.classifier_version > 0: + # Don't be confused by this check. It's simply here + # so that we wont log anything on initial reload. logger.info("Classifier updated on disk, " "reloading classifier models") self.data_hash = pickle.load(f) @@ -122,9 +125,14 @@ class DocumentClassifier(object): labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) num_tags = len(labels_tags_unique) + # substract 1 since -1 (null) is also part of the classes. - num_correspondents = len(set(labels_correspondent)) - 1 - num_document_types = len(set(labels_document_type)) - 1 + + # union with {-1} accounts for cases where all documents have + # correspondents and types assigned, so -1 isnt part of labels_x, which + # it usually is. + num_correspondents = len(set(labels_correspondent) | {-1}) - 1 + num_document_types = len(set(labels_document_type) | {-1}) - 1 logging.getLogger(__name__).debug( "{} documents, {} tag(s), {} correspondent(s), " @@ -145,12 +153,23 @@ class DocumentClassifier(object): ) data_vectorized = self.data_vectorizer.fit_transform(data) - self.tags_binarizer = MultiLabelBinarizer() - labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) - # Step 3: train the classifiers if num_tags > 0: logging.getLogger(__name__).debug("Training tags classifier...") + + if num_tags == 1: + # Special case where only one tag has auto: + # Fallback to binary classification. + labels_tags = [label[0] if len(label) == 1 else -1 + for label in labels_tags] + self.tags_binarizer = LabelBinarizer() + labels_tags_vectorized = self.tags_binarizer.fit_transform( + labels_tags).ravel() + else: + self.tags_binarizer = MultiLabelBinarizer() + labels_tags_vectorized = self.tags_binarizer.fit_transform( + labels_tags) + self.tags_classifier = MLPClassifier(tol=0.01) self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) else: @@ -222,6 +241,16 @@ class DocumentClassifier(object): X = self.data_vectorizer.transform([preprocess_content(content)]) y = self.tags_classifier.predict(X) tags_ids = self.tags_binarizer.inverse_transform(y)[0] - return tags_ids + if type_of_target(y).startswith('multilabel'): + # the usual case when there are multiple tags. + return list(tags_ids) + elif type_of_target(y) == 'binary' and tags_ids != -1: + # This is for when we have binary classification with only one + # tag and the result is to assign this tag. + return [tags_ids] + else: + # Usually binary as well with -1 as the result, but we're + # going to catch everything else here as well. + return [] else: return [] diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 4ae672ac2..0f421bb32 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,8 +1,10 @@ import tempfile +from time import sleep +from unittest import mock from django.test import TestCase, override_settings -from documents.classifier import DocumentClassifier +from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError from documents.models import Correspondent, Document, Tag, DocumentType @@ -15,10 +17,12 @@ class TestClassifier(TestCase): def generate_test_data(self): self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) self.c2 = Correspondent.objects.create(name="c2") + self.c3 = Correspondent.objects.create(name="c3", matching_algorithm=Correspondent.MATCH_AUTO) self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True) self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45) self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) + self.dt2 = DocumentType.objects.create(name="dt2", matching_algorithm=DocumentType.MATCH_AUTO) self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt) self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B") @@ -59,8 +63,8 @@ class TestClassifier(TestCase): self.classifier.train() self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk) self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None) - self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,)) - self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk, self.t3.pk)) + self.assertListEqual(self.classifier.predict_tags(self.doc1.content), [self.t1.pk]) + self.assertListEqual(self.classifier.predict_tags(self.doc2.content), [self.t1.pk, self.t3.pk]) self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk) self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None) @@ -71,6 +75,42 @@ class TestClassifier(TestCase): self.assertTrue(self.classifier.train()) self.assertFalse(self.classifier.train()) + def testVersionIncreased(self): + + self.generate_test_data() + self.assertTrue(self.classifier.train()) + self.assertFalse(self.classifier.train()) + + classifier2 = DocumentClassifier() + + current_ver = DocumentClassifier.FORMAT_VERSION + with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): + # assure that we won't load old classifiers. + self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload) + + self.classifier.save_classifier() + + # assure that we can load the classifier after saving it. + classifier2.reload() + + def testReload(self): + + self.generate_test_data() + self.assertTrue(self.classifier.train()) + self.classifier.save_classifier() + + classifier2 = DocumentClassifier() + classifier2.reload() + v1 = classifier2.classifier_version + + # change the classifier after some time. + sleep(1) + self.classifier.save_classifier() + + classifier2.reload() + v2 = classifier2.classifier_version + self.assertNotEqual(v1, v2) + @override_settings(DATA_DIR=tempfile.mkdtemp()) def testSaveClassifier(self): @@ -83,3 +123,112 @@ class TestClassifier(TestCase): new_classifier = DocumentClassifier() new_classifier.reload() self.assertFalse(new_classifier.train()) + + def test_one_correspondent_predict(self): + c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") + + self.classifier.train() + self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk) + + def test_one_correspondent_predict_manydocs(self): + c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") + doc2 = Document.objects.create(title="doc2", content="this is a document from noone", checksum="B") + + self.classifier.train() + self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk) + self.assertIsNone(self.classifier.predict_correspondent(doc2.content)) + + def test_one_type_predict(self): + dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", + checksum="A", document_type=dt) + + self.classifier.train() + self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk) + + def test_one_type_predict_manydocs(self): + dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", + checksum="A", document_type=dt) + + doc2 = Document.objects.create(title="doc1", content="this is a document from c2", + checksum="B") + + self.classifier.train() + self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk) + self.assertIsNone(self.classifier.predict_document_type(doc2.content)) + + def test_one_tag_predict(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") + + doc1.tags.add(t1) + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) + + def test_one_tag_predict_unassigned(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") + + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc1.content), []) + + def test_two_tags_predict_singledoc(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121) + + doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D") + + doc4.tags.add(t1) + doc4.tags.add(t2) + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk]) + + def test_two_tags_predict(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") + doc2 = Document.objects.create(title="doc1", content="this is a document from c2", checksum="B") + doc3 = Document.objects.create(title="doc1", content="this is a document from c3", checksum="C") + doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D") + + doc1.tags.add(t1) + doc2.tags.add(t2) + + doc4.tags.add(t1) + doc4.tags.add(t2) + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) + self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk]) + self.assertListEqual(self.classifier.predict_tags(doc3.content), []) + self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk]) + + def test_one_tag_predict_multi(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") + doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B") + + doc1.tags.add(t1) + doc2.tags.add(t1) + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) + self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk]) + + def test_one_tag_predict_multi_2(self): + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) + + doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") + doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B") + + doc1.tags.add(t1) + self.classifier.train() + self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) + self.assertListEqual(self.classifier.predict_tags(doc2.content), [])