mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Add custom fields to classifier
This commit is contained in:
parent
b8c618abbe
commit
a632b6b711
@ -97,6 +97,8 @@ class DocumentClassifier:
|
||||
self.correspondent_classifier = None
|
||||
self.document_type_classifier = None
|
||||
self.storage_path_classifier = None
|
||||
self.custom_fields_binarizer = None
|
||||
self.custom_fields_classifier = None
|
||||
|
||||
self._stemmer = None
|
||||
self._stop_words = None
|
||||
@ -120,11 +122,12 @@ class DocumentClassifier:
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
self.custom_fields_binarizer = pickle.load(f)
|
||||
self.custom_fields_classifier = pickle.load(f)
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
@ -162,6 +165,9 @@ class DocumentClassifier:
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
pickle.dump(self.storage_path_classifier, f)
|
||||
|
||||
pickle.dump(self.custom_fields_binarizer, f)
|
||||
pickle.dump(self.custom_fields_classifier, f)
|
||||
|
||||
target_file_temp.rename(target_file)
|
||||
|
||||
def train(self) -> bool:
|
||||
@ -183,6 +189,7 @@ class DocumentClassifier:
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
labels_custom_fields = []
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
@ -218,13 +225,25 @@ class DocumentClassifier:
|
||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_storage_path.append(y)
|
||||
|
||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||
custom_fields = sorted(
|
||||
cf.pk
|
||||
for cf in doc.custom_fields.filter(
|
||||
field__matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
)
|
||||
for cf in custom_fields:
|
||||
hasher.update(cf.to_bytes(4, "little", signed=True))
|
||||
labels_custom_fields.append(custom_fields)
|
||||
|
||||
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs}
|
||||
num_custom_fields = len(labels_custom_fields_unique)
|
||||
|
||||
# Check if retraining is actually required.
|
||||
# A document has been updated since the classifier was trained
|
||||
# New auto tags, types, correspondent, storage paths exist
|
||||
# New auto tags, types, correspondent, storage paths or custom fields exist
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
@ -253,7 +272,8 @@ class DocumentClassifier:
|
||||
|
||||
logger.debug(
|
||||
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
|
||||
f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
|
||||
f"{num_document_types} document type(s), {num_storage_paths} storage path(s), "
|
||||
f"{num_custom_fields} custom field(s)",
|
||||
)
|
||||
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
@ -345,6 +365,39 @@ class DocumentClassifier:
|
||||
"There are no storage paths. Not training storage path classifier.",
|
||||
)
|
||||
|
||||
if num_custom_fields > 0:
|
||||
logger.debug("Training custom fields classifier...")
|
||||
|
||||
if num_custom_fields == 1:
|
||||
# Special case where only one custom field has auto:
|
||||
# Fallback to binary classification.
|
||||
labels_custom_fields = [
|
||||
label[0] if len(label) == 1 else -1
|
||||
for label in labels_custom_fields
|
||||
]
|
||||
self.custom_fields_binarizer = LabelBinarizer()
|
||||
labels_custom_fields_vectorized = (
|
||||
self.custom_fields_binarizer.fit_transform(
|
||||
labels_custom_fields,
|
||||
).ravel()
|
||||
)
|
||||
else:
|
||||
self.custom_fields_binarizer = MultiLabelBinarizer()
|
||||
labels_custom_fields_vectorized = (
|
||||
self.custom_fields_binarizer.fit_transform(labels_custom_fields)
|
||||
)
|
||||
|
||||
self.custom_fields_classifier = MLPClassifier(tol=0.01)
|
||||
self.custom_fields_classifier.fit(
|
||||
data_vectorized,
|
||||
labels_custom_fields_vectorized,
|
||||
)
|
||||
else:
|
||||
self.custom_fields_classifier = None
|
||||
logger.debug(
|
||||
"There are no custom fields. Not training custom fields classifier.",
|
||||
)
|
||||
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
|
||||
@ -472,3 +525,24 @@ class DocumentClassifier:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_custom_fields(self, content: str) -> list[int]:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.custom_fields_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
y = self.custom_fields_classifier.predict(X)
|
||||
custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
# the usual case when there are multiple custom fields.
|
||||
return list(custom_fields_ids)
|
||||
elif type_of_target(y) == "binary" and custom_fields_ids != -1:
|
||||
# This is for when we have binary classification with only one
|
||||
# custom field and the result is to assign this custom field.
|
||||
return [custom_fields_ids]
|
||||
else:
|
||||
# Usually binary as well with -1 as the result, but we're
|
||||
# going to catch everything else here as well.
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
Loading…
x
Reference in New Issue
Block a user