mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 03:16:10 -06:00 
			
		
		
		
	Try joblib
This commit is contained in:
		@@ -103,7 +103,8 @@ class DocumentClassifier:
 | 
			
		||||
    # v7 - Updated scikit-learn package version
 | 
			
		||||
    # v8 - Added storage path classifier
 | 
			
		||||
    # v9 - Changed from hashing to time/ids for re-train check
 | 
			
		||||
    FORMAT_VERSION = 9
 | 
			
		||||
    # v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
 | 
			
		||||
    FORMAT_VERSION = 10
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        # last time a document changed and therefore training might be required
 | 
			
		||||
@@ -135,32 +136,51 @@ class DocumentClassifier:
 | 
			
		||||
        ).hexdigest()
 | 
			
		||||
 | 
			
		||||
    def load(self) -> None:
 | 
			
		||||
        import joblib
 | 
			
		||||
        from sklearn.exceptions import InconsistentVersionWarning
 | 
			
		||||
 | 
			
		||||
        # Catch warnings for processing
 | 
			
		||||
        with warnings.catch_warnings(record=True) as w:
 | 
			
		||||
            with Path(settings.MODEL_FILE).open("rb") as f:
 | 
			
		||||
                schema_version = pickle.load(f)
 | 
			
		||||
            try:
 | 
			
		||||
                state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
                # As a fallback, try to detect old pickle-based and mark incompatible
 | 
			
		||||
                try:
 | 
			
		||||
                    with Path(settings.MODEL_FILE).open("rb") as f:
 | 
			
		||||
                        _ = pickle.load(f)
 | 
			
		||||
                    raise IncompatibleClassifierVersionError(
 | 
			
		||||
                        "Cannot load classifier, incompatible versions.",
 | 
			
		||||
                    ) from err
 | 
			
		||||
                except IncompatibleClassifierVersionError:
 | 
			
		||||
                    raise
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    # Not even a readable pickle header
 | 
			
		||||
                    raise ClassifierModelCorruptError from err
 | 
			
		||||
 | 
			
		||||
                if schema_version != self.FORMAT_VERSION:
 | 
			
		||||
            try:
 | 
			
		||||
                if (
 | 
			
		||||
                    not isinstance(state, dict)
 | 
			
		||||
                    or state.get("format_version") != self.FORMAT_VERSION
 | 
			
		||||
                ):
 | 
			
		||||
                    raise IncompatibleClassifierVersionError(
 | 
			
		||||
                        "Cannot load classifier, incompatible versions.",
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    try:
 | 
			
		||||
                        self.last_doc_change_time = pickle.load(f)
 | 
			
		||||
                        self.last_auto_type_hash = pickle.load(f)
 | 
			
		||||
 | 
			
		||||
                        self.data_vectorizer = pickle.load(f)
 | 
			
		||||
                        self._update_data_vectorizer_hash()
 | 
			
		||||
                        self.tags_binarizer = pickle.load(f)
 | 
			
		||||
                self.last_doc_change_time = state.get("last_doc_change_time")
 | 
			
		||||
                self.last_auto_type_hash = state.get("last_auto_type_hash")
 | 
			
		||||
 | 
			
		||||
                        self.tags_classifier = pickle.load(f)
 | 
			
		||||
                        self.correspondent_classifier = pickle.load(f)
 | 
			
		||||
                        self.document_type_classifier = pickle.load(f)
 | 
			
		||||
                        self.storage_path_classifier = pickle.load(f)
 | 
			
		||||
                    except Exception as err:
 | 
			
		||||
                        raise ClassifierModelCorruptError from err
 | 
			
		||||
                self.data_vectorizer = state.get("data_vectorizer")
 | 
			
		||||
                self._update_data_vectorizer_hash()
 | 
			
		||||
                self.tags_binarizer = state.get("tags_binarizer")
 | 
			
		||||
 | 
			
		||||
                self.tags_classifier = state.get("tags_classifier")
 | 
			
		||||
                self.correspondent_classifier = state.get("correspondent_classifier")
 | 
			
		||||
                self.document_type_classifier = state.get("document_type_classifier")
 | 
			
		||||
                self.storage_path_classifier = state.get("storage_path_classifier")
 | 
			
		||||
            except IncompatibleClassifierVersionError:
 | 
			
		||||
                raise
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
                raise ClassifierModelCorruptError from err
 | 
			
		||||
 | 
			
		||||
            # Check for the warning about unpickling from differing versions
 | 
			
		||||
            # and consider it incompatible
 | 
			
		||||
@@ -178,23 +198,24 @@ class DocumentClassifier:
 | 
			
		||||
                    raise IncompatibleClassifierVersionError("sklearn version update")
 | 
			
		||||
 | 
			
		||||
    def save(self) -> None:
 | 
			
		||||
        import joblib
 | 
			
		||||
 | 
			
		||||
        target_file: Path = settings.MODEL_FILE
 | 
			
		||||
        target_file_temp: Path = target_file.with_suffix(".pickle.part")
 | 
			
		||||
        target_file_temp: Path = target_file.with_suffix(".joblib.part")
 | 
			
		||||
 | 
			
		||||
        with target_file_temp.open("wb") as f:
 | 
			
		||||
            pickle.dump(self.FORMAT_VERSION, f)
 | 
			
		||||
        state = {
 | 
			
		||||
            "format_version": self.FORMAT_VERSION,
 | 
			
		||||
            "last_doc_change_time": self.last_doc_change_time,
 | 
			
		||||
            "last_auto_type_hash": self.last_auto_type_hash,
 | 
			
		||||
            "data_vectorizer": self.data_vectorizer,
 | 
			
		||||
            "tags_binarizer": self.tags_binarizer,
 | 
			
		||||
            "tags_classifier": self.tags_classifier,
 | 
			
		||||
            "correspondent_classifier": self.correspondent_classifier,
 | 
			
		||||
            "document_type_classifier": self.document_type_classifier,
 | 
			
		||||
            "storage_path_classifier": self.storage_path_classifier,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
            pickle.dump(self.last_doc_change_time, f)
 | 
			
		||||
            pickle.dump(self.last_auto_type_hash, f)
 | 
			
		||||
 | 
			
		||||
            pickle.dump(self.data_vectorizer, f)
 | 
			
		||||
 | 
			
		||||
            pickle.dump(self.tags_binarizer, f)
 | 
			
		||||
            pickle.dump(self.tags_classifier, f)
 | 
			
		||||
 | 
			
		||||
            pickle.dump(self.correspondent_classifier, f)
 | 
			
		||||
            pickle.dump(self.document_type_classifier, f)
 | 
			
		||||
            pickle.dump(self.storage_path_classifier, f)
 | 
			
		||||
        joblib.dump(state, target_file_temp, compress=3)
 | 
			
		||||
 | 
			
		||||
        target_file_temp.rename(target_file)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user