diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 66958087a..9a5728124 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -1,10 +1,10 @@ -import hashlib import logging import os import pickle import re import shutil import warnings +from datetime import datetime from typing import Iterator from typing import List from typing import Optional @@ -62,12 +62,13 @@ class DocumentClassifier: # v7 - Updated scikit-learn package version # v8 - Added storage path classifier - FORMAT_VERSION = 8 + # v9 - Changed from hash to time for training data check + FORMAT_VERSION = 9 def __init__(self): - # hash of the training data. used to prevent re-training when the + # last time training data was calculated. used to prevent re-training when the # training data has not changed. - self.data_hash: Optional[bytes] = None + self.last_data_change: Optional[datetime] = None self.data_vectorizer = None self.tags_binarizer = None @@ -91,7 +92,7 @@ class DocumentClassifier: ) else: try: - self.data_hash = pickle.load(f) + self.last_data_change = pickle.load(f) self.data_vectorizer = pickle.load(f) self.tags_binarizer = pickle.load(f) @@ -121,7 +122,7 @@ class DocumentClassifier: with open(target_file_temp, "wb") as f: pickle.dump(self.FORMAT_VERSION, f) - pickle.dump(self.data_hash, f) + pickle.dump(self.last_data_change, f) pickle.dump(self.data_vectorizer, f) pickle.dump(self.tags_binarizer, f) @@ -137,35 +138,40 @@ class DocumentClassifier: def train(self): + # Get non-inbox documents + docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True) + + # No documents exit to train against + if docs_queryset.count() == 0: + raise ValueError("No training data available.") + + # No documents have changed since classifier was trained + latest_doc_change = docs_queryset.latest("modified").modified + if ( + self.last_data_change is not None + and self.last_data_change >= latest_doc_change + ): + return False + labels_tags = [] labels_correspondent = [] labels_document_type = [] labels_storage_path = [] - docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True) - - if docs_queryset.count() == 0: - raise ValueError("No training data available.") - # Step 1: Extract and preprocess training data from the database. logger.debug("Gathering data from database...") - m = hashlib.sha1() for doc in docs_queryset: - preprocessed_content = self.preprocess_content(doc.content) - m.update(preprocessed_content.encode("utf-8")) y = -1 dt = doc.document_type if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO: y = dt.pk - m.update(y.to_bytes(4, "little", signed=True)) labels_document_type.append(y) y = -1 cor = doc.correspondent if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO: y = cor.pk - m.update(y.to_bytes(4, "little", signed=True)) labels_correspondent.append(y) tags = sorted( @@ -174,22 +180,14 @@ class DocumentClassifier: matching_algorithm=MatchingModel.MATCH_AUTO, ) ) - for tag in tags: - m.update(tag.to_bytes(4, "little", signed=True)) labels_tags.append(tags) y = -1 sd = doc.storage_path if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO: y = sd.pk - m.update(y.to_bytes(4, "little", signed=True)) labels_storage_path.append(y) - new_data_hash = m.digest() - - if self.data_hash and new_data_hash == self.data_hash: - return False - labels_tags_unique = {tag for tags in labels_tags for tag in tags} num_tags = len(labels_tags_unique) @@ -216,12 +214,16 @@ class DocumentClassifier: from sklearn.feature_extraction.text import CountVectorizer from sklearn.neural_network import MLPClassifier - from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer + from sklearn.preprocessing import LabelBinarizer + from sklearn.preprocessing import MultiLabelBinarizer # Step 2: vectorize data logger.debug("Vectorizing data...") def content_generator() -> Iterator[str]: + """ + Generates the content for documents, but once at a time + """ for doc in docs_queryset: yield self.preprocess_content(doc.content) @@ -299,7 +301,7 @@ class DocumentClassifier: "There are no storage paths. Not training storage path classifier.", ) - self.data_hash = new_data_hash + self.last_data_change = latest_doc_change return True diff --git a/src/documents/tests/data/model.pickle b/src/documents/tests/data/model.pickle deleted file mode 100644 index ff88b8894..000000000 Binary files a/src/documents/tests/data/model.pickle and /dev/null differ diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 057f67204..7653fd861 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,7 +1,5 @@ import os import re -import shutil -import tempfile from pathlib import Path from unittest import mock @@ -22,15 +20,15 @@ from documents.tests.utils import DirectoriesMixin def dummy_preprocess(content: str): + """ + Simpler, faster pre-processing for testing purposes + """ content = content.lower().strip() content = re.sub(r"\s+", " ", content) return content class TestClassifier(DirectoriesMixin, TestCase): - - SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle") - def setUp(self): super().setUp() self.classifier = DocumentClassifier() @@ -111,17 +109,68 @@ class TestClassifier(DirectoriesMixin, TestCase): self.doc1.storage_path = self.sp1 - def testNoTrainingData(self): - try: + def generate_train_and_save(self): + """ + Generates the training data, trains and saves the updated pickle + file. This ensures the test is using the same scikit learn version + and eliminates a warning from the test suite + """ + self.generate_test_data() + self.classifier.train() + self.classifier.save() + + def test_no_training_data(self): + """ + GIVEN: + - No documents exist to train + WHEN: + - Classifier training is requested + THEN: + - Exception is raised + """ + with self.assertRaisesMessage(ValueError, "No training data available."): + self.classifier.train() + + def test_no_non_inbox_tags(self): + """ + GIVEN: + - No documents without an inbox tag exist + WHEN: + - Classifier training is requested + THEN: + - Exception is raised + """ + + t1 = Tag.objects.create( + name="t1", + matching_algorithm=Tag.MATCH_ANY, + pk=34, + is_inbox_tag=True, + ) + + doc1 = Document.objects.create( + title="doc1", + content="this is a document from c1", + checksum="A", + ) + doc1.tags.add(t1) + + with self.assertRaisesMessage(ValueError, "No training data available."): self.classifier.train() - except ValueError as e: - self.assertEqual(str(e), "No training data available.") - else: - self.fail("Should raise exception") def testEmpty(self): + """ + GIVEN: + - A document exists + - No tags/not enough data to predict + WHEN: + - Classifier prediction is requested + THEN: + - Classifier returns no predictions + """ Document.objects.create(title="WOW", checksum="3457", content="ASD") self.classifier.train() + self.assertIsNone(self.classifier.document_type_classifier) self.assertIsNone(self.classifier.tags_classifier) self.assertIsNone(self.classifier.correspondent_classifier) @@ -131,8 +180,18 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertIsNone(self.classifier.predict_correspondent("")) def testTrain(self): + """ + GIVEN: + - Test data + WHEN: + - Classifier is trained + THEN: + - Classifier uses correct values for correspondent learning + - Classifier uses correct values for tags learning + """ self.generate_test_data() self.classifier.train() + self.assertListEqual( list(self.classifier.correspondent_classifier.classes_), [-1, self.c1.pk], @@ -143,8 +202,17 @@ class TestClassifier(DirectoriesMixin, TestCase): ) def testPredict(self): + """ + GIVEN: + - Classifier trained against test data + WHEN: + - Prediction requested for correspondent, tags, type + THEN: + - Expected predictions based on training set + """ self.generate_test_data() self.classifier.train() + self.assertEqual( self.classifier.predict_correspondent(self.doc1.content), self.c1.pk, @@ -164,20 +232,51 @@ class TestClassifier(DirectoriesMixin, TestCase): ) self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None) - def testDatasetHashing(self): + def test_no_retrain_if_no_change(self): + """ + GIVEN: + - Classifier trained with current data + WHEN: + - Classifier training is requested again + THEN: + - Classifier does not redo training + """ self.generate_test_data() self.assertTrue(self.classifier.train()) self.assertFalse(self.classifier.train()) + def test_retrain_if_change(self): + """ + GIVEN: + - Classifier trained with current data + WHEN: + - Classifier training is requested again + - Documents have changed + THEN: + - Classifier does not redo training + """ + + self.generate_test_data() + + self.assertTrue(self.classifier.train()) + + self.doc1.correspondent = self.c2 + self.doc1.save() + + self.assertTrue(self.classifier.train()) + def testVersionIncreased(self): - - self.generate_test_data() - self.assertTrue(self.classifier.train()) - self.assertFalse(self.classifier.train()) - - self.classifier.save() + """ + GIVEN: + - Existing classifier model saved at a version + WHEN: + - Attempt to load classifier file from newer version + THEN: + - Exception is raised + """ + self.generate_train_and_save() classifier2 = DocumentClassifier() @@ -194,14 +293,9 @@ class TestClassifier(DirectoriesMixin, TestCase): # assure that we can load the classifier after saving it. classifier2.load() - @override_settings(DATA_DIR=tempfile.mkdtemp()) def testSaveClassifier(self): - self.generate_test_data() - - self.classifier.train() - - self.classifier.save() + self.generate_train_and_save() new_classifier = DocumentClassifier() new_classifier.load() @@ -209,25 +303,9 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertFalse(new_classifier.train()) - # @override_settings( - # MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"), - # ) - # def test_create_test_load_and_classify(self): - # self.generate_test_data() - # self.classifier.train() - # self.classifier.save() - def test_load_and_classify(self): - # Generate test data, train and save to the model file - # This ensures the model file sklearn version matches - # and eliminates a warning - shutil.copy( - self.SAMPLE_MODEL_FILE, - os.path.join(self.dirs.data_dir, "classification_model.pickle"), - ) - self.generate_test_data() - self.classifier.train() - self.classifier.save() + + self.generate_train_and_save() new_classifier = DocumentClassifier() new_classifier.load() @@ -245,11 +323,9 @@ class TestClassifier(DirectoriesMixin, TestCase): THEN: - The ClassifierModelCorruptError is raised """ - shutil.copy( - self.SAMPLE_MODEL_FILE, - os.path.join(self.dirs.data_dir, "classification_model.pickle"), - ) - # First load is the schema version + self.generate_train_and_save() + + # First load is the schema version,allow it patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] with self.assertRaises(ClassifierModelCorruptError):