mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-10-22 03:16:15 -05:00
Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens
This commit is contained in:

committed by
Johann Bauer

parent
ba79aff89b
commit
6bd585a9a0
@@ -4,6 +4,8 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
@@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception):
|
|||||||
logger = logging.getLogger("paperless.classifier")
|
logger = logging.getLogger("paperless.classifier")
|
||||||
|
|
||||||
|
|
||||||
def preprocess_content(content):
|
def preprocess_content(content: str) -> str:
|
||||||
content = content.lower().strip()
|
content = content.lower().strip()
|
||||||
content = re.sub(r"\s+", " ", content)
|
content = re.sub(r"\s+", " ", content)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def load_classifier():
|
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||||
if not os.path.isfile(settings.MODEL_FILE):
|
if not os.path.isfile(settings.MODEL_FILE):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Document classification model does not exist (yet), not "
|
"Document classification model does not exist (yet), not "
|
||||||
@@ -39,7 +41,11 @@ def load_classifier():
|
|||||||
try:
|
try:
|
||||||
classifier.load()
|
classifier.load()
|
||||||
|
|
||||||
except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
|
except IncompatibleClassifierVersionError:
|
||||||
|
logger.info("Classifier version updated, will re-train")
|
||||||
|
os.unlink(settings.MODEL_FILE)
|
||||||
|
classifier = None
|
||||||
|
except ClassifierModelCorruptError:
|
||||||
# there's something wrong with the model file.
|
# there's something wrong with the model file.
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Unrecoverable error while loading document "
|
"Unrecoverable error while loading document "
|
||||||
@@ -59,13 +65,14 @@ def load_classifier():
|
|||||||
|
|
||||||
class DocumentClassifier:
|
class DocumentClassifier:
|
||||||
|
|
||||||
|
# v7 - Updated scikit-learn package version
|
||||||
# v8 - Added storage path classifier
|
# v8 - Added storage path classifier
|
||||||
FORMAT_VERSION = 8
|
FORMAT_VERSION = 8
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# hash of the training data. used to prevent re-training when the
|
# hash of the training data. used to prevent re-training when the
|
||||||
# training data has not changed.
|
# training data has not changed.
|
||||||
self.data_hash = None
|
self.data_hash: Optional[bytes] = None
|
||||||
|
|
||||||
self.data_vectorizer = None
|
self.data_vectorizer = None
|
||||||
self.tags_binarizer = None
|
self.tags_binarizer = None
|
||||||
@@ -75,6 +82,8 @@ class DocumentClassifier:
|
|||||||
self.storage_path_classifier = None
|
self.storage_path_classifier = None
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
# Catch warnings for processing
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
with open(settings.MODEL_FILE, "rb") as f:
|
with open(settings.MODEL_FILE, "rb") as f:
|
||||||
schema_version = pickle.load(f)
|
schema_version = pickle.load(f)
|
||||||
|
|
||||||
@@ -95,6 +104,20 @@ class DocumentClassifier:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise ClassifierModelCorruptError()
|
raise ClassifierModelCorruptError()
|
||||||
|
|
||||||
|
# Check for the warning about unpickling from differing versions
|
||||||
|
# and consider it incompatible
|
||||||
|
if len(w) > 0:
|
||||||
|
sk_learn_warning_url = (
|
||||||
|
"https://scikit-learn.org/stable/"
|
||||||
|
"model_persistence.html"
|
||||||
|
"#security-maintainability-limitations"
|
||||||
|
)
|
||||||
|
for warning in w:
|
||||||
|
if issubclass(warning.category, UserWarning):
|
||||||
|
w_msg = str(warning.message)
|
||||||
|
if sk_learn_warning_url in w_msg:
|
||||||
|
raise IncompatibleClassifierVersionError()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
target_file = settings.MODEL_FILE
|
target_file = settings.MODEL_FILE
|
||||||
target_file_temp = settings.MODEL_FILE + ".part"
|
target_file_temp = settings.MODEL_FILE + ".part"
|
||||||
|
Binary file not shown.
BIN
src/documents/tests/data/v1.0.2.model.pickle
Normal file
BIN
src/documents/tests/data/v1.0.2.model.pickle
Normal file
Binary file not shown.
@@ -3,10 +3,12 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import documents
|
||||||
import pytest
|
import pytest
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from documents.classifier import ClassifierModelCorruptError
|
||||||
from documents.classifier import DocumentClassifier
|
from documents.classifier import DocumentClassifier
|
||||||
from documents.classifier import IncompatibleClassifierVersionError
|
from documents.classifier import IncompatibleClassifierVersionError
|
||||||
from documents.classifier import load_classifier
|
from documents.classifier import load_classifier
|
||||||
@@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
||||||
|
)
|
||||||
|
@mock.patch("documents.classifier.pickle.load")
|
||||||
|
def test_load_corrupt_file(self, patched_pickle_load):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Corrupted classifier pickle file
|
||||||
|
WHEN:
|
||||||
|
- An attempt is made to load the classifier
|
||||||
|
THEN:
|
||||||
|
- The ClassifierModelCorruptError is raised
|
||||||
|
"""
|
||||||
|
# First load is the schema version
|
||||||
|
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||||
|
|
||||||
|
with self.assertRaises(ClassifierModelCorruptError):
|
||||||
|
self.classifier.load()
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
MODEL_FILE=os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"data",
|
||||||
|
"v1.0.2.model.pickle",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_load_new_scikit_learn_version(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- classifier pickle file created with a different scikit-learn version
|
||||||
|
WHEN:
|
||||||
|
- An attempt is made to load the classifier
|
||||||
|
THEN:
|
||||||
|
- The classifier reports the warning was captured and processed
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self.assertRaises(IncompatibleClassifierVersionError):
|
||||||
|
self.classifier.load()
|
||||||
|
|
||||||
def test_one_correspondent_predict(self):
|
def test_one_correspondent_predict(self):
|
||||||
c1 = Correspondent.objects.create(
|
c1 = Correspondent.objects.create(
|
||||||
name="c1",
|
name="c1",
|
||||||
|
Reference in New Issue
Block a user