mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-03 01:56:16 +00:00
Try joblib
This commit is contained in:
@@ -103,7 +103,8 @@ class DocumentClassifier:
|
|||||||
# v7 - Updated scikit-learn package version
|
# v7 - Updated scikit-learn package version
|
||||||
# v8 - Added storage path classifier
|
# v8 - Added storage path classifier
|
||||||
# v9 - Changed from hashing to time/ids for re-train check
|
# v9 - Changed from hashing to time/ids for re-train check
|
||||||
FORMAT_VERSION = 9
|
# v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
|
||||||
|
FORMAT_VERSION = 10
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# last time a document changed and therefore training might be required
|
# last time a document changed and therefore training might be required
|
||||||
@@ -135,32 +136,51 @@ class DocumentClassifier:
|
|||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
|
import joblib
|
||||||
from sklearn.exceptions import InconsistentVersionWarning
|
from sklearn.exceptions import InconsistentVersionWarning
|
||||||
|
|
||||||
# Catch warnings for processing
|
# Catch warnings for processing
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
try:
|
||||||
schema_version = pickle.load(f)
|
state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
|
||||||
|
except Exception as err:
|
||||||
|
# As a fallback, try to detect old pickle-based and mark incompatible
|
||||||
|
try:
|
||||||
|
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||||
|
_ = pickle.load(f)
|
||||||
|
raise IncompatibleClassifierVersionError(
|
||||||
|
"Cannot load classifier, incompatible versions.",
|
||||||
|
) from err
|
||||||
|
except IncompatibleClassifierVersionError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# Not even a readable pickle header
|
||||||
|
raise ClassifierModelCorruptError from err
|
||||||
|
|
||||||
if schema_version != self.FORMAT_VERSION:
|
try:
|
||||||
|
if (
|
||||||
|
not isinstance(state, dict)
|
||||||
|
or state.get("format_version") != self.FORMAT_VERSION
|
||||||
|
):
|
||||||
raise IncompatibleClassifierVersionError(
|
raise IncompatibleClassifierVersionError(
|
||||||
"Cannot load classifier, incompatible versions.",
|
"Cannot load classifier, incompatible versions.",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.last_doc_change_time = pickle.load(f)
|
|
||||||
self.last_auto_type_hash = pickle.load(f)
|
|
||||||
|
|
||||||
self.data_vectorizer = pickle.load(f)
|
self.last_doc_change_time = state.get("last_doc_change_time")
|
||||||
self._update_data_vectorizer_hash()
|
self.last_auto_type_hash = state.get("last_auto_type_hash")
|
||||||
self.tags_binarizer = pickle.load(f)
|
|
||||||
|
|
||||||
self.tags_classifier = pickle.load(f)
|
self.data_vectorizer = state.get("data_vectorizer")
|
||||||
self.correspondent_classifier = pickle.load(f)
|
self._update_data_vectorizer_hash()
|
||||||
self.document_type_classifier = pickle.load(f)
|
self.tags_binarizer = state.get("tags_binarizer")
|
||||||
self.storage_path_classifier = pickle.load(f)
|
|
||||||
except Exception as err:
|
self.tags_classifier = state.get("tags_classifier")
|
||||||
raise ClassifierModelCorruptError from err
|
self.correspondent_classifier = state.get("correspondent_classifier")
|
||||||
|
self.document_type_classifier = state.get("document_type_classifier")
|
||||||
|
self.storage_path_classifier = state.get("storage_path_classifier")
|
||||||
|
except IncompatibleClassifierVersionError:
|
||||||
|
raise
|
||||||
|
except Exception as err:
|
||||||
|
raise ClassifierModelCorruptError from err
|
||||||
|
|
||||||
# Check for the warning about unpickling from differing versions
|
# Check for the warning about unpickling from differing versions
|
||||||
# and consider it incompatible
|
# and consider it incompatible
|
||||||
@@ -178,23 +198,24 @@ class DocumentClassifier:
|
|||||||
raise IncompatibleClassifierVersionError("sklearn version update")
|
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
|
import joblib
|
||||||
|
|
||||||
target_file: Path = settings.MODEL_FILE
|
target_file: Path = settings.MODEL_FILE
|
||||||
target_file_temp: Path = target_file.with_suffix(".pickle.part")
|
target_file_temp: Path = target_file.with_suffix(".joblib.part")
|
||||||
|
|
||||||
with target_file_temp.open("wb") as f:
|
state = {
|
||||||
pickle.dump(self.FORMAT_VERSION, f)
|
"format_version": self.FORMAT_VERSION,
|
||||||
|
"last_doc_change_time": self.last_doc_change_time,
|
||||||
|
"last_auto_type_hash": self.last_auto_type_hash,
|
||||||
|
"data_vectorizer": self.data_vectorizer,
|
||||||
|
"tags_binarizer": self.tags_binarizer,
|
||||||
|
"tags_classifier": self.tags_classifier,
|
||||||
|
"correspondent_classifier": self.correspondent_classifier,
|
||||||
|
"document_type_classifier": self.document_type_classifier,
|
||||||
|
"storage_path_classifier": self.storage_path_classifier,
|
||||||
|
}
|
||||||
|
|
||||||
pickle.dump(self.last_doc_change_time, f)
|
joblib.dump(state, target_file_temp, compress=3)
|
||||||
pickle.dump(self.last_auto_type_hash, f)
|
|
||||||
|
|
||||||
pickle.dump(self.data_vectorizer, f)
|
|
||||||
|
|
||||||
pickle.dump(self.tags_binarizer, f)
|
|
||||||
pickle.dump(self.tags_classifier, f)
|
|
||||||
|
|
||||||
pickle.dump(self.correspondent_classifier, f)
|
|
||||||
pickle.dump(self.document_type_classifier, f)
|
|
||||||
pickle.dump(self.storage_path_classifier, f)
|
|
||||||
|
|
||||||
target_file_temp.rename(target_file)
|
target_file_temp.rename(target_file)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user