Updates the classifier to catch warnings from scikit-learn and rebuild the model file when this happens

This commit is contained in:
Trenton Holmes
2022-06-02 13:58:38 -07:00
committed by Johann Bauer
parent ba79aff89b
commit 6bd585a9a0
4 changed files with 85 additions and 21 deletions

Binary file not shown.

View File

@@ -3,10 +3,12 @@ import tempfile
from pathlib import Path
from unittest import mock
import documents
import pytest
from django.conf import settings
from django.test import override_settings
from django.test import TestCase
from documents.classifier import ClassifierModelCorruptError
from documents.classifier import DocumentClassifier
from documents.classifier import IncompatibleClassifierVersionError
from documents.classifier import load_classifier
@@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
@mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load):
"""
GIVEN:
- Corrupted classifier pickle file
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
# First load is the schema version
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
@override_settings(
MODEL_FILE=os.path.join(
os.path.dirname(__file__),
"data",
"v1.0.2.model.pickle",
),
)
def test_load_new_scikit_learn_version(self):
"""
GIVEN:
- classifier pickle file created with a different scikit-learn version
WHEN:
- An attempt is made to load the classifier
THEN:
- The classifier reports the warning was captured and processed
"""
with self.assertRaises(IncompatibleClassifierVersionError):
self.classifier.load()
def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create(
name="c1",