mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Changes from a hash based system to a time based system to prevent extra retrains
This commit is contained in:
parent
8709ea4df0
commit
c958a7c593
@ -1,10 +1,10 @@
|
|||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -62,12 +62,13 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
# v7 - Updated scikit-learn package version
|
# v7 - Updated scikit-learn package version
|
||||||
# v8 - Added storage path classifier
|
# v8 - Added storage path classifier
|
||||||
FORMAT_VERSION = 8
|
# v9 - Changed from hash to time for training data check
|
||||||
|
FORMAT_VERSION = 9
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# hash of the training data. used to prevent re-training when the
|
# last time training data was calculated. used to prevent re-training when the
|
||||||
# training data has not changed.
|
# training data has not changed.
|
||||||
self.data_hash: Optional[bytes] = None
|
self.last_data_change: Optional[datetime] = None
|
||||||
|
|
||||||
self.data_vectorizer = None
|
self.data_vectorizer = None
|
||||||
self.tags_binarizer = None
|
self.tags_binarizer = None
|
||||||
@ -91,7 +92,7 @@ class DocumentClassifier:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.data_hash = pickle.load(f)
|
self.last_data_change = pickle.load(f)
|
||||||
self.data_vectorizer = pickle.load(f)
|
self.data_vectorizer = pickle.load(f)
|
||||||
self.tags_binarizer = pickle.load(f)
|
self.tags_binarizer = pickle.load(f)
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
with open(target_file_temp, "wb") as f:
|
with open(target_file_temp, "wb") as f:
|
||||||
pickle.dump(self.FORMAT_VERSION, f)
|
pickle.dump(self.FORMAT_VERSION, f)
|
||||||
pickle.dump(self.data_hash, f)
|
pickle.dump(self.last_data_change, f)
|
||||||
pickle.dump(self.data_vectorizer, f)
|
pickle.dump(self.data_vectorizer, f)
|
||||||
|
|
||||||
pickle.dump(self.tags_binarizer, f)
|
pickle.dump(self.tags_binarizer, f)
|
||||||
@ -137,35 +138,40 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|
||||||
|
# Get non-inbox documents
|
||||||
|
docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
|
||||||
|
|
||||||
|
# No documents exit to train against
|
||||||
|
if docs_queryset.count() == 0:
|
||||||
|
raise ValueError("No training data available.")
|
||||||
|
|
||||||
|
# No documents have changed since classifier was trained
|
||||||
|
latest_doc_change = docs_queryset.latest("modified").modified
|
||||||
|
if (
|
||||||
|
self.last_data_change is not None
|
||||||
|
and self.last_data_change >= latest_doc_change
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
labels_tags = []
|
labels_tags = []
|
||||||
labels_correspondent = []
|
labels_correspondent = []
|
||||||
labels_document_type = []
|
labels_document_type = []
|
||||||
labels_storage_path = []
|
labels_storage_path = []
|
||||||
|
|
||||||
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
|
|
||||||
|
|
||||||
if docs_queryset.count() == 0:
|
|
||||||
raise ValueError("No training data available.")
|
|
||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
m = hashlib.sha1()
|
|
||||||
for doc in docs_queryset:
|
for doc in docs_queryset:
|
||||||
preprocessed_content = self.preprocess_content(doc.content)
|
|
||||||
m.update(preprocessed_content.encode("utf-8"))
|
|
||||||
|
|
||||||
y = -1
|
y = -1
|
||||||
dt = doc.document_type
|
dt = doc.document_type
|
||||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||||
y = dt.pk
|
y = dt.pk
|
||||||
m.update(y.to_bytes(4, "little", signed=True))
|
|
||||||
labels_document_type.append(y)
|
labels_document_type.append(y)
|
||||||
|
|
||||||
y = -1
|
y = -1
|
||||||
cor = doc.correspondent
|
cor = doc.correspondent
|
||||||
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||||
y = cor.pk
|
y = cor.pk
|
||||||
m.update(y.to_bytes(4, "little", signed=True))
|
|
||||||
labels_correspondent.append(y)
|
labels_correspondent.append(y)
|
||||||
|
|
||||||
tags = sorted(
|
tags = sorted(
|
||||||
@ -174,22 +180,14 @@ class DocumentClassifier:
|
|||||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tag in tags:
|
|
||||||
m.update(tag.to_bytes(4, "little", signed=True))
|
|
||||||
labels_tags.append(tags)
|
labels_tags.append(tags)
|
||||||
|
|
||||||
y = -1
|
y = -1
|
||||||
sd = doc.storage_path
|
sd = doc.storage_path
|
||||||
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
|
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||||
y = sd.pk
|
y = sd.pk
|
||||||
m.update(y.to_bytes(4, "little", signed=True))
|
|
||||||
labels_storage_path.append(y)
|
labels_storage_path.append(y)
|
||||||
|
|
||||||
new_data_hash = m.digest()
|
|
||||||
|
|
||||||
if self.data_hash and new_data_hash == self.data_hash:
|
|
||||||
return False
|
|
||||||
|
|
||||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||||
|
|
||||||
num_tags = len(labels_tags_unique)
|
num_tags = len(labels_tags_unique)
|
||||||
@ -216,12 +214,16 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
from sklearn.neural_network import MLPClassifier
|
from sklearn.neural_network import MLPClassifier
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
from sklearn.preprocessing import LabelBinarizer
|
||||||
|
from sklearn.preprocessing import MultiLabelBinarizer
|
||||||
|
|
||||||
# Step 2: vectorize data
|
# Step 2: vectorize data
|
||||||
logger.debug("Vectorizing data...")
|
logger.debug("Vectorizing data...")
|
||||||
|
|
||||||
def content_generator() -> Iterator[str]:
|
def content_generator() -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Generates the content for documents, but once at a time
|
||||||
|
"""
|
||||||
for doc in docs_queryset:
|
for doc in docs_queryset:
|
||||||
yield self.preprocess_content(doc.content)
|
yield self.preprocess_content(doc.content)
|
||||||
|
|
||||||
@ -299,7 +301,7 @@ class DocumentClassifier:
|
|||||||
"There are no storage paths. Not training storage path classifier.",
|
"There are no storage paths. Not training storage path classifier.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.data_hash = new_data_hash
|
self.last_data_change = latest_doc_change
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
Binary file not shown.
@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@ -22,15 +20,15 @@ from documents.tests.utils import DirectoriesMixin
|
|||||||
|
|
||||||
|
|
||||||
def dummy_preprocess(content: str):
|
def dummy_preprocess(content: str):
|
||||||
|
"""
|
||||||
|
Simpler, faster pre-processing for testing purposes
|
||||||
|
"""
|
||||||
content = content.lower().strip()
|
content = content.lower().strip()
|
||||||
content = re.sub(r"\s+", " ", content)
|
content = re.sub(r"\s+", " ", content)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
class TestClassifier(DirectoriesMixin, TestCase):
|
class TestClassifier(DirectoriesMixin, TestCase):
|
||||||
|
|
||||||
SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.classifier = DocumentClassifier()
|
self.classifier = DocumentClassifier()
|
||||||
@ -111,17 +109,68 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
self.doc1.storage_path = self.sp1
|
self.doc1.storage_path = self.sp1
|
||||||
|
|
||||||
def testNoTrainingData(self):
|
def generate_train_and_save(self):
|
||||||
try:
|
"""
|
||||||
|
Generates the training data, trains and saves the updated pickle
|
||||||
|
file. This ensures the test is using the same scikit learn version
|
||||||
|
and eliminates a warning from the test suite
|
||||||
|
"""
|
||||||
|
self.generate_test_data()
|
||||||
|
self.classifier.train()
|
||||||
|
self.classifier.save()
|
||||||
|
|
||||||
|
def test_no_training_data(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- No documents exist to train
|
||||||
|
WHEN:
|
||||||
|
- Classifier training is requested
|
||||||
|
THEN:
|
||||||
|
- Exception is raised
|
||||||
|
"""
|
||||||
|
with self.assertRaisesMessage(ValueError, "No training data available."):
|
||||||
|
self.classifier.train()
|
||||||
|
|
||||||
|
def test_no_non_inbox_tags(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- No documents without an inbox tag exist
|
||||||
|
WHEN:
|
||||||
|
- Classifier training is requested
|
||||||
|
THEN:
|
||||||
|
- Exception is raised
|
||||||
|
"""
|
||||||
|
|
||||||
|
t1 = Tag.objects.create(
|
||||||
|
name="t1",
|
||||||
|
matching_algorithm=Tag.MATCH_ANY,
|
||||||
|
pk=34,
|
||||||
|
is_inbox_tag=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
doc1 = Document.objects.create(
|
||||||
|
title="doc1",
|
||||||
|
content="this is a document from c1",
|
||||||
|
checksum="A",
|
||||||
|
)
|
||||||
|
doc1.tags.add(t1)
|
||||||
|
|
||||||
|
with self.assertRaisesMessage(ValueError, "No training data available."):
|
||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
except ValueError as e:
|
|
||||||
self.assertEqual(str(e), "No training data available.")
|
|
||||||
else:
|
|
||||||
self.fail("Should raise exception")
|
|
||||||
|
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- A document exists
|
||||||
|
- No tags/not enough data to predict
|
||||||
|
WHEN:
|
||||||
|
- Classifier prediction is requested
|
||||||
|
THEN:
|
||||||
|
- Classifier returns no predictions
|
||||||
|
"""
|
||||||
Document.objects.create(title="WOW", checksum="3457", content="ASD")
|
Document.objects.create(title="WOW", checksum="3457", content="ASD")
|
||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
|
|
||||||
self.assertIsNone(self.classifier.document_type_classifier)
|
self.assertIsNone(self.classifier.document_type_classifier)
|
||||||
self.assertIsNone(self.classifier.tags_classifier)
|
self.assertIsNone(self.classifier.tags_classifier)
|
||||||
self.assertIsNone(self.classifier.correspondent_classifier)
|
self.assertIsNone(self.classifier.correspondent_classifier)
|
||||||
@ -131,8 +180,18 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
self.assertIsNone(self.classifier.predict_correspondent(""))
|
self.assertIsNone(self.classifier.predict_correspondent(""))
|
||||||
|
|
||||||
def testTrain(self):
|
def testTrain(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Test data
|
||||||
|
WHEN:
|
||||||
|
- Classifier is trained
|
||||||
|
THEN:
|
||||||
|
- Classifier uses correct values for correspondent learning
|
||||||
|
- Classifier uses correct values for tags learning
|
||||||
|
"""
|
||||||
self.generate_test_data()
|
self.generate_test_data()
|
||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(self.classifier.correspondent_classifier.classes_),
|
list(self.classifier.correspondent_classifier.classes_),
|
||||||
[-1, self.c1.pk],
|
[-1, self.c1.pk],
|
||||||
@ -143,8 +202,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def testPredict(self):
|
def testPredict(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Classifier trained against test data
|
||||||
|
WHEN:
|
||||||
|
- Prediction requested for correspondent, tags, type
|
||||||
|
THEN:
|
||||||
|
- Expected predictions based on training set
|
||||||
|
"""
|
||||||
self.generate_test_data()
|
self.generate_test_data()
|
||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.classifier.predict_correspondent(self.doc1.content),
|
self.classifier.predict_correspondent(self.doc1.content),
|
||||||
self.c1.pk,
|
self.c1.pk,
|
||||||
@ -164,20 +232,51 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||||
|
|
||||||
def testDatasetHashing(self):
|
def test_no_retrain_if_no_change(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Classifier trained with current data
|
||||||
|
WHEN:
|
||||||
|
- Classifier training is requested again
|
||||||
|
THEN:
|
||||||
|
- Classifier does not redo training
|
||||||
|
"""
|
||||||
|
|
||||||
self.generate_test_data()
|
self.generate_test_data()
|
||||||
|
|
||||||
self.assertTrue(self.classifier.train())
|
self.assertTrue(self.classifier.train())
|
||||||
self.assertFalse(self.classifier.train())
|
self.assertFalse(self.classifier.train())
|
||||||
|
|
||||||
|
def test_retrain_if_change(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Classifier trained with current data
|
||||||
|
WHEN:
|
||||||
|
- Classifier training is requested again
|
||||||
|
- Documents have changed
|
||||||
|
THEN:
|
||||||
|
- Classifier does not redo training
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.generate_test_data()
|
||||||
|
|
||||||
|
self.assertTrue(self.classifier.train())
|
||||||
|
|
||||||
|
self.doc1.correspondent = self.c2
|
||||||
|
self.doc1.save()
|
||||||
|
|
||||||
|
self.assertTrue(self.classifier.train())
|
||||||
|
|
||||||
def testVersionIncreased(self):
|
def testVersionIncreased(self):
|
||||||
|
"""
|
||||||
self.generate_test_data()
|
GIVEN:
|
||||||
self.assertTrue(self.classifier.train())
|
- Existing classifier model saved at a version
|
||||||
self.assertFalse(self.classifier.train())
|
WHEN:
|
||||||
|
- Attempt to load classifier file from newer version
|
||||||
self.classifier.save()
|
THEN:
|
||||||
|
- Exception is raised
|
||||||
|
"""
|
||||||
|
self.generate_train_and_save()
|
||||||
|
|
||||||
classifier2 = DocumentClassifier()
|
classifier2 = DocumentClassifier()
|
||||||
|
|
||||||
@ -194,14 +293,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
# assure that we can load the classifier after saving it.
|
# assure that we can load the classifier after saving it.
|
||||||
classifier2.load()
|
classifier2.load()
|
||||||
|
|
||||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
|
||||||
def testSaveClassifier(self):
|
def testSaveClassifier(self):
|
||||||
|
|
||||||
self.generate_test_data()
|
self.generate_train_and_save()
|
||||||
|
|
||||||
self.classifier.train()
|
|
||||||
|
|
||||||
self.classifier.save()
|
|
||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.load()
|
new_classifier.load()
|
||||||
@ -209,25 +303,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
self.assertFalse(new_classifier.train())
|
self.assertFalse(new_classifier.train())
|
||||||
|
|
||||||
# @override_settings(
|
|
||||||
# MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
|
|
||||||
# )
|
|
||||||
# def test_create_test_load_and_classify(self):
|
|
||||||
# self.generate_test_data()
|
|
||||||
# self.classifier.train()
|
|
||||||
# self.classifier.save()
|
|
||||||
|
|
||||||
def test_load_and_classify(self):
|
def test_load_and_classify(self):
|
||||||
# Generate test data, train and save to the model file
|
|
||||||
# This ensures the model file sklearn version matches
|
self.generate_train_and_save()
|
||||||
# and eliminates a warning
|
|
||||||
shutil.copy(
|
|
||||||
self.SAMPLE_MODEL_FILE,
|
|
||||||
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
|
|
||||||
)
|
|
||||||
self.generate_test_data()
|
|
||||||
self.classifier.train()
|
|
||||||
self.classifier.save()
|
|
||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.load()
|
new_classifier.load()
|
||||||
@ -245,11 +323,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- The ClassifierModelCorruptError is raised
|
- The ClassifierModelCorruptError is raised
|
||||||
"""
|
"""
|
||||||
shutil.copy(
|
self.generate_train_and_save()
|
||||||
self.SAMPLE_MODEL_FILE,
|
|
||||||
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
|
# First load is the schema version,allow it
|
||||||
)
|
|
||||||
# First load is the schema version
|
|
||||||
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||||
|
|
||||||
with self.assertRaises(ClassifierModelCorruptError):
|
with self.assertRaises(ClassifierModelCorruptError):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user