Changes from a hash based system to a time based system to prevent extra retrains

This commit is contained in:
Trenton Holmes 2023-02-22 20:03:23 -08:00 committed by Trenton H
parent 8709ea4df0
commit c958a7c593
3 changed files with 151 additions and 73 deletions

View File

@ -1,10 +1,10 @@
import hashlib
import logging import logging
import os import os
import pickle import pickle
import re import re
import shutil import shutil
import warnings import warnings
from datetime import datetime
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -62,12 +62,13 @@ class DocumentClassifier:
# v7 - Updated scikit-learn package version # v7 - Updated scikit-learn package version
# v8 - Added storage path classifier # v8 - Added storage path classifier
FORMAT_VERSION = 8 # v9 - Changed from hash to time for training data check
FORMAT_VERSION = 9
def __init__(self): def __init__(self):
# hash of the training data. used to prevent re-training when the # last time training data was calculated. used to prevent re-training when the
# training data has not changed. # training data has not changed.
self.data_hash: Optional[bytes] = None self.last_data_change: Optional[datetime] = None
self.data_vectorizer = None self.data_vectorizer = None
self.tags_binarizer = None self.tags_binarizer = None
@ -91,7 +92,7 @@ class DocumentClassifier:
) )
else: else:
try: try:
self.data_hash = pickle.load(f) self.last_data_change = pickle.load(f)
self.data_vectorizer = pickle.load(f) self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f) self.tags_binarizer = pickle.load(f)
@ -121,7 +122,7 @@ class DocumentClassifier:
with open(target_file_temp, "wb") as f: with open(target_file_temp, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f) pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.data_hash, f) pickle.dump(self.last_data_change, f)
pickle.dump(self.data_vectorizer, f) pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f) pickle.dump(self.tags_binarizer, f)
@ -137,35 +138,40 @@ class DocumentClassifier:
def train(self): def train(self):
# Get non-inbox documents
docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
# No documents exit to train against
if docs_queryset.count() == 0:
raise ValueError("No training data available.")
# No documents have changed since classifier was trained
latest_doc_change = docs_queryset.latest("modified").modified
if (
self.last_data_change is not None
and self.last_data_change >= latest_doc_change
):
return False
labels_tags = [] labels_tags = []
labels_correspondent = [] labels_correspondent = []
labels_document_type = [] labels_document_type = []
labels_storage_path = [] labels_storage_path = []
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
if docs_queryset.count() == 0:
raise ValueError("No training data available.")
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...") logger.debug("Gathering data from database...")
m = hashlib.sha1()
for doc in docs_queryset: for doc in docs_queryset:
preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8"))
y = -1 y = -1
dt = doc.document_type dt = doc.document_type
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO: if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
y = dt.pk y = dt.pk
m.update(y.to_bytes(4, "little", signed=True))
labels_document_type.append(y) labels_document_type.append(y)
y = -1 y = -1
cor = doc.correspondent cor = doc.correspondent
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO: if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
y = cor.pk y = cor.pk
m.update(y.to_bytes(4, "little", signed=True))
labels_correspondent.append(y) labels_correspondent.append(y)
tags = sorted( tags = sorted(
@ -174,22 +180,14 @@ class DocumentClassifier:
matching_algorithm=MatchingModel.MATCH_AUTO, matching_algorithm=MatchingModel.MATCH_AUTO,
) )
) )
for tag in tags:
m.update(tag.to_bytes(4, "little", signed=True))
labels_tags.append(tags) labels_tags.append(tags)
y = -1 y = -1
sd = doc.storage_path sd = doc.storage_path
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO: if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
y = sd.pk y = sd.pk
m.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y) labels_storage_path.append(y)
new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash:
return False
labels_tags_unique = {tag for tags in labels_tags for tag in tags} labels_tags_unique = {tag for tags in labels_tags for tag in tags}
num_tags = len(labels_tags_unique) num_tags = len(labels_tags_unique)
@ -216,12 +214,16 @@ class DocumentClassifier:
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
# Step 2: vectorize data # Step 2: vectorize data
logger.debug("Vectorizing data...") logger.debug("Vectorizing data...")
def content_generator() -> Iterator[str]: def content_generator() -> Iterator[str]:
"""
Generates the content for documents, but once at a time
"""
for doc in docs_queryset: for doc in docs_queryset:
yield self.preprocess_content(doc.content) yield self.preprocess_content(doc.content)
@ -299,7 +301,7 @@ class DocumentClassifier:
"There are no storage paths. Not training storage path classifier.", "There are no storage paths. Not training storage path classifier.",
) )
self.data_hash = new_data_hash self.last_data_change = latest_doc_change
return True return True

View File

