mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Add custom fields to classifier
This commit is contained in:
parent
b8c618abbe
commit
a632b6b711
@ -97,6 +97,8 @@ class DocumentClassifier:
|
|||||||
self.correspondent_classifier = None
|
self.correspondent_classifier = None
|
||||||
self.document_type_classifier = None
|
self.document_type_classifier = None
|
||||||
self.storage_path_classifier = None
|
self.storage_path_classifier = None
|
||||||
|
self.custom_fields_binarizer = None
|
||||||
|
self.custom_fields_classifier = None
|
||||||
|
|
||||||
self._stemmer = None
|
self._stemmer = None
|
||||||
self._stop_words = None
|
self._stop_words = None
|
||||||
@ -120,11 +122,12 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
self.data_vectorizer = pickle.load(f)
|
self.data_vectorizer = pickle.load(f)
|
||||||
self.tags_binarizer = pickle.load(f)
|
self.tags_binarizer = pickle.load(f)
|
||||||
|
|
||||||
self.tags_classifier = pickle.load(f)
|
self.tags_classifier = pickle.load(f)
|
||||||
self.correspondent_classifier = pickle.load(f)
|
self.correspondent_classifier = pickle.load(f)
|
||||||
self.document_type_classifier = pickle.load(f)
|
self.document_type_classifier = pickle.load(f)
|
||||||
self.storage_path_classifier = pickle.load(f)
|
self.storage_path_classifier = pickle.load(f)
|
||||||
|
self.custom_fields_binarizer = pickle.load(f)
|
||||||
|
self.custom_fields_classifier = pickle.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise ClassifierModelCorruptError from err
|
raise ClassifierModelCorruptError from err
|
||||||
|
|
||||||
@ -162,6 +165,9 @@ class DocumentClassifier:
|
|||||||
pickle.dump(self.document_type_classifier, f)
|
pickle.dump(self.document_type_classifier, f)
|
||||||
pickle.dump(self.storage_path_classifier, f)
|
pickle.dump(self.storage_path_classifier, f)
|
||||||
|
|
||||||
|
pickle.dump(self.custom_fields_binarizer, f)
|
||||||
|
pickle.dump(self.custom_fields_classifier, f)
|
||||||
|
|
||||||
target_file_temp.rename(target_file)
|
target_file_temp.rename(target_file)
|
||||||
|
|
||||||
def train(self) -> bool:
|
def train(self) -> bool:
|
||||||
@ -183,6 +189,7 @@ class DocumentClassifier:
|
|||||||
labels_correspondent = []
|
labels_correspondent = []
|
||||||
labels_document_type = []
|
labels_document_type = []
|
||||||
labels_storage_path = []
|
labels_storage_path = []
|
||||||
|
labels_custom_fields = []
|
||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
@ -218,13 +225,25 @@ class DocumentClassifier:
|
|||||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||||
labels_storage_path.append(y)
|
labels_storage_path.append(y)
|
||||||
|
|
||||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
custom_fields = sorted(
|
||||||
|
cf.pk
|
||||||
|
for cf in doc.custom_fields.filter(
|
||||||
|
field__matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for cf in custom_fields:
|
||||||
|
hasher.update(cf.to_bytes(4, "little", signed=True))
|
||||||
|
labels_custom_fields.append(custom_fields)
|
||||||
|
|
||||||
|
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||||
num_tags = len(labels_tags_unique)
|
num_tags = len(labels_tags_unique)
|
||||||
|
|
||||||
|
labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs}
|
||||||
|
num_custom_fields = len(labels_custom_fields_unique)
|
||||||
|
|
||||||
# Check if retraining is actually required.
|
# Check if retraining is actually required.
|
||||||
# A document has been updated since the classifier was trained
|
# A document has been updated since the classifier was trained
|
||||||
# New auto tags, types, correspondent, storage paths exist
|
# New auto tags, types, correspondent, storage paths or custom fields exist
|
||||||
latest_doc_change = docs_queryset.latest("modified").modified
|
latest_doc_change = docs_queryset.latest("modified").modified
|
||||||
if (
|
if (
|
||||||
self.last_doc_change_time is not None
|
self.last_doc_change_time is not None
|
||||||
@ -253,7 +272,8 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
|
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
|
||||||
f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
|
f"{num_document_types} document type(s), {num_storage_paths} storage path(s), "
|
||||||
|
f"{num_custom_fields} custom field(s)",
|
||||||
)
|
)
|
||||||
|
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
@ -345,6 +365,39 @@ class DocumentClassifier:
|
|||||||
"There are no storage paths. Not training storage path classifier.",
|
"There are no storage paths. Not training storage path classifier.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_custom_fields > 0:
|
||||||
|
logger.debug("Training custom fields classifier...")
|
||||||
|
|
||||||
|
if num_custom_fields == 1:
|
||||||
|
# Special case where only one custom field has auto:
|
||||||
|
# Fallback to binary classification.
|
||||||
|
labels_custom_fields = [
|
||||||
|
label[0] if len(label) == 1 else -1
|
||||||
|
for label in labels_custom_fields
|
||||||
|
]
|
||||||
|
self.custom_fields_binarizer = LabelBinarizer()
|
||||||
|
labels_custom_fields_vectorized = (
|
||||||
|
self.custom_fields_binarizer.fit_transform(
|
||||||
|
labels_custom_fields,
|
||||||
|
).ravel()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.custom_fields_binarizer = MultiLabelBinarizer()
|
||||||
|
labels_custom_fields_vectorized = (
|
||||||
|
self.custom_fields_binarizer.fit_transform(labels_custom_fields)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.custom_fields_classifier = MLPClassifier(tol=0.01)
|
||||||
|
self.custom_fields_classifier.fit(
|
||||||
|
data_vectorized,
|
||||||
|
labels_custom_fields_vectorized,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.custom_fields_classifier = None
|
||||||
|
logger.debug(
|
||||||
|
"There are no custom fields. Not training custom fields classifier.",
|
||||||
|
)
|
||||||
|
|
||||||
self.last_doc_change_time = latest_doc_change
|
self.last_doc_change_time = latest_doc_change
|
||||||
self.last_auto_type_hash = hasher.digest()
|
self.last_auto_type_hash = hasher.digest()
|
||||||
|
|
||||||
@ -472,3 +525,24 @@ class DocumentClassifier:
|
|||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def predict_custom_fields(self, content: str) -> list[int]:
|
||||||
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
|
if self.custom_fields_classifier:
|
||||||
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
|
y = self.custom_fields_classifier.predict(X)
|
||||||
|
custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0]
|
||||||
|
if type_of_target(y).startswith("multilabel"):
|
||||||
|
# the usual case when there are multiple custom fields.
|
||||||
|
return list(custom_fields_ids)
|
||||||
|
elif type_of_target(y) == "binary" and custom_fields_ids != -1:
|
||||||
|
# This is for when we have binary classification with only one
|
||||||
|
# custom field and the result is to assign this custom field.
|
||||||
|
return [custom_fields_ids]
|
||||||
|
else:
|
||||||
|
# Usually binary as well with -1 as the result, but we're
|
||||||
|
# going to catch everything else here as well.
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user