classifier caching

This commit is contained in:
jonaswinkler 2021-02-06 20:54:58 +01:00
parent 0024c2aae4
commit ffe96c8fff
4 changed files with 68 additions and 76 deletions

View File

@ -5,6 +5,7 @@ import pickle
import re import re
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from documents.models import Document, MatchingModel from documents.models import Document, MatchingModel
@ -30,22 +31,28 @@ def load_classifier():
) )
return None return None
try: version = os.stat(settings.MODEL_FILE).st_mtime
classifier = cache.get("paperless-classifier", version=version)
if not classifier:
classifier = DocumentClassifier() classifier = DocumentClassifier()
classifier.reload() try:
except (EOFError, IncompatibleClassifierVersionError) as e: classifier.load()
# there's something wrong with the model file. cache.set("paperless-classifier", classifier, version=version)
logger.error( except (EOFError, IncompatibleClassifierVersionError) as e:
f"Unrecoverable error while loading document " # there's something wrong with the model file.
f"classification model: {str(e)}, deleting model file." logger.error(
) f"Unrecoverable error while loading document "
os.unlink(settings.MODEL_FILE) f"classification model: {str(e)}, deleting model file."
classifier = None )
except OSError as e: os.unlink(settings.MODEL_FILE)
logger.error( classifier = None
f"Error while loading document classification model: {str(e)}" except OSError as e:
) logger.error(
classifier = None f"Error while loading document classification model: {str(e)}"
)
classifier = None
return classifier return classifier
@ -55,10 +62,6 @@ class DocumentClassifier(object):
FORMAT_VERSION = 6 FORMAT_VERSION = 6
def __init__(self): def __init__(self):
# mtime of the model file on disk. used to prevent reloading when
# nothing has changed.
self.classifier_version = 0
# hash of the training data. used to prevent re-training when the # hash of the training data. used to prevent re-training when the
# training data has not changed. # training data has not changed.
self.data_hash = None self.data_hash = None
@ -69,30 +72,23 @@ class DocumentClassifier(object):
self.correspondent_classifier = None self.correspondent_classifier = None
self.document_type_classifier = None self.document_type_classifier = None
def reload(self): def load(self):
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: with open(settings.MODEL_FILE, "rb") as f:
with open(settings.MODEL_FILE, "rb") as f: schema_version = pickle.load(f)
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION: if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError( raise IncompatibleClassifierVersionError(
"Cannor load classifier, incompatible versions.") "Cannor load classifier, incompatible versions.")
else: else:
if self.classifier_version > 0: self.data_hash = pickle.load(f)
# Don't be confused by this check. It's simply here self.data_vectorizer = pickle.load(f)
# so that we wont log anything on initial reload. self.tags_binarizer = pickle.load(f)
logger.info("Classifier updated on disk, "
"reloading classifier models")
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f) self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f) self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f) self.document_type_classifier = pickle.load(f)
self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
def save_classifier(self): def save(self):
with open(settings.MODEL_FILE, "wb") as f: with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f) pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.data_hash, f) pickle.dump(self.data_hash, f)

View File

@ -52,7 +52,7 @@ def train_classifier():
"Saving updated classifier model to {}...".format( "Saving updated classifier model to {}...".format(
settings.MODEL_FILE) settings.MODEL_FILE)
) )
classifier.save_classifier() classifier.save()
else: else:
logger.debug( logger.debug(
"Training data unchanged." "Training data unchanged."

View File

@ -1,7 +1,6 @@
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from time import sleep
from unittest import mock from unittest import mock
from django.conf import settings from django.conf import settings
@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertTrue(self.classifier.train()) self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train()) self.assertFalse(self.classifier.train())
self.classifier.save_classifier() self.classifier.save()
classifier2 = DocumentClassifier() classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers. # assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload) self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
self.classifier.save_classifier() self.classifier.save()
# assure that we can load the classifier after saving it. # assure that we can load the classifier after saving it.
classifier2.reload() classifier2.load()
def testReload(self):
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.classifier.save_classifier()
classifier2 = DocumentClassifier()
classifier2.reload()
v1 = classifier2.classifier_version
# change the classifier after some time.
sleep(1)
self.classifier.save_classifier()
classifier2.reload()
v2 = classifier2.classifier_version
self.assertNotEqual(v1, v2)
@override_settings(DATA_DIR=tempfile.mkdtemp()) @override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self): def testSaveClassifier(self):
@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.classifier.train() self.classifier.train()
self.classifier.save_classifier() self.classifier.save()
new_classifier = DocumentClassifier() new_classifier = DocumentClassifier()
new_classifier.reload() new_classifier.load()
self.assertFalse(new_classifier.train()) self.assertFalse(new_classifier.train())
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.generate_test_data() self.generate_test_data()
new_classifier = DocumentClassifier() new_classifier = DocumentClassifier()
new_classifier.reload() new_classifier.load()
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertFalse(os.path.exists(settings.MODEL_FILE)) self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier()) self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload") @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
def test_load_classifier(self, reload): @mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier(self, load):
Path(settings.MODEL_FILE).touch() Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier()) self.assertIsNotNone(load_classifier())
load.assert_called_once()
@mock.patch("documents.classifier.DocumentClassifier.reload") @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
def test_load_classifier_incompatible_version(self, reload): @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
def test_load_classifier_cached(self):
classifier = load_classifier()
self.assertIsNotNone(classifier)
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
classifier2 = load_classifier()
load.assert_not_called()
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_incompatible_version(self, load):
Path(settings.MODEL_FILE).touch() Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE)) self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = IncompatibleClassifierVersionError() load.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier()) self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE)) self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.reload") @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
def test_load_classifier_os_error(self, reload): @mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_os_error(self, load):
Path(settings.MODEL_FILE).touch() Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE)) self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = OSError() load.side_effect = OSError()
self.assertIsNone(load_classifier()) self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE)) self.assertTrue(os.path.exists(settings.MODEL_FILE))

View File

@ -2,7 +2,7 @@ import os
from unittest import mock from unittest import mock
from django.conf import settings from django.conf import settings
from django.test import TestCase from django.test import TestCase, override_settings
from django.utils import timezone from django.utils import timezone
from documents import tasks from documents import tasks
@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase):
load_classifier.assert_called_once() load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE)) self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
def test_train_classifier(self): def test_train_classifier(self):
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
doc = Document.objects.create(correspondent=c, content="test", title="test") doc = Document.objects.create(correspondent=c, content="test", title="test")