@ -1,7 +1,5 @@
import os import os
import re import re
import shutil
import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
@ -22,15 +20,15 @@ from documents.tests.utils import DirectoriesMixin
def dummy_preprocess(content: str): def dummy_preprocess(content: str):
"""
Simpler, faster pre-processing for testing purposes
"""
content = content.lower().strip() content = content.lower().strip()
content = re.sub(r"\s+", " ", content) content = re.sub(r"\s+", " ", content)
return content return content
class TestClassifier(DirectoriesMixin, TestCase): class TestClassifier(DirectoriesMixin, TestCase):
SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.classifier = DocumentClassifier() self.classifier = DocumentClassifier()
@ -111,17 +109,68 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.doc1.storage_path = self.sp1 self.doc1.storage_path = self.sp1
def testNoTrainingData(self): def generate_train_and_save(self):
try: """
Generates the training data, trains and saves the updated pickle
file. This ensures the test is using the same scikit learn version
and eliminates a warning from the test suite
"""
self.generate_test_data()
self.classifier.train()
self.classifier.save()
def test_no_training_data(self):
"""
GIVEN:
- No documents exist to train
WHEN:
- Classifier training is requested
THEN:
- Exception is raised
"""
with self.assertRaisesMessage(ValueError, "No training data available."):
self.classifier.train()
def test_no_non_inbox_tags(self):
"""
GIVEN:
- No documents without an inbox tag exist
WHEN:
- Classifier training is requested
THEN:
- Exception is raised
"""
t1 = Tag.objects.create(
name="t1",
matching_algorithm=Tag.MATCH_ANY,
pk=34,
is_inbox_tag=True,
)
doc1 = Document.objects.create(
title="doc1",
content="this is a document from c1",
checksum="A",
)
doc1.tags.add(t1)
with self.assertRaisesMessage(ValueError, "No training data available."):
self.classifier.train() self.classifier.train()
except ValueError as e:
self.assertEqual(str(e), "No training data available.")
else:
self.fail("Should raise exception")
def testEmpty(self): def testEmpty(self):
"""
GIVEN:
- A document exists
- No tags/not enough data to predict
WHEN:
- Classifier prediction is requested
THEN:
- Classifier returns no predictions
"""
Document.objects.create(title="WOW", checksum="3457", content="ASD") Document.objects.create(title="WOW", checksum="3457", content="ASD")
self.classifier.train() self.classifier.train()
self.assertIsNone(self.classifier.document_type_classifier) self.assertIsNone(self.classifier.document_type_classifier)
self.assertIsNone(self.classifier.tags_classifier) self.assertIsNone(self.classifier.tags_classifier)
self.assertIsNone(self.classifier.correspondent_classifier) self.assertIsNone(self.classifier.correspondent_classifier)
@ -131,8 +180,18 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertIsNone(self.classifier.predict_correspondent("")) self.assertIsNone(self.classifier.predict_correspondent(""))
def testTrain(self): def testTrain(self):
"""
GIVEN:
- Test data
WHEN:
- Classifier is trained
THEN:
- Classifier uses correct values for correspondent learning
- Classifier uses correct values for tags learning
"""
self.generate_test_data() self.generate_test_data()
self.classifier.train() self.classifier.train()
self.assertListEqual( self.assertListEqual(
list(self.classifier.correspondent_classifier.classes_), list(self.classifier.correspondent_classifier.classes_),
[-1, self.c1.pk], [-1, self.c1.pk],
@ -143,8 +202,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
) )
def testPredict(self): def testPredict(self):
"""
GIVEN:
- Classifier trained against test data
WHEN:
- Prediction requested for correspondent, tags, type
THEN:
- Expected predictions based on training set
"""
self.generate_test_data() self.generate_test_data()
self.classifier.train() self.classifier.train()
self.assertEqual( self.assertEqual(
self.classifier.predict_correspondent(self.doc1.content), self.classifier.predict_correspondent(self.doc1.content),
self.c1.pk, self.c1.pk,
@ -164,20 +232,51 @@ class TestClassifier(DirectoriesMixin, TestCase):
) )
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None) self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
def testDatasetHashing(self): def test_no_retrain_if_no_change(self):
"""
GIVEN:
- Classifier trained with current data
WHEN:
- Classifier training is requested again
THEN:
- Classifier does not redo training
"""
self.generate_test_data() self.generate_test_data()
self.assertTrue(self.classifier.train()) self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train()) self.assertFalse(self.classifier.train())
def test_retrain_if_change(self):
"""
GIVEN:
- Classifier trained with current data
WHEN:
- Classifier training is requested again
- Documents have changed
THEN:
- Classifier does not redo training
"""
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.doc1.correspondent = self.c2
self.doc1.save()
self.assertTrue(self.classifier.train())
def testVersionIncreased(self): def testVersionIncreased(self):
"""
self.generate_test_data() GIVEN:
self.assertTrue(self.classifier.train()) - Existing classifier model saved at a version
self.assertFalse(self.classifier.train()) WHEN:
- Attempt to load classifier file from newer version
self.classifier.save() THEN:
- Exception is raised
"""
self.generate_train_and_save()
classifier2 = DocumentClassifier() classifier2 = DocumentClassifier()
@ -194,14 +293,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
# assure that we can load the classifier after saving it. # assure that we can load the classifier after saving it.
classifier2.load() classifier2.load()
@override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self): def testSaveClassifier(self):
self.generate_test_data() self.generate_train_and_save()
self.classifier.train()
self.classifier.save()
new_classifier = DocumentClassifier() new_classifier = DocumentClassifier()
new_classifier.load() new_classifier.load()
@ -209,25 +303,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertFalse(new_classifier.train()) self.assertFalse(new_classifier.train())
# @override_settings(
# MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
# )
# def test_create_test_load_and_classify(self):
# self.generate_test_data()
# self.classifier.train()
# self.classifier.save()
def test_load_and_classify(self): def test_load_and_classify(self):
# Generate test data, train and save to the model file
# This ensures the model file sklearn version matches self.generate_train_and_save()
# and eliminates a warning
shutil.copy(
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
)
self.generate_test_data()
self.classifier.train()
self.classifier.save()
new_classifier = DocumentClassifier() new_classifier = DocumentClassifier()
new_classifier.load() new_classifier.load()
@ -245,11 +323,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
THEN: THEN:
- The ClassifierModelCorruptError is raised - The ClassifierModelCorruptError is raised
""" """
shutil.copy( self.generate_train_and_save()
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"), # First load is the schema version,allow it
)
# First load is the schema version
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError): with self.assertRaises(ClassifierModelCorruptError):