From b9afc9b65dbc43a8e91875ae297cfd2231493ece Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sun, 31 Aug 2025 14:35:58 -0700 Subject: [PATCH] Performance fix: change classifier persistence to joblib --- src/documents/classifier.py | 82 +++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 613c1d5ad..155ff2a56 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -8,6 +8,8 @@ from hashlib import sha256 from pathlib import Path from typing import TYPE_CHECKING +import joblib + if TYPE_CHECKING: from collections.abc import Iterator from datetime import datetime @@ -96,7 +98,8 @@ class DocumentClassifier: # v7 - Updated scikit-learn package version # v8 - Added storage path classifier # v9 - Changed from hashing to time/ids for re-train check - FORMAT_VERSION = 9 + # v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes + FORMAT_VERSION = 10 def __init__(self) -> None: # last time a document changed and therefore training might be required @@ -132,28 +135,46 @@ class DocumentClassifier: # Catch warnings for processing with warnings.catch_warnings(record=True) as w: - with Path(settings.MODEL_FILE).open("rb") as f: - schema_version = pickle.load(f) + try: + state = joblib.load(settings.MODEL_FILE, mmap_mode="r") + except Exception as err: + # As a fallback, try to detect old pickle-based and mark incompatible + try: + with Path(settings.MODEL_FILE).open("rb") as f: + _ = pickle.load(f) + raise IncompatibleClassifierVersionError( + "Cannot load classifier, incompatible versions.", + ) from err + except IncompatibleClassifierVersionError: + raise + except Exception: + # Not even a readable pickle header + raise ClassifierModelCorruptError from err - if schema_version != self.FORMAT_VERSION: + try: + if ( + not isinstance(state, dict) + or state.get("format_version") != self.FORMAT_VERSION + ): raise IncompatibleClassifierVersionError( "Cannot load classifier, incompatible versions.", ) - else: - try: - self.last_doc_change_time = pickle.load(f) - self.last_auto_type_hash = pickle.load(f) - self.data_vectorizer = pickle.load(f) - self._update_data_vectorizer_hash() - self.tags_binarizer = pickle.load(f) + self.last_doc_change_time = state.get("last_doc_change_time") + self.last_auto_type_hash = state.get("last_auto_type_hash") - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.storage_path_classifier = pickle.load(f) - except Exception as err: - raise ClassifierModelCorruptError from err + self.data_vectorizer = state.get("data_vectorizer") + self._update_data_vectorizer_hash() + self.tags_binarizer = state.get("tags_binarizer") + + self.tags_classifier = state.get("tags_classifier") + self.correspondent_classifier = state.get("correspondent_classifier") + self.document_type_classifier = state.get("document_type_classifier") + self.storage_path_classifier = state.get("storage_path_classifier") + except IncompatibleClassifierVersionError: + raise + except Exception as err: + raise ClassifierModelCorruptError from err # Check for the warning about unpickling from differing versions # and consider it incompatible @@ -172,22 +193,21 @@ class DocumentClassifier: def save(self) -> None: target_file: Path = settings.MODEL_FILE - target_file_temp: Path = target_file.with_suffix(".pickle.part") + target_file_temp: Path = target_file.with_suffix(".joblib.part") - with target_file_temp.open("wb") as f: - pickle.dump(self.FORMAT_VERSION, f) + state = { + "format_version": self.FORMAT_VERSION, + "last_doc_change_time": self.last_doc_change_time, + "last_auto_type_hash": self.last_auto_type_hash, + "data_vectorizer": self.data_vectorizer, + "tags_binarizer": self.tags_binarizer, + "tags_classifier": self.tags_classifier, + "correspondent_classifier": self.correspondent_classifier, + "document_type_classifier": self.document_type_classifier, + "storage_path_classifier": self.storage_path_classifier, + } - pickle.dump(self.last_doc_change_time, f) - pickle.dump(self.last_auto_type_hash, f) - - pickle.dump(self.data_vectorizer, f) - - pickle.dump(self.tags_binarizer, f) - pickle.dump(self.tags_classifier, f) - - pickle.dump(self.correspondent_classifier, f) - pickle.dump(self.document_type_classifier, f) - pickle.dump(self.storage_path_classifier, f) + joblib.dump(state, target_file_temp, compress=3) target_file_temp.rename(target_file)