mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-01 01:46:16 +00:00
Try joblib
This commit is contained in:
@@ -103,7 +103,8 @@ class DocumentClassifier:
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
FORMAT_VERSION = 9
|
||||
# v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
|
||||
FORMAT_VERSION = 10
|
||||
|
||||
def __init__(self) -> None:
|
||||
# last time a document changed and therefore training might be required
|
||||
@@ -135,32 +136,51 @@ class DocumentClassifier:
|
||||
).hexdigest()
|
||||
|
||||
def load(self) -> None:
|
||||
import joblib
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
try:
|
||||
state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
|
||||
except Exception as err:
|
||||
# As a fallback, try to detect old pickle-based and mark incompatible
|
||||
try:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
_ = pickle.load(f)
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
) from err
|
||||
except IncompatibleClassifierVersionError:
|
||||
raise
|
||||
except Exception:
|
||||
# Not even a readable pickle header
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
try:
|
||||
if (
|
||||
not isinstance(state, dict)
|
||||
or state.get("format_version") != self.FORMAT_VERSION
|
||||
):
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.last_doc_change_time = pickle.load(f)
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
self.last_doc_change_time = state.get("last_doc_change_time")
|
||||
self.last_auto_type_hash = state.get("last_auto_type_hash")
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
self.data_vectorizer = state.get("data_vectorizer")
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = state.get("tags_binarizer")
|
||||
|
||||
self.tags_classifier = state.get("tags_classifier")
|
||||
self.correspondent_classifier = state.get("correspondent_classifier")
|
||||
self.document_type_classifier = state.get("document_type_classifier")
|
||||
self.storage_path_classifier = state.get("storage_path_classifier")
|
||||
except IncompatibleClassifierVersionError:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
# Check for the warning about unpickling from differing versions
|
||||
# and consider it incompatible
|
||||
@@ -178,23 +198,24 @@ class DocumentClassifier:
|
||||
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||
|
||||
def save(self) -> None:
|
||||
import joblib
|
||||
|
||||
target_file: Path = settings.MODEL_FILE
|
||||
target_file_temp: Path = target_file.with_suffix(".pickle.part")
|
||||
target_file_temp: Path = target_file.with_suffix(".joblib.part")
|
||||
|
||||
with target_file_temp.open("wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
state = {
|
||||
"format_version": self.FORMAT_VERSION,
|
||||
"last_doc_change_time": self.last_doc_change_time,
|
||||
"last_auto_type_hash": self.last_auto_type_hash,
|
||||
"data_vectorizer": self.data_vectorizer,
|
||||
"tags_binarizer": self.tags_binarizer,
|
||||
"tags_classifier": self.tags_classifier,
|
||||
"correspondent_classifier": self.correspondent_classifier,
|
||||
"document_type_classifier": self.document_type_classifier,
|
||||
"storage_path_classifier": self.storage_path_classifier,
|
||||
}
|
||||
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
pickle.dump(self.last_auto_type_hash, f)
|
||||
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
pickle.dump(self.tags_classifier, f)
|
||||
|
||||
pickle.dump(self.correspondent_classifier, f)
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
pickle.dump(self.storage_path_classifier, f)
|
||||
joblib.dump(state, target_file_temp, compress=3)
|
||||
|
||||
target_file_temp.rename(target_file)
|
||||
|
||||
|
Reference in New Issue
Block a user