Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens

This commit is contained in:
Trenton Holmes
2022-06-02 13:58:38 -07:00
committed by Johann Bauer
parent 1aeb95396b
commit 77fbbe95ff
4 changed files with 85 additions and 21 deletions

View File

@@ -4,6 +4,8 @@ import os
import pickle
import re
import shutil
import warnings
from typing import Optional
from django.conf import settings
from documents.models import Document
@@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception):
logger = logging.getLogger("paperless.classifier")
def preprocess_content(content):
def preprocess_content(content: str) -> str:
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
def load_classifier():
def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
"Document classification model does not exist (yet), not "
@@ -39,7 +41,11 @@ def load_classifier():
try:
classifier.load()
except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
except IncompatibleClassifierVersionError:
logger.info("Classifier version updated, will re-train")
os.unlink(settings.MODEL_FILE)
classifier = None
except ClassifierModelCorruptError:
# there's something wrong with the model file.
logger.exception(
"Unrecoverable error while loading document "
@@ -59,13 +65,14 @@ def load_classifier():
class DocumentClassifier:
# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
FORMAT_VERSION = 8
def __init__(self):
# hash of the training data. used to prevent re-training when the
# training data has not changed.
self.data_hash = None
self.data_hash: Optional[bytes] = None
self.data_vectorizer = None
self.tags_binarizer = None
@@ -75,25 +82,41 @@ class DocumentClassifier:
self.storage_path_classifier = None
def load(self):
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
else:
try:
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception:
raise ClassifierModelCorruptError()
# Check for the warning about unpickling from differing versions
# and consider it incompatible
if len(w) > 0:
sk_learn_warning_url = (
"https://scikit-learn.org/stable/"
"model_persistence.html"
"#security-maintainability-limitations"
)
else:
try:
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception:
raise ClassifierModelCorruptError()
for warning in w:
if issubclass(warning.category, UserWarning):
w_msg = str(warning.message)
if sk_learn_warning_url in w_msg:
raise IncompatibleClassifierVersionError()
def save(self):
target_file = settings.MODEL_FILE