mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens
This commit is contained in:

committed by
Johann Bauer

parent
ba79aff89b
commit
6bd585a9a0
Binary file not shown.
BIN
src/documents/tests/data/v1.0.2.model.pickle
Normal file
BIN
src/documents/tests/data/v1.0.2.model.pickle
Normal file
Binary file not shown.
@@ -3,10 +3,12 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import documents
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.test import override_settings
|
||||
from django.test import TestCase
|
||||
from documents.classifier import ClassifierModelCorruptError
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.classifier import IncompatibleClassifierVersionError
|
||||
from documents.classifier import load_classifier
|
||||
@@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||
|
||||
@override_settings(
|
||||
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
||||
)
|
||||
@mock.patch("documents.classifier.pickle.load")
|
||||
def test_load_corrupt_file(self, patched_pickle_load):
|
||||
"""
|
||||
GIVEN:
|
||||
- Corrupted classifier pickle file
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised
|
||||
"""
|
||||
# First load is the schema version
|
||||
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
@override_settings(
|
||||
MODEL_FILE=os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"data",
|
||||
"v1.0.2.model.pickle",
|
||||
),
|
||||
)
|
||||
def test_load_new_scikit_learn_version(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- classifier pickle file created with a different scikit-learn version
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The classifier reports the warning was captured and processed
|
||||
"""
|
||||
|
||||
with self.assertRaises(IncompatibleClassifierVersionError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_one_correspondent_predict(self):
|
||||
c1 = Correspondent.objects.create(
|
||||
name="c1",
|
||||
|
Reference in New Issue
Block a user