mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-03 01:56:16 +00:00
Compare commits
5 Commits
v2.18.3
...
fix-joblib
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bc22a282d6 | ||
![]() |
fc5b9bdf59 | ||
![]() |
569cc46a43 | ||
![]() |
887b314744 | ||
![]() |
b9afc9b65d |
@@ -4,10 +4,13 @@ import logging
|
||||
import pickle
|
||||
import re
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import joblib
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
@@ -50,7 +53,24 @@ class ClassifierModelCorruptError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
|
||||
def _model_cache_token() -> tuple[str, int, int]:
|
||||
p = Path(settings.MODEL_FILE)
|
||||
if p.exists():
|
||||
try:
|
||||
st = p.stat()
|
||||
return (str(p), int(st.st_mtime), int(st.st_size))
|
||||
except OSError:
|
||||
return (str(p), 0, 0)
|
||||
return (str(p), 0, 0)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_classifier_cached(
|
||||
token: tuple[str, int, int],
|
||||
*,
|
||||
raise_exception: bool = False,
|
||||
) -> DocumentClassifier | None:
|
||||
# token used only for cache key; logic depends on current settings
|
||||
if not settings.MODEL_FILE.is_file():
|
||||
logger.debug(
|
||||
"Document classification model does not exist (yet), not "
|
||||
@@ -61,20 +81,23 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
|
||||
classifier = DocumentClassifier()
|
||||
try:
|
||||
classifier.load()
|
||||
|
||||
except IncompatibleClassifierVersionError as e:
|
||||
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
|
||||
Path(settings.MODEL_FILE).unlink()
|
||||
try:
|
||||
Path(settings.MODEL_FILE).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
classifier = None
|
||||
if raise_exception:
|
||||
raise e
|
||||
except ClassifierModelCorruptError as e:
|
||||
# there's something wrong with the model file.
|
||||
logger.exception(
|
||||
"Unrecoverable error while loading document "
|
||||
"classification model, deleting model file.",
|
||||
"Unrecoverable error while loading document classification model, deleting model file.",
|
||||
)
|
||||
Path(settings.MODEL_FILE).unlink
|
||||
try:
|
||||
Path(settings.MODEL_FILE).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
classifier = None
|
||||
if raise_exception:
|
||||
raise e
|
||||
@@ -92,11 +115,17 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
|
||||
return classifier
|
||||
|
||||
|
||||
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
|
||||
token = _model_cache_token()
|
||||
return _load_classifier_cached(token, raise_exception=raise_exception)
|
||||
|
||||
|
||||
class DocumentClassifier:
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
FORMAT_VERSION = 9
|
||||
# v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
|
||||
FORMAT_VERSION = 10
|
||||
|
||||
def __init__(self) -> None:
|
||||
# last time a document changed and therefore training might be required
|
||||
@@ -132,28 +161,57 @@ class DocumentClassifier:
|
||||
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
state = None
|
||||
try:
|
||||
state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
|
||||
except ValueError:
|
||||
# Some environments may fail to mmap small files; fall back to normal load
|
||||
state = joblib.load(settings.MODEL_FILE, mmap_mode=None)
|
||||
except Exception as err:
|
||||
# Fallback to old pickle-based format. Try to read the version and a field to
|
||||
# distinguish truly corrupt files from incompatible versions.
|
||||
try:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
_version = pickle.load(f)
|
||||
try:
|
||||
_ = pickle.load(f)
|
||||
except Exception as inner:
|
||||
raise ClassifierModelCorruptError from inner
|
||||
# Old, incompatible format
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
) from err
|
||||
except (
|
||||
IncompatibleClassifierVersionError,
|
||||
ClassifierModelCorruptError,
|
||||
):
|
||||
raise
|
||||
except Exception:
|
||||
# Not even a readable pickle header
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.last_doc_change_time = pickle.load(f)
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
if (
|
||||
not isinstance(state, dict)
|
||||
or state.get("format_version") != self.FORMAT_VERSION
|
||||
):
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
try:
|
||||
self.last_doc_change_time = state.get("last_doc_change_time")
|
||||
self.last_auto_type_hash = state.get("last_auto_type_hash")
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
self.data_vectorizer = state.get("data_vectorizer")
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = state.get("tags_binarizer")
|
||||
|
||||
self.tags_classifier = state.get("tags_classifier")
|
||||
self.correspondent_classifier = state.get("correspondent_classifier")
|
||||
self.document_type_classifier = state.get("document_type_classifier")
|
||||
self.storage_path_classifier = state.get("storage_path_classifier")
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
# Check for the warning about unpickling from differing versions
|
||||
# and consider it incompatible
|
||||
@@ -172,24 +230,28 @@ class DocumentClassifier:
|
||||
|
||||
def save(self) -> None:
|
||||
target_file: Path = settings.MODEL_FILE
|
||||
target_file_temp: Path = target_file.with_suffix(".pickle.part")
|
||||
target_file_temp: Path = target_file.with_suffix(".joblib.part")
|
||||
|
||||
with target_file_temp.open("wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
state = {
|
||||
"format_version": self.FORMAT_VERSION,
|
||||
"last_doc_change_time": self.last_doc_change_time,
|
||||
"last_auto_type_hash": self.last_auto_type_hash,
|
||||
"data_vectorizer": self.data_vectorizer,
|
||||
"tags_binarizer": self.tags_binarizer,
|
||||
"tags_classifier": self.tags_classifier,
|
||||
"correspondent_classifier": self.correspondent_classifier,
|
||||
"document_type_classifier": self.document_type_classifier,
|
||||
"storage_path_classifier": self.storage_path_classifier,
|
||||
}
|
||||
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
pickle.dump(self.last_auto_type_hash, f)
|
||||
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
pickle.dump(self.tags_classifier, f)
|
||||
|
||||
pickle.dump(self.correspondent_classifier, f)
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
pickle.dump(self.storage_path_classifier, f)
|
||||
joblib.dump(state, target_file_temp, compress=3)
|
||||
|
||||
target_file_temp.rename(target_file)
|
||||
# Invalidate cached classifier loader so subsequent calls see the new file
|
||||
try:
|
||||
_load_classifier_cached.cache_clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def train(self) -> bool:
|
||||
# Get non-inbox documents
|
||||
|
@@ -370,7 +370,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
|
||||
"""
|
||||
GIVEN:
|
||||
- Corrupted classifier pickle file
|
||||
- Corrupted legacy classifier pickle file
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
@@ -381,9 +381,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
# First load is the schema version,allow it
|
||||
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
patched_pickle_load.assert_called()
|
||||
# Force the loader down the legacy path by making joblib.load fail
|
||||
with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
patched_pickle_load.reset_mock()
|
||||
patched_pickle_load.side_effect = [
|
||||
@@ -391,8 +392,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
ClassifierModelCorruptError(),
|
||||
]
|
||||
|
||||
self.assertIsNone(load_classifier())
|
||||
patched_pickle_load.assert_called()
|
||||
with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
|
||||
self.assertIsNone(load_classifier())
|
||||
|
||||
def test_load_new_scikit_learn_version(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user