mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
427 lines
14 KiB
Python
427 lines
14 KiB
Python
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
from django.conf import settings
|
|
from django.test import override_settings
|
|
from django.test import TestCase
|
|
from documents.classifier import DocumentClassifier
|
|
from documents.classifier import IncompatibleClassifierVersionError
|
|
from documents.classifier import load_classifier
|
|
from documents.models import Correspondent
|
|
from documents.models import Document
|
|
from documents.models import DocumentType
|
|
from documents.models import Tag
|
|
from documents.tests.utils import DirectoriesMixin
|
|
|
|
|
|
class TestClassifier(DirectoriesMixin, TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.classifier = DocumentClassifier()
|
|
|
|
def generate_test_data(self):
|
|
self.c1 = Correspondent.objects.create(
|
|
name="c1",
|
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
|
)
|
|
self.c2 = Correspondent.objects.create(name="c2")
|
|
self.c3 = Correspondent.objects.create(
|
|
name="c3",
|
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
|
)
|
|
self.t1 = Tag.objects.create(
|
|
name="t1",
|
|
matching_algorithm=Tag.MATCH_AUTO,
|
|
pk=12,
|
|
)
|
|
self.t2 = Tag.objects.create(
|
|
name="t2",
|
|
matching_algorithm=Tag.MATCH_ANY,
|
|
pk=34,
|
|
is_inbox_tag=True,
|
|
)
|
|
self.t3 = Tag.objects.create(
|
|
name="t3",
|
|
matching_algorithm=Tag.MATCH_AUTO,
|
|
pk=45,
|
|
)
|
|
self.dt = DocumentType.objects.create(
|
|
name="dt",
|
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
|
)
|
|
self.dt2 = DocumentType.objects.create(
|
|
name="dt2",
|
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
|
)
|
|
|
|
self.doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
correspondent=self.c1,
|
|
checksum="A",
|
|
document_type=self.dt,
|
|
)
|
|
self.doc2 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is another document, but from c2",
|
|
correspondent=self.c2,
|
|
checksum="B",
|
|
)
|
|
self.doc_inbox = Document.objects.create(
|
|
title="doc235",
|
|
content="aa",
|
|
checksum="C",
|
|
)
|
|
|
|
self.doc1.tags.add(self.t1)
|
|
self.doc2.tags.add(self.t1)
|
|
self.doc2.tags.add(self.t3)
|
|
self.doc_inbox.tags.add(self.t2)
|
|
|
|
def testNoTrainingData(self):
|
|
try:
|
|
self.classifier.train()
|
|
except ValueError as e:
|
|
self.assertEqual(str(e), "No training data available.")
|
|
else:
|
|
self.fail("Should raise exception")
|
|
|
|
def testEmpty(self):
|
|
Document.objects.create(title="WOW", checksum="3457", content="ASD")
|
|
self.classifier.train()
|
|
self.assertIsNone(self.classifier.document_type_classifier)
|
|
self.assertIsNone(self.classifier.tags_classifier)
|
|
self.assertIsNone(self.classifier.correspondent_classifier)
|
|
|
|
self.assertListEqual(self.classifier.predict_tags(""), [])
|
|
self.assertIsNone(self.classifier.predict_document_type(""))
|
|
self.assertIsNone(self.classifier.predict_correspondent(""))
|
|
|
|
def testTrain(self):
|
|
self.generate_test_data()
|
|
self.classifier.train()
|
|
self.assertListEqual(
|
|
list(self.classifier.correspondent_classifier.classes_),
|
|
[-1, self.c1.pk],
|
|
)
|
|
self.assertListEqual(
|
|
list(self.classifier.tags_binarizer.classes_),
|
|
[self.t1.pk, self.t3.pk],
|
|
)
|
|
|
|
def testPredict(self):
|
|
self.generate_test_data()
|
|
self.classifier.train()
|
|
self.assertEqual(
|
|
self.classifier.predict_correspondent(self.doc1.content),
|
|
self.c1.pk,
|
|
)
|
|
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
|
self.assertListEqual(
|
|
self.classifier.predict_tags(self.doc1.content),
|
|
[self.t1.pk],
|
|
)
|
|
self.assertListEqual(
|
|
self.classifier.predict_tags(self.doc2.content),
|
|
[self.t1.pk, self.t3.pk],
|
|
)
|
|
self.assertEqual(
|
|
self.classifier.predict_document_type(self.doc1.content),
|
|
self.dt.pk,
|
|
)
|
|
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
|
|
|
def testDatasetHashing(self):
|
|
|
|
self.generate_test_data()
|
|
|
|
self.assertTrue(self.classifier.train())
|
|
self.assertFalse(self.classifier.train())
|
|
|
|
def testVersionIncreased(self):
|
|
|
|
self.generate_test_data()
|
|
self.assertTrue(self.classifier.train())
|
|
self.assertFalse(self.classifier.train())
|
|
|
|
self.classifier.save()
|
|
|
|
classifier2 = DocumentClassifier()
|
|
|
|
current_ver = DocumentClassifier.FORMAT_VERSION
|
|
with mock.patch(
|
|
"documents.classifier.DocumentClassifier.FORMAT_VERSION",
|
|
current_ver + 1,
|
|
):
|
|
# assure that we won't load old classifiers.
|
|
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
|
|
|
|
self.classifier.save()
|
|
|
|
# assure that we can load the classifier after saving it.
|
|
classifier2.load()
|
|
|
|
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
|
def testSaveClassifier(self):
|
|
|
|
self.generate_test_data()
|
|
|
|
self.classifier.train()
|
|
|
|
self.classifier.save()
|
|
|
|
new_classifier = DocumentClassifier()
|
|
new_classifier.load()
|
|
self.assertFalse(new_classifier.train())
|
|
|
|
@override_settings(
|
|
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
|
)
|
|
def test_load_and_classify(self):
|
|
self.generate_test_data()
|
|
|
|
new_classifier = DocumentClassifier()
|
|
new_classifier.load()
|
|
|
|
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
|
|
|
def test_one_correspondent_predict(self):
|
|
c1 = Correspondent.objects.create(
|
|
name="c1",
|
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
|
)
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
correspondent=c1,
|
|
checksum="A",
|
|
)
|
|
|
|
self.classifier.train()
|
|
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
|
|
|
def test_one_correspondent_predict_manydocs(self):
|
|
c1 = Correspondent.objects.create(
|
|
name="c1",
|
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
|
)
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
correspondent=c1,
|
|
checksum="A",
|
|
)
|
|
doc2 = Document.objects.create(
|
|
title="doc2",
|
|
content="this is a document from noone",
|
|
checksum="B",
|
|
)
|
|
|
|
self.classifier.train()
|
|
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
|
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
|
|
|
|
def test_one_type_predict(self):
|
|
dt = DocumentType.objects.create(
|
|
name="dt",
|
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
|
)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
document_type=dt,
|
|
)
|
|
|
|
self.classifier.train()
|
|
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
|
|
|
def test_one_type_predict_manydocs(self):
|
|
dt = DocumentType.objects.create(
|
|
name="dt",
|
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
|
)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
document_type=dt,
|
|
)
|
|
|
|
doc2 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c2",
|
|
checksum="B",
|
|
)
|
|
|
|
self.classifier.train()
|
|
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
|
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
|
|
|
|
def test_one_tag_predict(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
)
|
|
|
|
doc1.tags.add(t1)
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
|
|
|
def test_one_tag_predict_unassigned(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
)
|
|
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
|
|
|
|
def test_two_tags_predict_singledoc(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
|
|
|
doc4 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c4",
|
|
checksum="D",
|
|
)
|
|
|
|
doc4.tags.add(t1)
|
|
doc4.tags.add(t2)
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
|
|
|
def test_two_tags_predict(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
)
|
|
doc2 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c2",
|
|
checksum="B",
|
|
)
|
|
doc3 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c3",
|
|
checksum="C",
|
|
)
|
|
doc4 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c4",
|
|
checksum="D",
|
|
)
|
|
|
|
doc1.tags.add(t1)
|
|
doc2.tags.add(t2)
|
|
|
|
doc4.tags.add(t1)
|
|
doc4.tags.add(t2)
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
|
|
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
|
|
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
|
|
|
def test_one_tag_predict_multi(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
)
|
|
doc2 = Document.objects.create(
|
|
title="doc2",
|
|
content="this is a document from c2",
|
|
checksum="B",
|
|
)
|
|
|
|
doc1.tags.add(t1)
|
|
doc2.tags.add(t1)
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
|
|
|
|
def test_one_tag_predict_multi_2(self):
|
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
|
|
|
doc1 = Document.objects.create(
|
|
title="doc1",
|
|
content="this is a document from c1",
|
|
checksum="A",
|
|
)
|
|
doc2 = Document.objects.create(
|
|
title="doc2",
|
|
content="this is a document from c2",
|
|
checksum="B",
|
|
)
|
|
|
|
doc1.tags.add(t1)
|
|
self.classifier.train()
|
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
|
|
|
def test_load_classifier_not_exists(self):
|
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
|
self.assertIsNone(load_classifier())
|
|
|
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
|
def test_load_classifier(self, load):
|
|
Path(settings.MODEL_FILE).touch()
|
|
self.assertIsNotNone(load_classifier())
|
|
load.assert_called_once()
|
|
|
|
@override_settings(
|
|
CACHES={
|
|
"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"},
|
|
},
|
|
)
|
|
@override_settings(
|
|
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
|
)
|
|
@pytest.mark.skip(
|
|
reason="Disabled caching due to high memory usage - need to investigate.",
|
|
)
|
|
def test_load_classifier_cached(self):
|
|
classifier = load_classifier()
|
|
self.assertIsNotNone(classifier)
|
|
|
|
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
|
|
classifier2 = load_classifier()
|
|
load.assert_not_called()
|
|
|
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
|
def test_load_classifier_incompatible_version(self, load):
|
|
Path(settings.MODEL_FILE).touch()
|
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
|
|
|
load.side_effect = IncompatibleClassifierVersionError()
|
|
self.assertIsNone(load_classifier())
|
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
|
|
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
|
def test_load_classifier_os_error(self, load):
|
|
Path(settings.MODEL_FILE).touch()
|
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
|
|
|
load.side_effect = OSError()
|
|
self.assertIsNone(load_classifier())
|
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|