mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	classifier caching
This commit is contained in:
		| @@ -5,6 +5,7 @@ import pickle | ||||
| import re | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.cache import cache | ||||
|  | ||||
| from documents.models import Document, MatchingModel | ||||
|  | ||||
| @@ -30,22 +31,28 @@ def load_classifier(): | ||||
|         ) | ||||
|         return None | ||||
|  | ||||
|     try: | ||||
|     version = os.stat(settings.MODEL_FILE).st_mtime | ||||
|  | ||||
|     classifier = cache.get("paperless-classifier", version=version) | ||||
|  | ||||
|     if not classifier: | ||||
|         classifier = DocumentClassifier() | ||||
|         classifier.reload() | ||||
|     except (EOFError, IncompatibleClassifierVersionError) as e: | ||||
|         # there's something wrong with the model file. | ||||
|         logger.error( | ||||
|             f"Unrecoverable error while loading document " | ||||
|             f"classification model: {str(e)}, deleting model file." | ||||
|         ) | ||||
|         os.unlink(settings.MODEL_FILE) | ||||
|         classifier = None | ||||
|     except OSError as e: | ||||
|         logger.error( | ||||
|             f"Error while loading document classification model: {str(e)}" | ||||
|         ) | ||||
|         classifier = None | ||||
|         try: | ||||
|             classifier.load() | ||||
|             cache.set("paperless-classifier", classifier, version=version) | ||||
|         except (EOFError, IncompatibleClassifierVersionError) as e: | ||||
|             # there's something wrong with the model file. | ||||
|             logger.error( | ||||
|                 f"Unrecoverable error while loading document " | ||||
|                 f"classification model: {str(e)}, deleting model file." | ||||
|             ) | ||||
|             os.unlink(settings.MODEL_FILE) | ||||
|             classifier = None | ||||
|         except OSError as e: | ||||
|             logger.error( | ||||
|                 f"Error while loading document classification model: {str(e)}" | ||||
|             ) | ||||
|             classifier = None | ||||
|  | ||||
|     return classifier | ||||
|  | ||||
| @@ -55,10 +62,6 @@ class DocumentClassifier(object): | ||||
|     FORMAT_VERSION = 6 | ||||
|  | ||||
|     def __init__(self): | ||||
|         # mtime of the model file on disk. used to prevent reloading when | ||||
|         # nothing has changed. | ||||
|         self.classifier_version = 0 | ||||
|  | ||||
|         # hash of the training data. used to prevent re-training when the | ||||
|         # training data has not changed. | ||||
|         self.data_hash = None | ||||
| @@ -69,30 +72,23 @@ class DocumentClassifier(object): | ||||
|         self.correspondent_classifier = None | ||||
|         self.document_type_classifier = None | ||||
|  | ||||
|     def reload(self): | ||||
|         if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: | ||||
|             with open(settings.MODEL_FILE, "rb") as f: | ||||
|                 schema_version = pickle.load(f) | ||||
|     def load(self): | ||||
|         with open(settings.MODEL_FILE, "rb") as f: | ||||
|             schema_version = pickle.load(f) | ||||
|  | ||||
|                 if schema_version != self.FORMAT_VERSION: | ||||
|                     raise IncompatibleClassifierVersionError( | ||||
|                         "Cannor load classifier, incompatible versions.") | ||||
|                 else: | ||||
|                     if self.classifier_version > 0: | ||||
|                         # Don't be confused by this check. It's simply here | ||||
|                         # so that we wont log anything on initial reload. | ||||
|                         logger.info("Classifier updated on disk, " | ||||
|                                     "reloading classifier models") | ||||
|                     self.data_hash = pickle.load(f) | ||||
|                     self.data_vectorizer = pickle.load(f) | ||||
|                     self.tags_binarizer = pickle.load(f) | ||||
|             if schema_version != self.FORMAT_VERSION: | ||||
|                 raise IncompatibleClassifierVersionError( | ||||
|                     "Cannor load classifier, incompatible versions.") | ||||
|             else: | ||||
|                 self.data_hash = pickle.load(f) | ||||
|                 self.data_vectorizer = pickle.load(f) | ||||
|                 self.tags_binarizer = pickle.load(f) | ||||
|  | ||||
|                     self.tags_classifier = pickle.load(f) | ||||
|                     self.correspondent_classifier = pickle.load(f) | ||||
|                     self.document_type_classifier = pickle.load(f) | ||||
|             self.classifier_version = os.path.getmtime(settings.MODEL_FILE) | ||||
|                 self.tags_classifier = pickle.load(f) | ||||
|                 self.correspondent_classifier = pickle.load(f) | ||||
|                 self.document_type_classifier = pickle.load(f) | ||||
|  | ||||
|     def save_classifier(self): | ||||
|     def save(self): | ||||
|         with open(settings.MODEL_FILE, "wb") as f: | ||||
|             pickle.dump(self.FORMAT_VERSION, f) | ||||
|             pickle.dump(self.data_hash, f) | ||||
|   | ||||
| @@ -52,7 +52,7 @@ def train_classifier(): | ||||
|                 "Saving updated classifier model to {}...".format( | ||||
|                     settings.MODEL_FILE) | ||||
|             ) | ||||
|             classifier.save_classifier() | ||||
|             classifier.save() | ||||
|         else: | ||||
|             logger.debug( | ||||
|                 "Training data unchanged." | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| import os | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
| from time import sleep | ||||
| from unittest import mock | ||||
|  | ||||
| from django.conf import settings | ||||
| @@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.assertFalse(self.classifier.train()) | ||||
|  | ||||
|         self.classifier.save_classifier() | ||||
|         self.classifier.save() | ||||
|  | ||||
|         classifier2 = DocumentClassifier() | ||||
|  | ||||
|         current_ver = DocumentClassifier.FORMAT_VERSION | ||||
|         with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): | ||||
|             # assure that we won't load old classifiers. | ||||
|             self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload) | ||||
|             self.assertRaises(IncompatibleClassifierVersionError, classifier2.load) | ||||
|  | ||||
|             self.classifier.save_classifier() | ||||
|             self.classifier.save() | ||||
|  | ||||
|             # assure that we can load the classifier after saving it. | ||||
|             classifier2.reload() | ||||
|  | ||||
|     def testReload(self): | ||||
|  | ||||
|         self.generate_test_data() | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.classifier.save_classifier() | ||||
|  | ||||
|         classifier2 = DocumentClassifier() | ||||
|         classifier2.reload() | ||||
|         v1 = classifier2.classifier_version | ||||
|  | ||||
|         # change the classifier after some time. | ||||
|         sleep(1) | ||||
|         self.classifier.save_classifier() | ||||
|  | ||||
|         classifier2.reload() | ||||
|         v2 = classifier2.classifier_version | ||||
|         self.assertNotEqual(v1, v2) | ||||
|             classifier2.load() | ||||
|  | ||||
|     @override_settings(DATA_DIR=tempfile.mkdtemp()) | ||||
|     def testSaveClassifier(self): | ||||
| @@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         self.classifier.train() | ||||
|  | ||||
|         self.classifier.save_classifier() | ||||
|         self.classifier.save() | ||||
|  | ||||
|         new_classifier = DocumentClassifier() | ||||
|         new_classifier.reload() | ||||
|         new_classifier.load() | ||||
|         self.assertFalse(new_classifier.train()) | ||||
|  | ||||
|     @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) | ||||
| @@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         self.generate_test_data() | ||||
|  | ||||
|         new_classifier = DocumentClassifier() | ||||
|         new_classifier.reload() | ||||
|         new_classifier.load() | ||||
|  | ||||
|         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) | ||||
|  | ||||
| @@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         self.assertFalse(os.path.exists(settings.MODEL_FILE)) | ||||
|         self.assertIsNone(load_classifier()) | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier(self, reload): | ||||
|     @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.load") | ||||
|     def test_load_classifier(self, load): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertIsNotNone(load_classifier()) | ||||
|         load.assert_called_once() | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier_incompatible_version(self, reload): | ||||
|     @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}}) | ||||
|     @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) | ||||
|     def test_load_classifier_cached(self): | ||||
|         classifier = load_classifier() | ||||
|         self.assertIsNotNone(classifier) | ||||
|  | ||||
|         with mock.patch("documents.classifier.DocumentClassifier.load") as load: | ||||
|             classifier2 = load_classifier() | ||||
|             load.assert_not_called() | ||||
|  | ||||
|     @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.load") | ||||
|     def test_load_classifier_incompatible_version(self, load): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|         reload.side_effect = IncompatibleClassifierVersionError() | ||||
|         load.side_effect = IncompatibleClassifierVersionError() | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertFalse(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier_os_error(self, reload): | ||||
|     @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.load") | ||||
|     def test_load_classifier_os_error(self, load): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|         reload.side_effect = OSError() | ||||
|         load.side_effect = OSError() | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|   | ||||
| @@ -2,7 +2,7 @@ import os | ||||
| from unittest import mock | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
| from django.test import TestCase, override_settings | ||||
| from django.utils import timezone | ||||
|  | ||||
| from documents import tasks | ||||
| @@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase): | ||||
|         load_classifier.assert_called_once() | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|     @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}}) | ||||
|     def test_train_classifier(self): | ||||
|         c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") | ||||
|         doc = Document.objects.create(correspondent=c, content="test", title="test") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler