mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
centralized classifier loading, better error handling, no error messages when auto matching is not used
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier
|
||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
@@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
||||
|
||||
def test_load_classifier_not_exists(self):
|
||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||
self.assertIsNone(load_classifier())
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier(self, reload):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertIsNotNone(load_classifier())
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier_incompatible_version(self, reload):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
reload.side_effect = IncompatibleClassifierVersionError()
|
||||
self.assertIsNone(load_classifier())
|
||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier_os_error(self, reload):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
reload.side_effect = OSError()
|
||||
self.assertIsNone(load_classifier())
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
@@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase):
|
||||
self.assertIsNotNone(os.path.isfile(document.title))
|
||||
self.assertTrue(os.path.isfile(document.source_path))
|
||||
|
||||
@mock.patch("documents.consumer.DocumentClassifier")
|
||||
@mock.patch("documents.consumer.load_classifier")
|
||||
def testClassifyDocument(self, m):
|
||||
correspondent = Correspondent.objects.create(name="test")
|
||||
dtype = DocumentType.objects.create(name="test")
|
||||
|
@@ -1,11 +1,12 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from documents import tasks
|
||||
from documents.models import Document
|
||||
from documents.models import Document, Tag, Correspondent, DocumentType
|
||||
from documents.sanity_checker import SanityError, SanityFailedError
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
@@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase):
|
||||
|
||||
tasks.index_optimize()
|
||||
|
||||
def test_train_classifier(self):
|
||||
@mock.patch("documents.tasks.load_classifier")
|
||||
def test_train_classifier_no_auto_matching(self, load_classifier):
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_not_called()
|
||||
|
||||
@mock.patch("documents.tasks.load_classifier")
|
||||
def test_train_classifier_with_auto_tag(self, load_classifier):
|
||||
load_classifier.return_value = None
|
||||
Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||
|
||||
@mock.patch("documents.tasks.load_classifier")
|
||||
def test_train_classifier_with_auto_type(self, load_classifier):
|
||||
load_classifier.return_value = None
|
||||
DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||
|
||||
@mock.patch("documents.tasks.load_classifier")
|
||||
def test_train_classifier_with_auto_correspondent(self, load_classifier):
|
||||
load_classifier.return_value = None
|
||||
Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||
|
||||
def test_train_classifier(self):
|
||||
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
doc = Document.objects.create(correspondent=c, content="test", title="test")
|
||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||
|
||||
tasks.train_classifier()
|
||||
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||
mtime = os.stat(settings.MODEL_FILE).st_mtime
|
||||
|
||||
tasks.train_classifier()
|
||||
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||
mtime2 = os.stat(settings.MODEL_FILE).st_mtime
|
||||
self.assertEqual(mtime, mtime2)
|
||||
|
||||
doc.content = "test2"
|
||||
doc.save()
|
||||
tasks.train_classifier()
|
||||
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||
mtime3 = os.stat(settings.MODEL_FILE).st_mtime
|
||||
self.assertNotEqual(mtime2, mtime3)
|
||||
|
||||
@mock.patch("documents.tasks.sanity_checker.check_sanity")
|
||||
def test_sanity_check(self, m):
|
||||
|
Reference in New Issue
Block a user