diff --git a/src/documents/classifier.py b/src/documents/classifier.py index cc8af5868..52af2733a 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Optional from django.conf import settings +from sklearn.exceptions import InconsistentVersionWarning from documents.models import Document from documents.models import MatchingModel @@ -18,7 +19,9 @@ logger = logging.getLogger("paperless.classifier") class IncompatibleClassifierVersionError(Exception): - pass + def __init__(self, message: str, *args: object) -> None: + self.message = message + super().__init__(*args) class ClassifierModelCorruptError(Exception): @@ -37,8 +40,8 @@ def load_classifier() -> Optional["DocumentClassifier"]: try: classifier.load() - except IncompatibleClassifierVersionError: - logger.info("Classifier version updated, will re-train") + except IncompatibleClassifierVersionError as e: + logger.info(f"Classifier version incompatible: {e.message}, will re-train") os.unlink(settings.MODEL_FILE) classifier = None except ClassifierModelCorruptError: @@ -114,10 +117,12 @@ class DocumentClassifier: "#security-maintainability-limitations" ) for warning in w: - if issubclass(warning.category, UserWarning): - w_msg = str(warning.message) - if sk_learn_warning_url in w_msg: - raise IncompatibleClassifierVersionError + # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet + if issubclass(warning.category, InconsistentVersionWarning) or ( + issubclass(warning.category, UserWarning) + and sk_learn_warning_url in str(warning.message) + ): + raise IncompatibleClassifierVersionError("sklearn version update") def save(self): target_file: Path = settings.MODEL_FILE diff --git a/src/documents/tests/data/v1.0.2.model.pickle b/src/documents/tests/data/v1.0.2.model.pickle deleted file mode 100644 index 8a0e1829c..000000000 Binary files a/src/documents/tests/data/v1.0.2.model.pickle and /dev/null differ diff --git a/src/documents/tests/data/v1.17.4.model.pickle b/src/documents/tests/data/v1.17.4.model.pickle new file mode 100644 index 000000000..4b2734607 Binary files /dev/null and b/src/documents/tests/data/v1.17.4.model.pickle differ diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 0b91e223f..cb1c5c8a3 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,5 +1,6 @@ import os import re +import shutil from pathlib import Path from unittest import mock @@ -649,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase): Path(settings.MODEL_FILE).touch() self.assertTrue(os.path.exists(settings.MODEL_FILE)) - load.side_effect = IncompatibleClassifierVersionError() + load.side_effect = IncompatibleClassifierVersionError("Dummey Error") self.assertIsNone(load_classifier()) self.assertFalse(os.path.exists(settings.MODEL_FILE)) @@ -661,3 +662,14 @@ class TestClassifier(DirectoriesMixin, TestCase): load.side_effect = OSError() self.assertIsNone(load_classifier()) self.assertTrue(os.path.exists(settings.MODEL_FILE)) + + def test_load_old_classifier_version(self): + shutil.copy( + os.path.join(os.path.dirname(__file__), "data", "v1.17.4.model.pickle"), + self.dirs.scratch_dir, + ) + with override_settings( + MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle", + ): + classifier = load_classifier() + self.assertIsNone(classifier)