Compare commits

...

5 Commits

Author SHA1 Message Date
shamoon
bc22a282d6 Fix caching, maybe 2025-08-31 16:08:32 -07:00
shamoon
fc5b9bdf59 Cache classifier loading with lru_cache 2025-08-31 15:41:33 -07:00
shamoon
569cc46a43 Fix test 2025-08-31 15:16:12 -07:00
shamoon
887b314744 Fix loading / error handling 2025-08-31 15:08:23 -07:00
shamoon
b9afc9b65d Performance fix: change classifier persistence to joblib 2025-08-31 15:08:22 -07:00
2 changed files with 110 additions and 47 deletions

View File

@@ -4,10 +4,13 @@ import logging
import pickle
import re
import warnings
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import TYPE_CHECKING
import joblib
if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime
@@ -50,7 +53,24 @@ class ClassifierModelCorruptError(Exception):
pass
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
def _model_cache_token() -> tuple[str, int, int]:
p = Path(settings.MODEL_FILE)
if p.exists():
try:
st = p.stat()
return (str(p), int(st.st_mtime), int(st.st_size))
except OSError:
return (str(p), 0, 0)
return (str(p), 0, 0)
@lru_cache(maxsize=1)
def _load_classifier_cached(
token: tuple[str, int, int],
*,
raise_exception: bool = False,
) -> DocumentClassifier | None:
# token used only for cache key; logic depends on current settings
if not settings.MODEL_FILE.is_file():
logger.debug(
"Document classification model does not exist (yet), not "
@@ -61,20 +81,23 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
classifier = DocumentClassifier()
try:
classifier.load()
except IncompatibleClassifierVersionError as e:
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
Path(settings.MODEL_FILE).unlink()
try:
Path(settings.MODEL_FILE).unlink()
except Exception:
pass
classifier = None
if raise_exception:
raise e
except ClassifierModelCorruptError as e:
# there's something wrong with the model file.
logger.exception(
"Unrecoverable error while loading document "
"classification model, deleting model file.",
"Unrecoverable error while loading document classification model, deleting model file.",
)
Path(settings.MODEL_FILE).unlink
try:
Path(settings.MODEL_FILE).unlink()
except Exception:
pass
classifier = None
if raise_exception:
raise e
@@ -92,11 +115,17 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
return classifier
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
token = _model_cache_token()
return _load_classifier_cached(token, raise_exception=raise_exception)
class DocumentClassifier:
# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
# v9 - Changed from hashing to time/ids for re-train check
FORMAT_VERSION = 9
# v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
FORMAT_VERSION = 10
def __init__(self) -> None:
# last time a document changed and therefore training might be required
@@ -132,28 +161,57 @@ class DocumentClassifier:
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
with Path(settings.MODEL_FILE).open("rb") as f:
schema_version = pickle.load(f)
state = None
try:
state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
except ValueError:
# Some environments may fail to mmap small files; fall back to normal load
state = joblib.load(settings.MODEL_FILE, mmap_mode=None)
except Exception as err:
# Fallback to old pickle-based format. Try to read the version and a field to
# distinguish truly corrupt files from incompatible versions.
try:
with Path(settings.MODEL_FILE).open("rb") as f:
_version = pickle.load(f)
try:
_ = pickle.load(f)
except Exception as inner:
raise ClassifierModelCorruptError from inner
# Old, incompatible format
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
) from err
except (
IncompatibleClassifierVersionError,
ClassifierModelCorruptError,
):
raise
except Exception:
# Not even a readable pickle header
raise ClassifierModelCorruptError from err
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
else:
try:
self.last_doc_change_time = pickle.load(f)
self.last_auto_type_hash = pickle.load(f)
if (
not isinstance(state, dict)
or state.get("format_version") != self.FORMAT_VERSION
):
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
self.data_vectorizer = pickle.load(f)
self._update_data_vectorizer_hash()
self.tags_binarizer = pickle.load(f)
try:
self.last_doc_change_time = state.get("last_doc_change_time")
self.last_auto_type_hash = state.get("last_auto_type_hash")
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception as err:
raise ClassifierModelCorruptError from err
self.data_vectorizer = state.get("data_vectorizer")
self._update_data_vectorizer_hash()
self.tags_binarizer = state.get("tags_binarizer")
self.tags_classifier = state.get("tags_classifier")
self.correspondent_classifier = state.get("correspondent_classifier")
self.document_type_classifier = state.get("document_type_classifier")
self.storage_path_classifier = state.get("storage_path_classifier")
except Exception as err:
raise ClassifierModelCorruptError from err
# Check for the warning about unpickling from differing versions
# and consider it incompatible
@@ -172,24 +230,28 @@ class DocumentClassifier:
def save(self) -> None:
target_file: Path = settings.MODEL_FILE
target_file_temp: Path = target_file.with_suffix(".pickle.part")
target_file_temp: Path = target_file.with_suffix(".joblib.part")
with target_file_temp.open("wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
state = {
"format_version": self.FORMAT_VERSION,
"last_doc_change_time": self.last_doc_change_time,
"last_auto_type_hash": self.last_auto_type_hash,
"data_vectorizer": self.data_vectorizer,
"tags_binarizer": self.tags_binarizer,
"tags_classifier": self.tags_classifier,
"correspondent_classifier": self.correspondent_classifier,
"document_type_classifier": self.document_type_classifier,
"storage_path_classifier": self.storage_path_classifier,
}
pickle.dump(self.last_doc_change_time, f)
pickle.dump(self.last_auto_type_hash, f)
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f)
joblib.dump(state, target_file_temp, compress=3)
target_file_temp.rename(target_file)
# Invalidate cached classifier loader so subsequent calls see the new file
try:
_load_classifier_cached.cache_clear()
except Exception:
pass
def train(self) -> bool:
# Get non-inbox documents

View File

@@ -370,7 +370,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
"""
GIVEN:
- Corrupted classifier pickle file
- Corrupted legacy classifier pickle file
WHEN:
- An attempt is made to load the classifier
THEN:
@@ -381,9 +381,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
# First load is the schema version,allow it
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
patched_pickle_load.assert_called()
# Force the loader down the legacy path by making joblib.load fail
with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
patched_pickle_load.reset_mock()
patched_pickle_load.side_effect = [
@@ -391,8 +392,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
ClassifierModelCorruptError(),
]
self.assertIsNone(load_classifier())
patched_pickle_load.assert_called()
with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
self.assertIsNone(load_classifier())
def test_load_new_scikit_learn_version(self):
"""