mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-16 00:36:22 +00:00
Changes from a hash based system to a time based system to prevent extra retrains
This commit is contained in:

committed by
Trenton H

parent
8709ea4df0
commit
c958a7c593
@@ -1,10 +1,10 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -62,12 +62,13 @@ class DocumentClassifier:
|
||||
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
FORMAT_VERSION = 8
|
||||
# v9 - Changed from hash to time for training data check
|
||||
FORMAT_VERSION = 9
|
||||
|
||||
def __init__(self):
|
||||
# hash of the training data. used to prevent re-training when the
|
||||
# last time training data was calculated. used to prevent re-training when the
|
||||
# training data has not changed.
|
||||
self.data_hash: Optional[bytes] = None
|
||||
self.last_data_change: Optional[datetime] = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.tags_binarizer = None
|
||||
@@ -91,7 +92,7 @@ class DocumentClassifier:
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.data_hash = pickle.load(f)
|
||||
self.last_data_change = pickle.load(f)
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
@@ -121,7 +122,7 @@ class DocumentClassifier:
|
||||
|
||||
with open(target_file_temp, "wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
pickle.dump(self.data_hash, f)
|
||||
pickle.dump(self.last_data_change, f)
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
@@ -137,35 +138,40 @@ class DocumentClassifier:
|
||||
|
||||
def train(self):
|
||||
|
||||
# Get non-inbox documents
|
||||
docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
|
||||
|
||||
# No documents exit to train against
|
||||
if docs_queryset.count() == 0:
|
||||
raise ValueError("No training data available.")
|
||||
|
||||
# No documents have changed since classifier was trained
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_data_change is not None
|
||||
and self.last_data_change >= latest_doc_change
|
||||
):
|
||||
return False
|
||||
|
||||
labels_tags = []
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
|
||||
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
|
||||
|
||||
if docs_queryset.count() == 0:
|
||||
raise ValueError("No training data available.")
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
m = hashlib.sha1()
|
||||
for doc in docs_queryset:
|
||||
preprocessed_content = self.preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode("utf-8"))
|
||||
|
||||
y = -1
|
||||
dt = doc.document_type
|
||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = dt.pk
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_document_type.append(y)
|
||||
|
||||
y = -1
|
||||
cor = doc.correspondent
|
||||
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = cor.pk
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_correspondent.append(y)
|
||||
|
||||
tags = sorted(
|
||||
@@ -174,22 +180,14 @@ class DocumentClassifier:
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
)
|
||||
for tag in tags:
|
||||
m.update(tag.to_bytes(4, "little", signed=True))
|
||||
labels_tags.append(tags)
|
||||
|
||||
y = -1
|
||||
sd = doc.storage_path
|
||||
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = sd.pk
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_storage_path.append(y)
|
||||
|
||||
new_data_hash = m.digest()
|
||||
|
||||
if self.data_hash and new_data_hash == self.data_hash:
|
||||
return False
|
||||
|
||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
@@ -216,12 +214,16 @@ class DocumentClassifier:
|
||||
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from sklearn.preprocessing import LabelBinarizer
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
|
||||
# Step 2: vectorize data
|
||||
logger.debug("Vectorizing data...")
|
||||
|
||||
def content_generator() -> Iterator[str]:
|
||||
"""
|
||||
Generates the content for documents, but once at a time
|
||||
"""
|
||||
for doc in docs_queryset:
|
||||
yield self.preprocess_content(doc.content)
|
||||
|
||||
@@ -299,7 +301,7 @@ class DocumentClassifier:
|
||||
"There are no storage paths. Not training storage path classifier.",
|
||||
)
|
||||
|
||||
self.data_hash = new_data_hash
|
||||
self.last_data_change = latest_doc_change
|
||||
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user