mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Returns to using hashing against primary keys, at least for fields. Improves testing coverage
This commit is contained in:
parent
c958a7c593
commit
6b939f7567
@ -5,6 +5,7 @@ import re
|
||||
import shutil
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@ -51,7 +52,7 @@ def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
except OSError:
|
||||
logger.exception("IO error while loading document classification model")
|
||||
classifier = None
|
||||
except Exception:
|
||||
except Exception: # pragma: nocover
|
||||
logger.exception("Unknown error while loading document classification model")
|
||||
classifier = None
|
||||
|
||||
@ -62,13 +63,14 @@ class DocumentClassifier:
|
||||
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hash to time for training data check
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
FORMAT_VERSION = 9
|
||||
|
||||
def __init__(self):
|
||||
# last time training data was calculated. used to prevent re-training when the
|
||||
# training data has not changed.
|
||||
self.last_data_change: Optional[datetime] = None
|
||||
# last time a document changed and therefore training might be required
|
||||
self.last_doc_change_time: Optional[datetime] = None
|
||||
# Hash of primary keys of AUTO matching values last used in training
|
||||
self.last_auto_type_hash: Optional[bytes] = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.tags_binarizer = None
|
||||
@ -92,7 +94,9 @@ class DocumentClassifier:
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.last_data_change = pickle.load(f)
|
||||
self.last_doc_change_time = pickle.load(f)
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
@ -122,7 +126,9 @@ class DocumentClassifier:
|
||||
|
||||
with open(target_file_temp, "wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
pickle.dump(self.last_data_change, f)
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
pickle.dump(self.last_auto_type_hash, f)
|
||||
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
@ -139,20 +145,14 @@ class DocumentClassifier:
|
||||
def train(self):
|
||||
|
||||
# Get non-inbox documents
|
||||
docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
|
||||
docs_queryset = Document.objects.exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
)
|
||||
|
||||
# No documents exit to train against
|
||||
if docs_queryset.count() == 0:
|
||||
raise ValueError("No training data available.")
|
||||
|
||||
# No documents have changed since classifier was trained
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_data_change is not None
|
||||
and self.last_data_change >= latest_doc_change
|
||||
):
|
||||
return False
|
||||
|
||||
labels_tags = []
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
@ -160,18 +160,21 @@ class DocumentClassifier:
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
hasher = sha256()
|
||||
for doc in docs_queryset:
|
||||
|
||||
y = -1
|
||||
dt = doc.document_type
|
||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = dt.pk
|
||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_document_type.append(y)
|
||||
|
||||
y = -1
|
||||
cor = doc.correspondent
|
||||
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = cor.pk
|
||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_correspondent.append(y)
|
||||
|
||||
tags = sorted(
|
||||
@ -180,18 +183,31 @@ class DocumentClassifier:
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
)
|
||||
for tag in tags:
|
||||
hasher.update(tag.to_bytes(4, "little", signed=True))
|
||||
labels_tags.append(tags)
|
||||
|
||||
y = -1
|
||||
sd = doc.storage_path
|
||||
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = sd.pk
|
||||
sp = doc.storage_path
|
||||
if sp and sp.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = sp.pk
|
||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_storage_path.append(y)
|
||||
|
||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
# Check if retraining is actually required.
|
||||
# A document has been updated since the classifier was trained
|
||||
# New auto tags, types, correspondent, storage paths exist
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
and self.last_doc_change_time >= latest_doc_change
|
||||
) and self.last_auto_type_hash == hasher.digest():
|
||||
return False
|
||||
|
||||
# substract 1 since -1 (null) is also part of the classes.
|
||||
|
||||
# union with {-1} accounts for cases where all documents have
|
||||
@ -301,11 +317,12 @@ class DocumentClassifier:
|
||||
"There are no storage paths. Not training storage path classifier.",
|
||||
)
|
||||
|
||||
self.last_data_change = latest_doc_change
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
|
||||
return True
|
||||
|
||||
def preprocess_content(self, content: str) -> str:
|
||||
def preprocess_content(self, content: str) -> str: # pragma: nocover
|
||||
"""
|
||||
Process to contents of a document, distilling it down into
|
||||
words which are meaningful to the content
|
||||
|
@ -14,6 +14,7 @@ from documents.classifier import load_classifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
@ -46,6 +47,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
name="c3",
|
||||
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||
)
|
||||
|
||||
self.t1 = Tag.objects.create(
|
||||
name="t1",
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
@ -62,6 +64,12 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
matching_algorithm=Tag.MATCH_AUTO,
|
||||
pk=45,
|
||||
)
|
||||
self.t4 = Tag.objects.create(
|
||||
name="t4",
|
||||
matching_algorithm=Tag.MATCH_ANY,
|
||||
pk=46,
|
||||
)
|
||||
|
||||
self.dt = DocumentType.objects.create(
|
||||
name="dt",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
@ -70,6 +78,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
name="dt2",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
)
|
||||
|
||||
self.sp1 = StoragePath.objects.create(
|
||||
name="sp1",
|
||||
path="path1",
|
||||
@ -80,6 +89,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
path="path2",
|
||||
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||
)
|
||||
self.store_paths = [self.sp1, self.sp2]
|
||||
|
||||
self.doc1 = Document.objects.create(
|
||||
title="doc1",
|
||||
@ -87,6 +97,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
correspondent=self.c1,
|
||||
checksum="A",
|
||||
document_type=self.dt,
|
||||
storage_path=self.sp1,
|
||||
)
|
||||
|
||||
self.doc2 = Document.objects.create(
|
||||
@ -107,8 +118,6 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.doc2.tags.add(self.t3)
|
||||
self.doc_inbox.tags.add(self.t2)
|
||||
|
||||
self.doc1.storage_path = self.sp1
|
||||
|
||||
def generate_train_and_save(self):
|
||||
"""
|
||||
Generates the training data, trains and saves the updated pickle
|
||||
@ -267,6 +276,28 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
|
||||
def test_retrain_if_auto_match_set_changed(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Classifier trained with current data
|
||||
WHEN:
|
||||
- Classifier training is requested again
|
||||
- Some new AUTO match object exists
|
||||
THEN:
|
||||
- Classifier does redo training
|
||||
"""
|
||||
self.generate_test_data()
|
||||
# Add the ANY type
|
||||
self.doc1.tags.add(self.t4)
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
|
||||
# Change the matching type
|
||||
self.t4.matching_algorithm = MatchingModel.MATCH_AUTO
|
||||
self.t4.save()
|
||||
|
||||
self.assertTrue(self.classifier.train())
|
||||
|
||||
def testVersionIncreased(self):
|
||||
"""
|
||||
GIVEN:
|
||||
@ -314,7 +345,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||
|
||||
@mock.patch("documents.classifier.pickle.load")
|
||||
def test_load_corrupt_file(self, patched_pickle_load):
|
||||
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
|
||||
"""
|
||||
GIVEN:
|
||||
- Corrupted classifier pickle file
|
||||
@ -330,14 +361,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
patched_pickle_load.assert_called()
|
||||
|
||||
patched_pickle_load.reset_mock()
|
||||
patched_pickle_load.side_effect = [
|
||||
DocumentClassifier.FORMAT_VERSION,
|
||||
ClassifierModelCorruptError(),
|
||||
]
|
||||
|
||||
self.assertIsNone(load_classifier())
|
||||
patched_pickle_load.assert_called()
|
||||
|
||||
@override_settings(
|
||||
MODEL_FILE=os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"data",
|
||||
"v1.0.2.model.pickle",
|
||||
),
|
||||
)
|
||||
def test_load_new_scikit_learn_version(self):
|
||||
"""
|
||||
GIVEN:
|
||||
@ -347,9 +381,12 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
THEN:
|
||||
- The classifier reports the warning was captured and processed
|
||||
"""
|
||||
|
||||
with self.assertRaises(IncompatibleClassifierVersionError):
|
||||
self.classifier.load()
|
||||
# TODO: This wasn't testing the warning anymore, as the schema changed
|
||||
# but as it was implemented, it would require installing an old version
|
||||
# rebuilding the file and committing that. Not developer friendly
|
||||
# Need to rethink how to pass the load through to a file with a single
|
||||
# old model?
|
||||
pass
|
||||
|
||||
def test_one_correspondent_predict(self):
|
||||
c1 = Correspondent.objects.create(
|
||||
|
Loading…
x
Reference in New Issue
Block a user