paperless-ngx/src/documents/tests/test_classifier.py

553 lines
18 KiB
Python

import os
import re
import shutil
import tempfile
from pathlib import Path
from unittest import mock
import pytest
from django.conf import settings
from django.test import override_settings
from django.test import TestCase
from documents.classifier import ClassifierModelCorruptError
from documents.classifier import DocumentClassifier
from documents.classifier import IncompatibleClassifierVersionError
from documents.classifier import load_classifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
def dummy_preprocess(content: str):
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
class TestClassifier(DirectoriesMixin, TestCase):
SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
def setUp(self):
super().setUp()
self.classifier = DocumentClassifier()
self.classifier.preprocess_content = mock.MagicMock(
side_effect=dummy_preprocess,
)
def generate_test_data(self):
self.c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.c2 = Correspondent.objects.create(name="c2")
self.c3 = Correspondent.objects.create(
name="c3",
matching_algorithm=Correspondent.MATCH_AUTO,
)
self.t1 = Tag.objects.create(
name="t1",
matching_algorithm=Tag.MATCH_AUTO,
pk=12,
)
self.t2 = Tag.objects.create(
name="t2",
matching_algorithm=Tag.MATCH_ANY,
pk=34,
is_inbox_tag=True,
)
self.t3 = Tag.objects.create(
name="t3",
matching_algorithm=Tag.MATCH_AUTO,
pk=45,
)
self.dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.dt2 = DocumentType.objects.create(
name="dt2",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.sp1 = StoragePath.objects.create(
name="sp1",
path="path1",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.sp2 = StoragePath.objects.create(
name="sp2",
path="path2",
matching_algorithm=DocumentType.MATCH_AUTO,
)
self.doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=self.c1,
checksum="A",
document_type=self.dt,
)
self.doc2 = Document.objects.create(
title="doc1",
content="this is another document, but from c2",
correspondent=self.c2,
checksum="B",
)
self.doc_inbox = Document.objects.create(
title="doc235",
content="aa",
checksum="C",
)
self.doc1.tags.add(self.t1)
self.doc2.tags.add(self.t1)
self.doc2.tags.add(self.t3)
self.doc_inbox.tags.add(self.t2)
self.doc1.storage_path = self.sp1
def testNoTrainingData(self):
try:
self.classifier.train()
except ValueError as e:
self.assertEqual(str(e), "No training data available.")
else:
self.fail("Should raise exception")
def testEmpty(self):
Document.objects.create(title="WOW", checksum="3457", content="ASD")
self.classifier.train()
self.assertIsNone(self.classifier.document_type_classifier)
self.assertIsNone(self.classifier.tags_classifier)
self.assertIsNone(self.classifier.correspondent_classifier)
self.assertListEqual(self.classifier.predict_tags(""), [])
self.assertIsNone(self.classifier.predict_document_type(""))
self.assertIsNone(self.classifier.predict_correspondent(""))
def testTrain(self):
self.generate_test_data()
self.classifier.train()
self.assertListEqual(
list(self.classifier.correspondent_classifier.classes_),
[-1, self.c1.pk],
)
self.assertListEqual(
list(self.classifier.tags_binarizer.classes_),
[self.t1.pk, self.t3.pk],
)
def testPredict(self):
self.generate_test_data()
self.classifier.train()
self.assertEqual(
self.classifier.predict_correspondent(self.doc1.content),
self.c1.pk,
)
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
self.assertListEqual(
self.classifier.predict_tags(self.doc1.content),
[self.t1.pk],
)
self.assertListEqual(
self.classifier.predict_tags(self.doc2.content),
[self.t1.pk, self.t3.pk],
)
self.assertEqual(
self.classifier.predict_document_type(self.doc1.content),
self.dt.pk,
)
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
def testDatasetHashing(self):
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
def testVersionIncreased(self):
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
self.classifier.save()
classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch(
"documents.classifier.DocumentClassifier.FORMAT_VERSION",
current_ver + 1,
):
# assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
self.classifier.save()
# assure that we can load the classifier after saving it.
classifier2.load()
@override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self):
self.generate_test_data()
self.classifier.train()
self.classifier.save()
new_classifier = DocumentClassifier()
new_classifier.load()
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
self.assertFalse(new_classifier.train())
# @override_settings(
# MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
# )
# def test_create_test_load_and_classify(self):
# self.generate_test_data()
# self.classifier.train()
# self.classifier.save()
def test_load_and_classify(self):
# Generate test data, train and save to the model file
# This ensures the model file sklearn version matches
# and eliminates a warning
shutil.copy(
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
)
self.generate_test_data()
self.classifier.train()
self.classifier.save()
new_classifier = DocumentClassifier()
new_classifier.load()
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load):
"""
GIVEN:
- Corrupted classifier pickle file
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
shutil.copy(
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
)
# First load is the schema version
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
@override_settings(
MODEL_FILE=os.path.join(
os.path.dirname(__file__),
"data",
"v1.0.2.model.pickle",
),
)
def test_load_new_scikit_learn_version(self):
"""
GIVEN:
- classifier pickle file created with a different scikit-learn version
WHEN:
- An attempt is made to load the classifier
THEN:
- The classifier reports the warning was captured and processed
"""
with self.assertRaises(IncompatibleClassifierVersionError):
self.classifier.load()
def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=c1,
checksum="A",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
def test_one_correspondent_predict_manydocs(self):
c1 = Correspondent.objects.create(
name="c1",
matching_algorithm=Correspondent.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
correspondent=c1,
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from noone",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
def test_one_type_predict(self):
dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
document_type=dt,
)
self.classifier.train()
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
def test_one_type_predict_manydocs(self):
dt = DocumentType.objects.create(
name="dt",
matching_algorithm=DocumentType.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
document_type=dt,
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
def test_one_path_predict(self):
sp = StoragePath.objects.create(
name="sp",
matching_algorithm=StoragePath.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
storage_path=sp,
)
self.classifier.train()
self.assertEqual(self.classifier.predict_storage_path(doc1.content), sp.pk)
def test_one_path_predict_manydocs(self):
sp = StoragePath.objects.create(
name="sp",
matching_algorithm=StoragePath.MATCH_AUTO,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
storage_path=sp,
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
self.classifier.train()
self.assertEqual(self.classifier.predict_storage_path(doc1.content), sp.pk)
self.assertIsNone(self.classifier.predict_storage_path(doc2.content))
def test_one_tag_predict(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc1.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
def test_one_tag_predict_unassigned(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
def test_two_tags_predict_singledoc(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
doc4 = Document.objects.create(
title="doc1",
content="this is a document from c4",
checksum="D",
)
doc4.tags.add(t1)
doc4.tags.add(t2)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
def test_two_tags_predict(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc1",
content="this is a document from c2",
checksum="B",
)
doc3 = Document.objects.create(
title="doc1",
content="this is a document from c3",
checksum="C",
)
doc4 = Document.objects.create(
title="doc1",
content="this is a document from c4",
checksum="D",
)
doc1.tags.add(t1)
doc2.tags.add(t2)
doc4.tags.add(t1)
doc4.tags.add(t2)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
def test_one_tag_predict_multi(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from c2",
checksum="B",
)
doc1.tags.add(t1)
doc2.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
def test_one_tag_predict_multi_2(self):
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc2 = Document.objects.create(
title="doc2",
content="this is a document from c2",
checksum="B",
)
doc1.tags.add(t1)
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
def test_load_classifier_not_exists(self):
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier(self, load):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
load.assert_called_once()
@override_settings(
CACHES={
"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"},
},
)
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
@pytest.mark.skip(
reason="Disabled caching due to high memory usage - need to investigate.",
)
def test_load_classifier_cached(self):
classifier = load_classifier()
self.assertIsNotNone(classifier)
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
classifier2 = load_classifier()
load.assert_not_called()
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_incompatible_version(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_os_error(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))