Fix: Catch new warning when loading the classifier (#5395)

This commit is contained in:
Trenton H 2024-01-14 13:21:17 -08:00 committed by GitHub
parent 16cc7415c1
commit 41a3c7c89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 8 deletions

View File

@ -10,6 +10,7 @@ from pathlib import Path
from typing import Optional
from django.conf import settings
from sklearn.exceptions import InconsistentVersionWarning
from documents.models import Document
from documents.models import MatchingModel
@ -18,7 +19,9 @@ logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception):
pass
def __init__(self, message: str, *args: object) -> None:
self.message = message
super().__init__(*args)
class ClassifierModelCorruptError(Exception):
@ -37,8 +40,8 @@ def load_classifier() -> Optional["DocumentClassifier"]:
try:
classifier.load()
except IncompatibleClassifierVersionError:
logger.info("Classifier version updated, will re-train")
except IncompatibleClassifierVersionError as e:
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
os.unlink(settings.MODEL_FILE)
classifier = None
except ClassifierModelCorruptError:
@ -114,10 +117,12 @@ class DocumentClassifier:
"#security-maintainability-limitations"
)
for warning in w:
if issubclass(warning.category, UserWarning):
w_msg = str(warning.message)
if sk_learn_warning_url in w_msg:
raise IncompatibleClassifierVersionError
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
if issubclass(warning.category, InconsistentVersionWarning) or (
issubclass(warning.category, UserWarning)
and sk_learn_warning_url in str(warning.message)
):
raise IncompatibleClassifierVersionError("sklearn version update")
def save(self):
target_file: Path = settings.MODEL_FILE

Binary file not shown.

View File

@ -1,5 +1,6 @@
import os
import re
import shutil
from pathlib import Path
from unittest import mock
@ -649,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = IncompatibleClassifierVersionError()
load.side_effect = IncompatibleClassifierVersionError("Dummey Error")
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@ -661,3 +662,14 @@ class TestClassifier(DirectoriesMixin, TestCase):
load.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))
def test_load_old_classifier_version(self):
shutil.copy(
os.path.join(os.path.dirname(__file__), "data", "v1.17.4.model.pickle"),
self.dirs.scratch_dir,
)
with override_settings(
MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle",
):
classifier = load_classifier()
self.assertIsNone(classifier)