classifier caching

This commit is contained in:
jonaswinkler 2021-02-06 20:54:58 +01:00
parent 0024c2aae4
commit ffe96c8fff
4 changed files with 68 additions and 76 deletions

View File

@ -5,6 +5,7 @@ import pickle
import re
from django.conf import settings
from django.core.cache import cache
from documents.models import Document, MatchingModel
@ -30,22 +31,28 @@ def load_classifier():
)
return None
try:
version = os.stat(settings.MODEL_FILE).st_mtime
classifier = cache.get("paperless-classifier", version=version)
if not classifier:
classifier = DocumentClassifier()
classifier.reload()
except (EOFError, IncompatibleClassifierVersionError) as e:
# there's something wrong with the model file.
logger.error(
f"Unrecoverable error while loading document "
f"classification model: {str(e)}, deleting model file."
)
os.unlink(settings.MODEL_FILE)
classifier = None
except OSError as e:
logger.error(
f"Error while loading document classification model: {str(e)}"
)
classifier = None
try:
classifier.load()
cache.set("paperless-classifier", classifier, version=version)
except (EOFError, IncompatibleClassifierVersionError) as e:
# there's something wrong with the model file.
logger.error(
f"Unrecoverable error while loading document "
f"classification model: {str(e)}, deleting model file."
)
os.unlink(settings.MODEL_FILE)
classifier = None
except OSError as e:
logger.error(
f"Error while loading document classification model: {str(e)}"
)
classifier = None
return classifier
@ -55,10 +62,6 @@ class DocumentClassifier(object):
FORMAT_VERSION = 6
def __init__(self):
# mtime of the model file on disk. used to prevent reloading when
# nothing has changed.
self.classifier_version = 0
# hash of the training data. used to prevent re-training when the
# training data has not changed.
self.data_hash = None
@ -69,30 +72,23 @@ class DocumentClassifier(object):
self.correspondent_classifier = None
self.document_type_classifier = None
def reload(self):
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
def load(self):
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannor load classifier, incompatible versions.")
else:
if self.classifier_version > 0:
# Don't be confused by this check. It's simply here
# so that we wont log anything on initial reload.
logger.info("Classifier updated on disk, "
"reloading classifier models")
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannor load classifier, incompatible versions.")
else:
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
def save_classifier(self):
def save(self):
with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.data_hash, f)

View File

@ -52,7 +52,7 @@ def train_classifier():
"Saving updated classifier model to {}...".format(
settings.MODEL_FILE)
)
classifier.save_classifier()
classifier.save()
else:
logger.debug(
"Training data unchanged."

View File

@ -1,7 +1,6 @@
import os
import tempfile
from pathlib import Path
from time import sleep
from unittest import mock
from django.conf import settings
@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
self.classifier.save_classifier()
self.classifier.save()
classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
self.classifier.save_classifier()
self.classifier.save()
# assure that we can load the classifier after saving it.
classifier2.reload()
def testReload(self):
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.classifier.save_classifier()
classifier2 = DocumentClassifier()
classifier2.reload()
v1 = classifier2.classifier_version
# change the classifier after some time.
sleep(1)
self.classifier.save_classifier()
classifier2.reload()
v2 = classifier2.classifier_version
self.assertNotEqual(v1, v2)
classifier2.load()
@override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self):
@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.classifier.train()
self.classifier.save_classifier()
self.classifier.save()
new_classifier = DocumentClassifier()
new_classifier.reload()
new_classifier.load()
self.assertFalse(new_classifier.train())
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.generate_test_data()
new_classifier = DocumentClassifier()
new_classifier.reload()
new_classifier.load()
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier(self, reload):
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier(self, load):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
load.assert_called_once()
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_incompatible_version(self, reload):
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
def test_load_classifier_cached(self):
classifier = load_classifier()
self.assertIsNotNone(classifier)
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
classifier2 = load_classifier()
load.assert_not_called()
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_incompatible_version(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = IncompatibleClassifierVersionError()
load.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_os_error(self, reload):
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_os_error(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = OSError()
load.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))

View File

@ -2,7 +2,7 @@ import os
from unittest import mock
from django.conf import settings
from django.test import TestCase
from django.test import TestCase, override_settings
from django.utils import timezone
from documents import tasks
@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase):
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
def test_train_classifier(self):
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
doc = Document.objects.create(correspondent=c, content="test", title="test")