mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
classifier caching
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
@@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
self.classifier.save_classifier()
|
||||
self.classifier.save()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
|
||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||
# assure that we won't load old classifiers.
|
||||
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
|
||||
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
|
||||
|
||||
self.classifier.save_classifier()
|
||||
self.classifier.save()
|
||||
|
||||
# assure that we can load the classifier after saving it.
|
||||
classifier2.reload()
|
||||
|
||||
def testReload(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
classifier2.reload()
|
||||
v1 = classifier2.classifier_version
|
||||
|
||||
# change the classifier after some time.
|
||||
sleep(1)
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2.reload()
|
||||
v2 = classifier2.classifier_version
|
||||
self.assertNotEqual(v1, v2)
|
||||
classifier2.load()
|
||||
|
||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||
def testSaveClassifier(self):
|
||||
@@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.classifier.train()
|
||||
|
||||
self.classifier.save_classifier()
|
||||
self.classifier.save()
|
||||
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.reload()
|
||||
new_classifier.load()
|
||||
self.assertFalse(new_classifier.train())
|
||||
|
||||
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
||||
@@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.generate_test_data()
|
||||
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.reload()
|
||||
new_classifier.load()
|
||||
|
||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||
|
||||
@@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||
self.assertIsNone(load_classifier())
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier(self, reload):
|
||||
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||
def test_load_classifier(self, load):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertIsNotNone(load_classifier())
|
||||
load.assert_called_once()
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier_incompatible_version(self, reload):
|
||||
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
|
||||
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
||||
def test_load_classifier_cached(self):
|
||||
classifier = load_classifier()
|
||||
self.assertIsNotNone(classifier)
|
||||
|
||||
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
|
||||
classifier2 = load_classifier()
|
||||
load.assert_not_called()
|
||||
|
||||
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||
def test_load_classifier_incompatible_version(self, load):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
reload.side_effect = IncompatibleClassifierVersionError()
|
||||
load.side_effect = IncompatibleClassifierVersionError()
|
||||
self.assertIsNone(load_classifier())
|
||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||
def test_load_classifier_os_error(self, reload):
|
||||
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||
def test_load_classifier_os_error(self, load):
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
||||
reload.side_effect = OSError()
|
||||
load.side_effect = OSError()
|
||||
self.assertIsNone(load_classifier())
|
||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||
|
@@ -2,7 +2,7 @@ import os
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.test import TestCase, override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from documents import tasks
|
||||
@@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase):
|
||||
load_classifier.assert_called_once()
|
||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||
|
||||
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||
def test_train_classifier(self):
|
||||
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
doc = Document.objects.create(correspondent=c, content="test", title="test")
|
||||
|
Reference in New Issue
Block a user