Fix: Catch new warning when loading the classifier (#5395)

This commit is contained in:
Trenton H 2024-01-14 13:21:17 -08:00 committed by GitHub
parent 16cc7415c1
commit 41a3c7c89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 8 deletions

View File

@ -10,6 +10,7 @@ from pathlib import Path
from typing import Optional from typing import Optional
from django.conf import settings from django.conf import settings
from sklearn.exceptions import InconsistentVersionWarning
from documents.models import Document from documents.models import Document
from documents.models import MatchingModel from documents.models import MatchingModel
@ -18,7 +19,9 @@ logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception): class IncompatibleClassifierVersionError(Exception):
pass def __init__(self, message: str, *args: object) -> None:
self.message = message
super().__init__(*args)
class ClassifierModelCorruptError(Exception): class ClassifierModelCorruptError(Exception):
@ -37,8 +40,8 @@ def load_classifier() -> Optional["DocumentClassifier"]:
try: try:
classifier.load() classifier.load()
except IncompatibleClassifierVersionError: except IncompatibleClassifierVersionError as e:
logger.info("Classifier version updated, will re-train") logger.info(f"Classifier version incompatible: {e.message}, will re-train")
os.unlink(settings.MODEL_FILE) os.unlink(settings.MODEL_FILE)
classifier = None classifier = None
except ClassifierModelCorruptError: except ClassifierModelCorruptError:
@ -114,10 +117,12 @@ class DocumentClassifier:
"#security-maintainability-limitations" "#security-maintainability-limitations"
) )
for warning in w: for warning in w:
if issubclass(warning.category, UserWarning): # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
w_msg = str(warning.message) if issubclass(warning.category, InconsistentVersionWarning) or (
if sk_learn_warning_url in w_msg: issubclass(warning.category, UserWarning)
raise IncompatibleClassifierVersionError and sk_learn_warning_url in str(warning.message)
):
raise IncompatibleClassifierVersionError("sklearn version update")
def save(self): def save(self):
target_file: Path = settings.MODEL_FILE target_file: Path = settings.MODEL_FILE

Binary file not shown.

View File

@ -1,5 +1,6 @@
import os import os
import re import re
import shutil
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
@ -649,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
Path(settings.MODEL_FILE).touch() Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE)) self.assertTrue(os.path.exists(settings.MODEL_FILE))
load.side_effect = IncompatibleClassifierVersionError() load.side_effect = IncompatibleClassifierVersionError("Dummey Error")
self.assertIsNone(load_classifier()) self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE)) self.assertFalse(os.path.exists(settings.MODEL_FILE))
@ -661,3 +662,14 @@ class TestClassifier(DirectoriesMixin, TestCase):
load.side_effect = OSError() load.side_effect = OSError()
self.assertIsNone(load_classifier()) self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE)) self.assertTrue(os.path.exists(settings.MODEL_FILE))
def test_load_old_classifier_version(self):
shutil.copy(
os.path.join(os.path.dirname(__file__), "data", "v1.17.4.model.pickle"),
self.dirs.scratch_dir,
)
with override_settings(
MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle",
):
classifier = load_classifier()
self.assertIsNone(classifier)