mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-16 00:36:22 +00:00
Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens
This commit is contained in:

committed by
Johann Bauer

parent
1aeb95396b
commit
77fbbe95ff
@@ -4,6 +4,8 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
from documents.models import Document
|
||||
@@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception):
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
|
||||
def preprocess_content(content):
|
||||
def preprocess_content(content: str) -> str:
|
||||
content = content.lower().strip()
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
return content
|
||||
|
||||
|
||||
def load_classifier():
|
||||
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
if not os.path.isfile(settings.MODEL_FILE):
|
||||
logger.debug(
|
||||
"Document classification model does not exist (yet), not "
|
||||
@@ -39,7 +41,11 @@ def load_classifier():
|
||||
try:
|
||||
classifier.load()
|
||||
|
||||
except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
|
||||
except IncompatibleClassifierVersionError:
|
||||
logger.info("Classifier version updated, will re-train")
|
||||
os.unlink(settings.MODEL_FILE)
|
||||
classifier = None
|
||||
except ClassifierModelCorruptError:
|
||||
# there's something wrong with the model file.
|
||||
logger.exception(
|
||||
"Unrecoverable error while loading document "
|
||||
@@ -59,13 +65,14 @@ def load_classifier():
|
||||
|
||||
class DocumentClassifier:
|
||||
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
FORMAT_VERSION = 8
|
||||
|
||||
def __init__(self):
|
||||
# hash of the training data. used to prevent re-training when the
|
||||
# training data has not changed.
|
||||
self.data_hash = None
|
||||
self.data_hash: Optional[bytes] = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.tags_binarizer = None
|
||||
@@ -75,25 +82,41 @@ class DocumentClassifier:
|
||||
self.storage_path_classifier = None
|
||||
|
||||
def load(self):
|
||||
with open(settings.MODEL_FILE, "rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with open(settings.MODEL_FILE, "rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.data_hash = pickle.load(f)
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
except Exception:
|
||||
raise ClassifierModelCorruptError()
|
||||
|
||||
# Check for the warning about unpickling from differing versions
|
||||
# and consider it incompatible
|
||||
if len(w) > 0:
|
||||
sk_learn_warning_url = (
|
||||
"https://scikit-learn.org/stable/"
|
||||
"model_persistence.html"
|
||||
"#security-maintainability-limitations"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.data_hash = pickle.load(f)
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
except Exception:
|
||||
raise ClassifierModelCorruptError()
|
||||
for warning in w:
|
||||
if issubclass(warning.category, UserWarning):
|
||||
w_msg = str(warning.message)
|
||||
if sk_learn_warning_url in w_msg:
|
||||
raise IncompatibleClassifierVersionError()
|
||||
|
||||
def save(self):
|
||||
target_file = settings.MODEL_FILE
|
||||
|
Reference in New Issue
Block a user