From ffe96c8fffa8ca22c5911ad82aea48aa77d13cc1 Mon Sep 17 00:00:00 2001 From: jonaswinkler Date: Sat, 6 Feb 2021 20:54:58 +0100 Subject: [PATCH] classifier caching --- src/documents/classifier.py | 76 ++++++++++++-------------- src/documents/tasks.py | 2 +- src/documents/tests/test_classifier.py | 63 ++++++++++----------- src/documents/tests/test_tasks.py | 3 +- 4 files changed, 68 insertions(+), 76 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 2acebe04a..5151d453c 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -5,6 +5,7 @@ import pickle import re from django.conf import settings +from django.core.cache import cache from documents.models import Document, MatchingModel @@ -30,22 +31,28 @@ def load_classifier(): ) return None - try: + version = os.stat(settings.MODEL_FILE).st_mtime + + classifier = cache.get("paperless-classifier", version=version) + + if not classifier: classifier = DocumentClassifier() - classifier.reload() - except (EOFError, IncompatibleClassifierVersionError) as e: - # there's something wrong with the model file. - logger.error( - f"Unrecoverable error while loading document " - f"classification model: {str(e)}, deleting model file." - ) - os.unlink(settings.MODEL_FILE) - classifier = None - except OSError as e: - logger.error( - f"Error while loading document classification model: {str(e)}" - ) - classifier = None + try: + classifier.load() + cache.set("paperless-classifier", classifier, version=version) + except (EOFError, IncompatibleClassifierVersionError) as e: + # there's something wrong with the model file. + logger.error( + f"Unrecoverable error while loading document " + f"classification model: {str(e)}, deleting model file." + ) + os.unlink(settings.MODEL_FILE) + classifier = None + except OSError as e: + logger.error( + f"Error while loading document classification model: {str(e)}" + ) + classifier = None return classifier @@ -55,10 +62,6 @@ class DocumentClassifier(object): FORMAT_VERSION = 6 def __init__(self): - # mtime of the model file on disk. used to prevent reloading when - # nothing has changed. - self.classifier_version = 0 - # hash of the training data. used to prevent re-training when the # training data has not changed. self.data_hash = None @@ -69,30 +72,23 @@ class DocumentClassifier(object): self.correspondent_classifier = None self.document_type_classifier = None - def reload(self): - if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: - with open(settings.MODEL_FILE, "rb") as f: - schema_version = pickle.load(f) + def load(self): + with open(settings.MODEL_FILE, "rb") as f: + schema_version = pickle.load(f) - if schema_version != self.FORMAT_VERSION: - raise IncompatibleClassifierVersionError( - "Cannor load classifier, incompatible versions.") - else: - if self.classifier_version > 0: - # Don't be confused by this check. It's simply here - # so that we wont log anything on initial reload. - logger.info("Classifier updated on disk, " - "reloading classifier models") - self.data_hash = pickle.load(f) - self.data_vectorizer = pickle.load(f) - self.tags_binarizer = pickle.load(f) + if schema_version != self.FORMAT_VERSION: + raise IncompatibleClassifierVersionError( + "Cannor load classifier, incompatible versions.") + else: + self.data_hash = pickle.load(f) + self.data_vectorizer = pickle.load(f) + self.tags_binarizer = pickle.load(f) - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.classifier_version = os.path.getmtime(settings.MODEL_FILE) + self.tags_classifier = pickle.load(f) + self.correspondent_classifier = pickle.load(f) + self.document_type_classifier = pickle.load(f) - def save_classifier(self): + def save(self): with open(settings.MODEL_FILE, "wb") as f: pickle.dump(self.FORMAT_VERSION, f) pickle.dump(self.data_hash, f) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 8c7d91585..f74a3d420 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -52,7 +52,7 @@ def train_classifier(): "Saving updated classifier model to {}...".format( settings.MODEL_FILE) ) - classifier.save_classifier() + classifier.save() else: logger.debug( "Training data unchanged." diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 14673ae65..1efe564d1 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,7 +1,6 @@ import os import tempfile from pathlib import Path -from time import sleep from unittest import mock from django.conf import settings @@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertTrue(self.classifier.train()) self.assertFalse(self.classifier.train()) - self.classifier.save_classifier() + self.classifier.save() classifier2 = DocumentClassifier() current_ver = DocumentClassifier.FORMAT_VERSION with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): # assure that we won't load old classifiers. - self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload) + self.assertRaises(IncompatibleClassifierVersionError, classifier2.load) - self.classifier.save_classifier() + self.classifier.save() # assure that we can load the classifier after saving it. - classifier2.reload() - - def testReload(self): - - self.generate_test_data() - self.assertTrue(self.classifier.train()) - self.classifier.save_classifier() - - classifier2 = DocumentClassifier() - classifier2.reload() - v1 = classifier2.classifier_version - - # change the classifier after some time. - sleep(1) - self.classifier.save_classifier() - - classifier2.reload() - v2 = classifier2.classifier_version - self.assertNotEqual(v1, v2) + classifier2.load() @override_settings(DATA_DIR=tempfile.mkdtemp()) def testSaveClassifier(self): @@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase): self.classifier.train() - self.classifier.save_classifier() + self.classifier.save() new_classifier = DocumentClassifier() - new_classifier.reload() + new_classifier.load() self.assertFalse(new_classifier.train()) @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) @@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase): self.generate_test_data() new_classifier = DocumentClassifier() - new_classifier.reload() + new_classifier.load() self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) @@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertFalse(os.path.exists(settings.MODEL_FILE)) self.assertIsNone(load_classifier()) - @mock.patch("documents.classifier.DocumentClassifier.reload") - def test_load_classifier(self, reload): + @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) + @mock.patch("documents.classifier.DocumentClassifier.load") + def test_load_classifier(self, load): Path(settings.MODEL_FILE).touch() self.assertIsNotNone(load_classifier()) + load.assert_called_once() - @mock.patch("documents.classifier.DocumentClassifier.reload") - def test_load_classifier_incompatible_version(self, reload): + @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}}) + @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) + def test_load_classifier_cached(self): + classifier = load_classifier() + self.assertIsNotNone(classifier) + + with mock.patch("documents.classifier.DocumentClassifier.load") as load: + classifier2 = load_classifier() + load.assert_not_called() + + @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) + @mock.patch("documents.classifier.DocumentClassifier.load") + def test_load_classifier_incompatible_version(self, load): Path(settings.MODEL_FILE).touch() self.assertTrue(os.path.exists(settings.MODEL_FILE)) - reload.side_effect = IncompatibleClassifierVersionError() + load.side_effect = IncompatibleClassifierVersionError() self.assertIsNone(load_classifier()) self.assertFalse(os.path.exists(settings.MODEL_FILE)) - @mock.patch("documents.classifier.DocumentClassifier.reload") - def test_load_classifier_os_error(self, reload): + @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) + @mock.patch("documents.classifier.DocumentClassifier.load") + def test_load_classifier_os_error(self, load): Path(settings.MODEL_FILE).touch() self.assertTrue(os.path.exists(settings.MODEL_FILE)) - reload.side_effect = OSError() + load.side_effect = OSError() self.assertIsNone(load_classifier()) self.assertTrue(os.path.exists(settings.MODEL_FILE)) diff --git a/src/documents/tests/test_tasks.py b/src/documents/tests/test_tasks.py index d008f995a..ed280441f 100644 --- a/src/documents/tests/test_tasks.py +++ b/src/documents/tests/test_tasks.py @@ -2,7 +2,7 @@ import os from unittest import mock from django.conf import settings -from django.test import TestCase +from django.test import TestCase, override_settings from django.utils import timezone from documents import tasks @@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase): load_classifier.assert_called_once() self.assertFalse(os.path.isfile(settings.MODEL_FILE)) + @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) def test_train_classifier(self): c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") doc = Document.objects.create(correspondent=c, content="test", title="test")