Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens

This commit is contained in:
Trenton Holmes 2022-06-02 13:58:38 -07:00 committed by Johann Bauer
parent 1aeb95396b
commit 77fbbe95ff
4 changed files with 85 additions and 21 deletions

View File

@ -4,6 +4,8 @@ import os
import pickle import pickle
import re import re
import shutil import shutil
import warnings
from typing import Optional
from django.conf import settings from django.conf import settings
from documents.models import Document from documents.models import Document
@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception):
logger = logging.getLogger("paperless.classifier") logger = logging.getLogger("paperless.classifier")
def preprocess_content(content): def preprocess_content(content: str) -> str:
content = content.lower().strip() content = content.lower().strip()
content = re.sub(r"\s+", " ", content) content = re.sub(r"\s+", " ", content)
return content return content
def load_classifier(): def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE): if not os.path.isfile(settings.MODEL_FILE):
logger.debug( logger.debug(
"Document classification model does not exist (yet), not " "Document classification model does not exist (yet), not "
@ -39,7 +41,11 @@ def load_classifier():
try: try:
classifier.load() classifier.load()
except (ClassifierModelCorruptError, IncompatibleClassifierVersionError): except IncompatibleClassifierVersionError:
logger.info("Classifier version updated, will re-train")
os.unlink(settings.MODEL_FILE)
classifier = None
except ClassifierModelCorruptError:
# there's something wrong with the model file. # there's something wrong with the model file.
logger.exception( logger.exception(
"Unrecoverable error while loading document " "Unrecoverable error while loading document "
@ -59,13 +65,14 @@ def load_classifier():
class DocumentClassifier: class DocumentClassifier:
# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier # v8 - Added storage path classifier
FORMAT_VERSION = 8 FORMAT_VERSION = 8
def __init__(self): def __init__(self):
# hash of the training data. used to prevent re-training when the # hash of the training data. used to prevent re-training when the
# training data has not changed. # training data has not changed.
self.data_hash = None self.data_hash: Optional[bytes] = None
self.data_vectorizer = None self.data_vectorizer = None
self.tags_binarizer = None self.tags_binarizer = None
@ -75,25 +82,41 @@ class DocumentClassifier:
self.storage_path_classifier = None self.storage_path_classifier = None
def load(self): def load(self):
with open(settings.MODEL_FILE, "rb") as f: # Catch warnings for processing
schema_version = pickle.load(f) with warnings.catch_warnings(record=True) as w:
with open(settings.MODEL_FILE, "rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION: if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError( raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.", "Cannot load classifier, incompatible versions.",
)
else:
try:
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception:
raise ClassifierModelCorruptError()
# Check for the warning about unpickling from differing versions
# and consider it incompatible
if len(w) > 0:
sk_learn_warning_url = (
"https://scikit-learn.org/stable/"
"model_persistence.html"
"#security-maintainability-limitations"
) )
else: for warning in w:
try: if issubclass(warning.category, UserWarning):
self.data_hash = pickle.load(f) w_msg = str(warning.message)
self.data_vectorizer = pickle.load(f) if sk_learn_warning_url in w_msg:
self.tags_binarizer = pickle.load(f) raise IncompatibleClassifierVersionError()
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception:
raise ClassifierModelCorruptError()
def save(self): def save(self):
target_file = settings.MODEL_FILE target_file = settings.MODEL_FILE

Binary file not shown.

View File

@ -3,10 +3,12 @@ import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
import documents
import pytest import pytest
from django.conf import settings from django.conf import settings
from django.test import override_settings from django.test import override_settings
from django.test import TestCase from django.test import TestCase
from documents.classifier import ClassifierModelCorruptError
from documents.classifier import DocumentClassifier from documents.classifier import DocumentClassifier
from documents.classifier import IncompatibleClassifierVersionError from documents.classifier import IncompatibleClassifierVersionError
from documents.classifier import load_classifier from documents.classifier import load_classifier
@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
@mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load):
"""
GIVEN:
- Corrupted classifier pickle file
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
# First load is the schema version
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
@override_settings(
MODEL_FILE=os.path.join(
os.path.dirname(__file__),
"data",
"v1.0.2.model.pickle",
),
)
def test_load_new_scikit_learn_version(self):
"""
GIVEN:
- classifier pickle file created with a different scikit-learn version
WHEN:
- An attempt is made to load the classifier
THEN:
- The classifier reports the warning was captured and processed
"""
with self.assertRaises(IncompatibleClassifierVersionError):
self.classifier.load()
def test_one_correspondent_predict(self): def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create( c1 = Correspondent.objects.create(
name="c1", name="c1",