paperless-ngx/src/documents/tests/test_classifier.py
2023-04-26 09:35:27 -07:00

664 lines
21 KiB
Python

import os
import re
from pathlib import Path
from unittest import mock
import pytest
from django.conf import settings
from django.test import TestCase
from django.test import override_settings
from documents.classifier import ClassifierModelCorruptError
from documents.classifier import DocumentClassifier
from documents.classifier import IncompatibleClassifierVersionError
from documents.classifier import load_classifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
def dummy_preprocess(content: str):
"""
Simpler, faster pre-processing for testing purposes
"""
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
class TestClassifier(DirectoriesMixin, TestCase):
def setUp(self):
super().setUp()
self.classifier = DocumentClassifier()
self.classifier.preprocess_content = mock.MagicMock(
side_effect=dummy_preprocess,
)
def generate_test_data(self):
self.c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.c2 = Correspondent.objects.create(name="c2")
self.c3 = Correspondent.objects.create(
name="c3",
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.t1 = Tag.objects.create(
name="t1",
matching_algorithm=Tag.MATCH_AUTO,
pk=12,
)
self.t2 = Tag.objects.create(
name="t2",
matching_algorithm=Tag.MATCH_ANY,
pk=34,
is_inbox_tag=True,
)
self.t3 = Tag.objects.create(
name="t3",
matching_algorithm=Tag.MATCH_AUTO,
pk=45,
)
self.t4 = Tag.objects.create(
name="t4",
matching_algorithm=Tag.MATCH_ANY,
pk=46,
)
self.dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.dt2 = DocumentType.objects.create(
name="dt2",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.sp1 = StoragePath.objects.create(
name="sp1",
path="path1",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.sp2 = StoragePath.objects.create(
name="sp2",
path="path2",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.store_paths = [self.sp1, self.sp2]
self.doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=self.c1,
checksum="A",
document_type=self.dt,
storage_path=self.sp1,
)
self.doc2 = Document.objects.create(
title="doc1",
content="this is another document, but from c2",
correspondent=self.c2,
checksum="B",
)
self.doc_inbox = Document.objects.create(
title="doc235",
content="aa",
checksum="C",
)
self.doc1.tags.add(self.t1)
self.doc2.tags.add(self.t1)
self.doc2.tags.add(self.t3)
self.doc_inbox.tags.add(self.t2)
def generate_train_and_save(self):
"""
Generates the training data, trains and saves the updated pickle
file. This ensures the test is using the same scikit learn version
and eliminates a warning from the test suite
"""
self.generate_test_data()
self.classifier.train()
self.classifier.save()
def test_no_training_data(self):
"""
GIVEN:
- No documents exist to train
WHEN:
- Classifier training is requested
THEN:
- Exception is raised
"""
with self.assertRaisesMessage(ValueError, "No training data available."):
self.classifier.train()
def test_no_non_inbox_tags(self):
"""
GIVEN:
- No documents without an inbox tag exist
WHEN:
- Classifier training is requested
THEN:
- Exception is raised
"""
t1 = Tag.objects.create(
name="t1",
matching_algorithm=Tag.MATCH_ANY,
pk=34,
is_inbox_tag=True,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc1.tags.add(t1)
with self.assertRaisesMessage(ValueError, "No training data available."):
self.classifier.train()
def testEmpty(self):
"""
GIVEN:
- A document exists
- No tags/not enough data to predict
WHEN:
- Classifier prediction is requested
THEN:
- Classifier returns no predictions
"""
Document.objects.create(title="WOW", checksum="3457", content="ASD")
self.classifier.train()
self.assertIsNone(self.classifier.document_type_classifier)
self.assertIsNone(self.classifier.tags_classifier)
self.assertIsNone(self.classifier.correspondent_classifier)
self.assertListEqual(self.classifier.predict_tags(""), [])
self.assertIsNone(self.classifier.predict_document_type(""))
self.assertIsNone(self.classifier.predict_correspondent(""))
def testTrain(self):
"""
GIVEN:
- Test data
WHEN:
- Classifier is trained
THEN:
- Classifier uses correct values for correspondent learning
- Classifier uses correct values for tags learning
"""
self.generate_test_data()
self.classifier.train()
self.assertListEqual(
list(self.classifier.correspondent_classifier.classes_),
[-1, self.c1.pk],
)
self.assertListEqual(
list(self.classifier.tags_binarizer.classes_),
[self.t1.pk, self.t3.pk],
)
def testPredict(self):
"""
GIVEN:
- Classifier trained against test data
WHEN:
- Prediction requested for correspondent, tags, type
THEN:
- Expected predictions based on training set
"""
self.generate_test_data()
self.classifier.train()
self.assertEqual(
self.classifier.predict_correspondent(self.doc1.content),
self.c1.pk,
)
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
self.assertListEqual(
self.classifier.predict_tags(self.doc1.content),
[self.t1.pk],
)
self.assertListEqual(
self.classifier.predict_tags(self.doc2.content),
[self.t1.pk, self.t3.pk],
)
self.assertEqual(
self.classifier.predict_document_type(self.doc1.content),
self.dt.pk,
)
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
def test_no_retrain_if_no_change(self):
"""
GIVEN:
- Classifier trained with current data
WHEN:
- Classifier training is requested again
THEN:
- Classifier does not redo training
"""
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
def test_retrain_if_change(self):
"""
GIVEN:
- Classifier trained with current data
WHEN:
- Classifier training is requested again
- Documents have changed
THEN:
- Classifier does not redo training
"""
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.doc1.correspondent = self.c2
self.doc1.save()
self.assertTrue(self.classifier.train())
def test_retrain_if_auto_match_set_changed(self):
"""
GIVEN:
- Classifier trained with current data
WHEN:
- Classifier training is requested again
- Some new AUTO match object exists
THEN:
- Classifier does redo training
"""
self.generate_test_data()
# Add the ANY type
self.doc1.tags.add(self.t4)
self.assertTrue(self.classifier.train())
# Change the matching type
self.t4.matching_algorithm = MatchingModel.MATCH_AUTO
self.t4.save()
self.assertTrue(self.classifier.train())
def testVersionIncreased(self):
"""
GIVEN:
- Existing classifier model saved at a version
WHEN:
- Attempt to load classifier file from newer version
THEN:
- Exception is raised
"""
self.generate_train_and_save()
classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch(
"documents.classifier.DocumentClassifier.FORMAT_VERSION",
current_ver + 1,
):
# assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
self.classifier.save()
# assure that we can load the classifier after saving it.
classifier2.load()
def testSaveClassifier(self):
self.generate_train_and_save()
new_classifier = DocumentClassifier()
new_classifier.load()
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
self.assertFalse(new_classifier.train())
def test_load_and_classify(self):
self.generate_train_and_save()
new_classifier = DocumentClassifier()
new_classifier.load()
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
"""
GIVEN:
- Corrupted classifier pickle file
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
self.generate_train_and_save()
# First load is the schema version,allow it
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
patched_pickle_load.assert_called()
patched_pickle_load.reset_mock()
patched_pickle_load.side_effect = [
DocumentClassifier.FORMAT_VERSION,
ClassifierModelCorruptError(),
]
self.assertIsNone(load_classifier())
patched_pickle_load.assert_called()
def test_load_new_scikit_learn_version(self):
"""
GIVEN:
- classifier pickle file created with a different scikit-learn version
WHEN:
- An attempt is made to load the classifier
THEN:
- The classifier reports the warning was captured and processed
"""
# TODO: This wasn't testing the warning anymore, as the schema changed
# but as it was implemented, it would require installing an old version
# rebuilding the file and committing that. Not developer friendly
# Need to rethink how to pass the load through to a file with a single
# old model?
def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=c1,
checksum="A",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
def test_one_correspondent_predict_manydocs(self):
c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=c1,
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from noone",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
def test_one_type_predict(self):
dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
document_type=dt,
)
self.classifier.train()
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
def test_one_type_predict_manydocs(self):
dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
document_type=dt,
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
def test_one_path_predict(self):
sp = StoragePath.objects.create(
name="sp",
matching_algorithm=StoragePath.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
storage_path=sp,
)
self.classifier.train()
self.assertEqual(self.classifier.predict_storage_path(doc1.content), sp.pk)
def test_one_path_predict_manydocs(self):
sp = StoragePath.objects.create(
name="sp",
matching_algorithm=StoragePath.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
storage_path=sp,
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_storage_path(doc1.content), sp.pk)
self.assertIsNone(self.classifier.predict_storage_path(doc2.content))
def test_one_tag_predict(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc1.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
def test_one_tag_predict_unassigned(self):
Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
def test_two_tags_predict_singledoc(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
doc4 = Document.objects.create(
title="doc1",
content="this is a document from c4",
checksum="D",
)
doc4.tags.add(t1)
doc4.tags.add(t2)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
def test_two_tags_predict(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
doc3 = Document.objects.create(
title="doc1",
content="this is a document from c3",
checksum="C",
)
doc4 = Document.objects.create(
title="doc1",
content="this is a document from c4",
checksum="D",
)
doc1.tags.add(t1)
doc2.tags.add(t2)
doc4.tags.add(t1)
doc4.tags.add(t2)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
def test_one_tag_predict_multi(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from c2",
checksum="B",
)
doc1.tags.add(t1)
doc2.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
def test_one_tag_predict_multi_2(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from c2",
checksum="B",
)
doc1.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
def test_load_classifier_not_exists(self):
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier(self, load):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
load.assert_called_once()
@override_settings(
CACHES={
"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"},
},
)
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
@pytest.mark.skip(
reason="Disabled caching due to high memory usage - need to investigate.",
)
def test_load_classifier_cached(self):
classifier = load_classifier()
self.assertIsNotNone(classifier)
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
load_classifier()
load.assert_not_called()
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_incompatible_version(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_os_error(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))