mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Fix: Catch new warning when loading the classifier (#5395)
This commit is contained in:
parent
16cc7415c1
commit
41a3c7c89b
@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from sklearn.exceptions import InconsistentVersionWarning
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import MatchingModel
|
from documents.models import MatchingModel
|
||||||
@ -18,7 +19,9 @@ logger = logging.getLogger("paperless.classifier")
|
|||||||
|
|
||||||
|
|
||||||
class IncompatibleClassifierVersionError(Exception):
|
class IncompatibleClassifierVersionError(Exception):
|
||||||
pass
|
def __init__(self, message: str, *args: object) -> None:
|
||||||
|
self.message = message
|
||||||
|
super().__init__(*args)
|
||||||
|
|
||||||
|
|
||||||
class ClassifierModelCorruptError(Exception):
|
class ClassifierModelCorruptError(Exception):
|
||||||
@ -37,8 +40,8 @@ def load_classifier() -> Optional["DocumentClassifier"]:
|
|||||||
try:
|
try:
|
||||||
classifier.load()
|
classifier.load()
|
||||||
|
|
||||||
except IncompatibleClassifierVersionError:
|
except IncompatibleClassifierVersionError as e:
|
||||||
logger.info("Classifier version updated, will re-train")
|
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
|
||||||
os.unlink(settings.MODEL_FILE)
|
os.unlink(settings.MODEL_FILE)
|
||||||
classifier = None
|
classifier = None
|
||||||
except ClassifierModelCorruptError:
|
except ClassifierModelCorruptError:
|
||||||
@ -114,10 +117,12 @@ class DocumentClassifier:
|
|||||||
"#security-maintainability-limitations"
|
"#security-maintainability-limitations"
|
||||||
)
|
)
|
||||||
for warning in w:
|
for warning in w:
|
||||||
if issubclass(warning.category, UserWarning):
|
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
|
||||||
w_msg = str(warning.message)
|
if issubclass(warning.category, InconsistentVersionWarning) or (
|
||||||
if sk_learn_warning_url in w_msg:
|
issubclass(warning.category, UserWarning)
|
||||||
raise IncompatibleClassifierVersionError
|
and sk_learn_warning_url in str(warning.message)
|
||||||
|
):
|
||||||
|
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
target_file: Path = settings.MODEL_FILE
|
target_file: Path = settings.MODEL_FILE
|
||||||
|
Binary file not shown.
BIN
src/documents/tests/data/v1.17.4.model.pickle
Normal file
BIN
src/documents/tests/data/v1.17.4.model.pickle
Normal file
Binary file not shown.
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@ -649,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
Path(settings.MODEL_FILE).touch()
|
Path(settings.MODEL_FILE).touch()
|
||||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
load.side_effect = IncompatibleClassifierVersionError()
|
load.side_effect = IncompatibleClassifierVersionError("Dummey Error")
|
||||||
self.assertIsNone(load_classifier())
|
self.assertIsNone(load_classifier())
|
||||||
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
@ -661,3 +662,14 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
load.side_effect = OSError()
|
load.side_effect = OSError()
|
||||||
self.assertIsNone(load_classifier())
|
self.assertIsNone(load_classifier())
|
||||||
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
def test_load_old_classifier_version(self):
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(os.path.dirname(__file__), "data", "v1.17.4.model.pickle"),
|
||||||
|
self.dirs.scratch_dir,
|
||||||
|
)
|
||||||
|
with override_settings(
|
||||||
|
MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle",
|
||||||
|
):
|
||||||
|
classifier = load_classifier()
|
||||||
|
self.assertIsNone(classifier)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user