mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
classifier caching
This commit is contained in:
parent
0024c2aae4
commit
ffe96c8fff
@ -5,6 +5,7 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
from documents.models import Document, MatchingModel
|
from documents.models import Document, MatchingModel
|
||||||
|
|
||||||
@ -30,22 +31,28 @@ def load_classifier():
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
version = os.stat(settings.MODEL_FILE).st_mtime
|
||||||
|
|
||||||
|
classifier = cache.get("paperless-classifier", version=version)
|
||||||
|
|
||||||
|
if not classifier:
|
||||||
classifier = DocumentClassifier()
|
classifier = DocumentClassifier()
|
||||||
classifier.reload()
|
try:
|
||||||
except (EOFError, IncompatibleClassifierVersionError) as e:
|
classifier.load()
|
||||||
# there's something wrong with the model file.
|
cache.set("paperless-classifier", classifier, version=version)
|
||||||
logger.error(
|
except (EOFError, IncompatibleClassifierVersionError) as e:
|
||||||
f"Unrecoverable error while loading document "
|
# there's something wrong with the model file.
|
||||||
f"classification model: {str(e)}, deleting model file."
|
logger.error(
|
||||||
)
|
f"Unrecoverable error while loading document "
|
||||||
os.unlink(settings.MODEL_FILE)
|
f"classification model: {str(e)}, deleting model file."
|
||||||
classifier = None
|
)
|
||||||
except OSError as e:
|
os.unlink(settings.MODEL_FILE)
|
||||||
logger.error(
|
classifier = None
|
||||||
f"Error while loading document classification model: {str(e)}"
|
except OSError as e:
|
||||||
)
|
logger.error(
|
||||||
classifier = None
|
f"Error while loading document classification model: {str(e)}"
|
||||||
|
)
|
||||||
|
classifier = None
|
||||||
|
|
||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
@ -55,10 +62,6 @@ class DocumentClassifier(object):
|
|||||||
FORMAT_VERSION = 6
|
FORMAT_VERSION = 6
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# mtime of the model file on disk. used to prevent reloading when
|
|
||||||
# nothing has changed.
|
|
||||||
self.classifier_version = 0
|
|
||||||
|
|
||||||
# hash of the training data. used to prevent re-training when the
|
# hash of the training data. used to prevent re-training when the
|
||||||
# training data has not changed.
|
# training data has not changed.
|
||||||
self.data_hash = None
|
self.data_hash = None
|
||||||
@ -69,30 +72,23 @@ class DocumentClassifier(object):
|
|||||||
self.correspondent_classifier = None
|
self.correspondent_classifier = None
|
||||||
self.document_type_classifier = None
|
self.document_type_classifier = None
|
||||||
|
|
||||||
def reload(self):
|
def load(self):
|
||||||
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
|
with open(settings.MODEL_FILE, "rb") as f:
|
||||||
with open(settings.MODEL_FILE, "rb") as f:
|
schema_version = pickle.load(f)
|
||||||
schema_version = pickle.load(f)
|
|
||||||
|
|
||||||
if schema_version != self.FORMAT_VERSION:
|
if schema_version != self.FORMAT_VERSION:
|
||||||
raise IncompatibleClassifierVersionError(
|
raise IncompatibleClassifierVersionError(
|
||||||
"Cannor load classifier, incompatible versions.")
|
"Cannor load classifier, incompatible versions.")
|
||||||
else:
|
else:
|
||||||
if self.classifier_version > 0:
|
self.data_hash = pickle.load(f)
|
||||||
# Don't be confused by this check. It's simply here
|
self.data_vectorizer = pickle.load(f)
|
||||||
# so that we wont log anything on initial reload.
|
self.tags_binarizer = pickle.load(f)
|
||||||
logger.info("Classifier updated on disk, "
|
|
||||||
"reloading classifier models")
|
|
||||||
self.data_hash = pickle.load(f)
|
|
||||||
self.data_vectorizer = pickle.load(f)
|
|
||||||
self.tags_binarizer = pickle.load(f)
|
|
||||||
|
|
||||||
self.tags_classifier = pickle.load(f)
|
self.tags_classifier = pickle.load(f)
|
||||||
self.correspondent_classifier = pickle.load(f)
|
self.correspondent_classifier = pickle.load(f)
|
||||||
self.document_type_classifier = pickle.load(f)
|
self.document_type_classifier = pickle.load(f)
|
||||||
self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
|
|
||||||
|
|
||||||
def save_classifier(self):
|
def save(self):
|
||||||
with open(settings.MODEL_FILE, "wb") as f:
|
with open(settings.MODEL_FILE, "wb") as f:
|
||||||
pickle.dump(self.FORMAT_VERSION, f)
|
pickle.dump(self.FORMAT_VERSION, f)
|
||||||
pickle.dump(self.data_hash, f)
|
pickle.dump(self.data_hash, f)
|
||||||
|
@ -52,7 +52,7 @@ def train_classifier():
|
|||||||
"Saving updated classifier model to {}...".format(
|
"Saving updated classifier model to {}...".format(
|
||||||
settings.MODEL_FILE)
|
settings.MODEL_FILE)
|
||||||
)
|
)
|
||||||
classifier.save_classifier()
|
classifier.save()
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Training data unchanged."
|
"Training data unchanged."
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import sleep
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
self.assertTrue(self.classifier.train())
|
self.assertTrue(self.classifier.train())
|
||||||
self.assertFalse(self.classifier.train())
|
self.assertFalse(self.classifier.train())
|
||||||
|
|
||||||
self.classifier.save_classifier()
|
self.classifier.save()
|
||||||
|
|
||||||
classifier2 = DocumentClassifier()
|
classifier2 = DocumentClassifier()
|
||||||
|
|
||||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||||
# assure that we won't load old classifiers.
|
# assure that we won't load old classifiers.
|
||||||
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
|
self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
|
||||||
|
|
||||||
self.classifier.save_classifier()
|
self.classifier.save()
|
||||||
|
|
||||||
# assure that we can load the classifier after saving it.
|
# assure that we can load the classifier after saving it.
|
||||||
classifier2.reload()
|
classifier2.load()
|
||||||
|
|
||||||
def testReload(self):
|
|
||||||
|
|
||||||
self.generate_test_data()
|
|
||||||
self.assertTrue(self.classifier.train())
|
|
||||||
self.classifier.save_classifier()
|
|
||||||
|
|
||||||
classifier2 = DocumentClassifier()
|
|
||||||
classifier2.reload()
|
|
||||||
v1 = classifier2.classifier_version
|
|
||||||
|
|
||||||
# change the classifier after some time.
|
|
||||||
sleep(1)
|
|
||||||
self.classifier.save_classifier()
|
|
||||||
|
|
||||||
classifier2.reload()
|
|
||||||
v2 = classifier2.classifier_version
|
|
||||||
self.assertNotEqual(v1, v2)
|
|
||||||
|
|
||||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||||
def testSaveClassifier(self):
|
def testSaveClassifier(self):
|
||||||
@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
|
|
||||||
self.classifier.save_classifier()
|
self.classifier.save()
|
||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.reload()
|
new_classifier.load()
|
||||||
self.assertFalse(new_classifier.train())
|
self.assertFalse(new_classifier.train())
|
||||||
|
|
||||||
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
||||||
@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
self.generate_test_data()
|
self.generate_test_data()
|
||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.reload()
|
new_classifier.load()
|
||||||
|
|
||||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||||
|
|
||||||
@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||||
self.assertIsNone(load_classifier())
|
self.assertIsNone(load_classifier())
|
||||||
|
|
||||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||||
def test_load_classifier(self, reload):
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||||
|
def test_load_classifier(self, load):
|
||||||
Path(settings.MODEL_FILE).touch()
|
Path(settings.MODEL_FILE).touch()
|
||||||
self.assertIsNotNone(load_classifier())
|
self.assertIsNotNone(load_classifier())
|
||||||
|
load.assert_called_once()
|
||||||
|
|
||||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
|
||||||
def test_load_classifier_incompatible_version(self, reload):
|
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
||||||
|
def test_load_classifier_cached(self):
|
||||||
|
classifier = load_classifier()
|
||||||
|
self.assertIsNotNone(classifier)
|
||||||
|
|
||||||
|
with mock.patch("documents.classifier.DocumentClassifier.load") as load:
|
||||||
|
classifier2 = load_classifier()
|
||||||
|
load.assert_not_called()
|
||||||
|
|
||||||
|
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||||
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||||
|
def test_load_classifier_incompatible_version(self, load):
|
||||||
Path(settings.MODEL_FILE).touch()
|
Path(settings.MODEL_FILE).touch()
|
||||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
reload.side_effect = IncompatibleClassifierVersionError()
|
load.side_effect = IncompatibleClassifierVersionError()
|
||||||
self.assertIsNone(load_classifier())
|
self.assertIsNone(load_classifier())
|
||||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||||
def test_load_classifier_os_error(self, reload):
|
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||||
|
def test_load_classifier_os_error(self, load):
|
||||||
Path(settings.MODEL_FILE).touch()
|
Path(settings.MODEL_FILE).touch()
|
||||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
reload.side_effect = OSError()
|
load.side_effect = OSError()
|
||||||
self.assertIsNone(load_classifier())
|
self.assertIsNone(load_classifier())
|
||||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase, override_settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from documents import tasks
|
from documents import tasks
|
||||||
@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase):
|
|||||||
load_classifier.assert_called_once()
|
load_classifier.assert_called_once()
|
||||||
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
@override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
|
||||||
def test_train_classifier(self):
|
def test_train_classifier(self):
|
||||||
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||||
doc = Document.objects.create(correspondent=c, content="test", title="test")
|
doc = Document.objects.create(correspondent=c, content="test", title="test")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user