diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 4bae1830b..1c2ccea07 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -4,6 +4,8 @@ import os import pickle import re import shutil +import warnings +from typing import Optional from django.conf import settings from documents.models import Document @@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception): logger = logging.getLogger("paperless.classifier") -def preprocess_content(content): +def preprocess_content(content: str) -> str: content = content.lower().strip() content = re.sub(r"\s+", " ", content) return content -def load_classifier(): +def load_classifier() -> Optional["DocumentClassifier"]: if not os.path.isfile(settings.MODEL_FILE): logger.debug( "Document classification model does not exist (yet), not " @@ -39,7 +41,11 @@ def load_classifier(): try: classifier.load() - except (ClassifierModelCorruptError, IncompatibleClassifierVersionError): + except IncompatibleClassifierVersionError: + logger.info("Classifier version updated, will re-train") + os.unlink(settings.MODEL_FILE) + classifier = None + except ClassifierModelCorruptError: # there's something wrong with the model file. logger.exception( "Unrecoverable error while loading document " @@ -59,13 +65,14 @@ def load_classifier(): class DocumentClassifier: + # v7 - Updated scikit-learn package version # v8 - Added storage path classifier FORMAT_VERSION = 8 def __init__(self): # hash of the training data. used to prevent re-training when the # training data has not changed. - self.data_hash = None + self.data_hash: Optional[bytes] = None self.data_vectorizer = None self.tags_binarizer = None @@ -75,25 +82,41 @@ class DocumentClassifier: self.storage_path_classifier = None def load(self): - with open(settings.MODEL_FILE, "rb") as f: - schema_version = pickle.load(f) + # Catch warnings for processing + with warnings.catch_warnings(record=True) as w: + with open(settings.MODEL_FILE, "rb") as f: + schema_version = pickle.load(f) - if schema_version != self.FORMAT_VERSION: - raise IncompatibleClassifierVersionError( - "Cannot load classifier, incompatible versions.", + if schema_version != self.FORMAT_VERSION: + raise IncompatibleClassifierVersionError( + "Cannot load classifier, incompatible versions.", + ) + else: + try: + self.data_hash = pickle.load(f) + self.data_vectorizer = pickle.load(f) + self.tags_binarizer = pickle.load(f) + + self.tags_classifier = pickle.load(f) + self.correspondent_classifier = pickle.load(f) + self.document_type_classifier = pickle.load(f) + self.storage_path_classifier = pickle.load(f) + except Exception: + raise ClassifierModelCorruptError() + + # Check for the warning about unpickling from differing versions + # and consider it incompatible + if len(w) > 0: + sk_learn_warning_url = ( + "https://scikit-learn.org/stable/" + "model_persistence.html" + "#security-maintainability-limitations" ) - else: - try: - self.data_hash = pickle.load(f) - self.data_vectorizer = pickle.load(f) - self.tags_binarizer = pickle.load(f) - - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.storage_path_classifier = pickle.load(f) - except Exception: - raise ClassifierModelCorruptError() + for warning in w: + if issubclass(warning.category, UserWarning): + w_msg = str(warning.message) + if sk_learn_warning_url in w_msg: + raise IncompatibleClassifierVersionError() def save(self): target_file = settings.MODEL_FILE diff --git a/src/documents/tests/data/model.pickle b/src/documents/tests/data/model.pickle index 8a0e1829c..ff88b8894 100644 Binary files a/src/documents/tests/data/model.pickle and b/src/documents/tests/data/model.pickle differ diff --git a/src/documents/tests/data/v1.0.2.model.pickle b/src/documents/tests/data/v1.0.2.model.pickle new file mode 100644 index 000000000..8a0e1829c Binary files /dev/null and b/src/documents/tests/data/v1.0.2.model.pickle differ diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index dcc503f97..cfa662c02 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -3,10 +3,12 @@ import tempfile from pathlib import Path from unittest import mock +import documents import pytest from django.conf import settings from django.test import override_settings from django.test import TestCase +from documents.classifier import ClassifierModelCorruptError from documents.classifier import DocumentClassifier from documents.classifier import IncompatibleClassifierVersionError from documents.classifier import load_classifier @@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) + @override_settings( + MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"), + ) + @mock.patch("documents.classifier.pickle.load") + def test_load_corrupt_file(self, patched_pickle_load): + """ + GIVEN: + - Corrupted classifier pickle file + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised + """ + # First load is the schema version + patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + @override_settings( + MODEL_FILE=os.path.join( + os.path.dirname(__file__), + "data", + "v1.0.2.model.pickle", + ), + ) + def test_load_new_scikit_learn_version(self): + """ + GIVEN: + - classifier pickle file created with a different scikit-learn version + WHEN: + - An attempt is made to load the classifier + THEN: + - The classifier reports the warning was captured and processed + """ + + with self.assertRaises(IncompatibleClassifierVersionError): + self.classifier.load() + def test_one_correspondent_predict(self): c1 = Correspondent.objects.create( name="c1",