mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-30 18:27:45 -05:00
Changes from a hash based system to a time based system to prevent extra retrains
This commit is contained in:

committed by
Trenton H

parent
21cd76a181
commit
303e81eb79
Binary file not shown.
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
@@ -22,15 +20,15 @@ from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
def dummy_preprocess(content: str):
|
||||
"""
|
||||
Simpler, faster pre-processing for testing purposes
|
||||
"""
|
||||
content = content.lower().strip()
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
return content
|
||||
|
||||
|
||||
class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.classifier = DocumentClassifier()
|
||||
@@ -111,17 +109,68 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.doc1.storage_path = self.sp1
|
||||
|
||||
def testNoTrainingData(self):
|
||||
try:
|
||||
def generate_train_and_save(self):
|
||||
"""
|
||||
Generates the training data, trains and saves the updated pickle
|
||||
file. This ensures the test is using the same scikit learn version
|
||||
and eliminates a warning from the test suite
|
||||
"""
|
||||
self.generate_test_data()
|
||||
self.classifier.train()
|
||||
self.classifier.save()
|
||||
|
||||
def test_no_training_data(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- No documents exist to train
|
||||
WHEN:
|
||||
- Classifier training is requested
|
||||
THEN:
|
||||
- Exception is raised
|
||||
"""
|
||||
with self.assertRaisesMessage(ValueError, "No training data available."):
|
||||
self.classifier.train()
|
||||
|
||||
def test_no_non_inbox_tags(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- No documents without an inbox tag exist
|
||||
WHEN:
|
||||
- Classifier training is requested
|
||||
THEN:
|
||||
- Exception is raised
|
||||
"""
|
||||
|
||||
t1 = Tag.objects.create(
|
||||
name="t1",
|
||||
matching_algorithm=Tag.MATCH_ANY,
|
||||
pk=34,
|
||||
is_inbox_tag=True,
|
||||
)
|
||||
|
||||
doc1 = Document.objects.create(
|
||||
title="doc1",
|
||||
content="this is a document from c1",
|
||||
checksum="A",
|
||||
)
|
||||
doc1.tags.add(t1)
|
||||
|
||||
with self.assertRaisesMessage(ValueError, "No training data available."):
|
||||
self.classifier.train()
|
||||
except ValueError as e:
|
||||
self.assertEqual(str(e), "No training data available.")
|
||||
else:
|
||||
self.fail("Should raise exception")
|
||||
|
||||
def testEmpty(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- A document exists
|
||||
- No tags/not enough data to predict
|
||||
WHEN:
|
||||
- Classifier prediction is requested
|
||||
THEN:
|
||||
- Classifier returns no predictions
|
||||
"""
|
||||
Document.objects.create(title="WOW", checksum="3457", content="ASD")
|
||||
self.classifier.train()
|
||||
|
||||
self.assertIsNone(self.classifier.document_type_classifier)
|
||||
self.assertIsNone(self.classifier.tags_classifier)
|
||||
self.assertIsNone(self.classifier.correspondent_classifier)
|
||||
@@ -131,8 +180,18 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.assertIsNone(self.classifier.predict_correspondent(""))
|
||||
|
||||
def testTrain(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Test data
|
||||
WHEN:
|
||||
- Classifier is trained
|
||||
THEN:
|
||||
- Classifier uses correct values for correspondent learning
|
||||
- Classifier uses correct values for tags learning
|
||||
"""
|
||||
self.generate_test_data()
|
||||
self.classifier.train()
|
||||
|
||||
self.assertListEqual(
|
||||
list(self.classifier.correspondent_classifier.classes_),
|
||||
[-1, self.c1.pk],
|
||||
@@ -143,8 +202,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
def testPredict(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Classifier trained against test data
|
||||
WHEN:
|
||||
- Prediction requested for correspondent, tags, type
|
||||
THEN:
|
||||
- Expected predictions based on training set
|
||||
"""
|
||||
self.generate_test_data()
|
||||
self.classifier.train()
|
||||
|
||||
self.assertEqual(
|
||||
self.classifier.predict_correspondent(self.doc1.content),
|
||||
self.c1.pk,
|
||||
@@ -164,20 +232,51 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
)
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||
|
||||
def testDatasetHashing(self):
|
||||
def test_no_retrain_if_no_change(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Classifier trained with current data
|
||||
WHEN:
|
||||
- Classifier training is requested again
|
||||
THEN:
|
||||
- Classifier does not redo training
|
||||
"""
|
||||
|
||||
self.generate_test_data()
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
def test_retrain_if_change(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Classifier trained with current data
|
||||
WHEN:
|
||||
- Classifier training is requested again
|
||||
- Documents have changed
|
||||
THEN:
|
||||
- Classifier does not redo training
|
||||
"""
|
||||
|
||||
self.generate_test_data()
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
|
||||
self.doc1.correspondent = self.c2
|
||||
self.doc1.save()
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
|
||||
def testVersionIncreased(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
self.classifier.save()
|
||||
"""
|
||||
GIVEN:
|
||||
- Existing classifier model saved at a version
|
||||
WHEN:
|
||||
- Attempt to load classifier file from newer version
|
||||
THEN:
|
||||
- Exception is raised
|
||||
"""
|
||||
self.generate_train_and_save()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
|
||||
@@ -194,14 +293,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
# assure that we can load the classifier after saving it.
|
||||
classifier2.load()
|
||||
|
||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||
def testSaveClassifier(self):
|
||||
|
||||
self.generate_test_data()
|
||||
|
||||
self.classifier.train()
|
||||
|
||||
self.classifier.save()
|
||||
self.generate_train_and_save()
|
||||
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.load()
|
||||
@@ -209,25 +303,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.assertFalse(new_classifier.train())
|
||||
|
||||
# @override_settings(
|
||||
# MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
||||
# )
|
||||
# def test_create_test_load_and_classify(self):
|
||||
# self.generate_test_data()
|
||||
# self.classifier.train()
|
||||
# self.classifier.save()
|
||||
|
||||
def test_load_and_classify(self):
|
||||
# Generate test data, train and save to the model file
|
||||
# This ensures the model file sklearn version matches
|
||||
# and eliminates a warning
|
||||
shutil.copy(
|
||||
self.SAMPLE_MODEL_FILE,
|
||||
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
|
||||
)
|
||||
self.generate_test_data()
|
||||
self.classifier.train()
|
||||
self.classifier.save()
|
||||
|
||||
self.generate_train_and_save()
|
||||
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.load()
|
||||
@@ -245,11 +323,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised
|
||||
"""
|
||||
shutil.copy(
|
||||
self.SAMPLE_MODEL_FILE,
|
||||
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
|
||||
)
|
||||
# First load is the schema version
|
||||
self.generate_train_and_save()
|
||||
|
||||
# First load is the schema version,allow it
|
||||
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
|
Reference in New Issue
Block a user