Add custom fields to classifier

This commit is contained in:
shamoon 2024-12-13 13:39:19 -08:00
parent b8c618abbe
commit a632b6b711
No known key found for this signature in database

View File

@ -97,6 +97,8 @@ class DocumentClassifier:
self.correspondent_classifier = None
self.document_type_classifier = None
self.storage_path_classifier = None
self.custom_fields_binarizer = None
self.custom_fields_classifier = None
self._stemmer = None
self._stop_words = None
@ -120,11 +122,12 @@ class DocumentClassifier:
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
self.custom_fields_binarizer = pickle.load(f)
self.custom_fields_classifier = pickle.load(f)
except Exception as err:
raise ClassifierModelCorruptError from err
@ -162,6 +165,9 @@ class DocumentClassifier:
pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f)
pickle.dump(self.custom_fields_binarizer, f)
pickle.dump(self.custom_fields_classifier, f)
target_file_temp.rename(target_file)
def train(self) -> bool:
@ -183,6 +189,7 @@ class DocumentClassifier:
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
labels_custom_fields = []
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
@ -218,13 +225,25 @@ class DocumentClassifier:
hasher.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y)
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
custom_fields = sorted(
cf.pk
for cf in doc.custom_fields.filter(
field__matching_algorithm=MatchingModel.MATCH_AUTO,
)
)
for cf in custom_fields:
hasher.update(cf.to_bytes(4, "little", signed=True))
labels_custom_fields.append(custom_fields)
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
num_tags = len(labels_tags_unique)
labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs}
num_custom_fields = len(labels_custom_fields_unique)
# Check if retraining is actually required.
# A document has been updated since the classifier was trained
# New auto tags, types, correspondent, storage paths exist
# New auto tags, types, correspondent, storage paths or custom fields exist
latest_doc_change = docs_queryset.latest("modified").modified
if (
self.last_doc_change_time is not None
@ -253,7 +272,8 @@ class DocumentClassifier:
logger.debug(
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
f"{num_document_types} document type(s), {num_storage_paths} storage path(s), "
f"{num_custom_fields} custom field(s)",
)
from sklearn.feature_extraction.text import CountVectorizer
@ -345,6 +365,39 @@ class DocumentClassifier:
"There are no storage paths. Not training storage path classifier.",
)
if num_custom_fields > 0:
logger.debug("Training custom fields classifier...")
if num_custom_fields == 1:
# Special case where only one custom field has auto:
# Fallback to binary classification.
labels_custom_fields = [
label[0] if len(label) == 1 else -1
for label in labels_custom_fields
]
self.custom_fields_binarizer = LabelBinarizer()
labels_custom_fields_vectorized = (
self.custom_fields_binarizer.fit_transform(
labels_custom_fields,
).ravel()
)
else:
self.custom_fields_binarizer = MultiLabelBinarizer()
labels_custom_fields_vectorized = (
self.custom_fields_binarizer.fit_transform(labels_custom_fields)
)
self.custom_fields_classifier = MLPClassifier(tol=0.01)
self.custom_fields_classifier.fit(
data_vectorized,
labels_custom_fields_vectorized,
)
else:
self.custom_fields_classifier = None
logger.debug(
"There are no custom fields. Not training custom fields classifier.",
)
self.last_doc_change_time = latest_doc_change
self.last_auto_type_hash = hasher.digest()
@ -472,3 +525,24 @@ class DocumentClassifier:
return None
else:
return None
def predict_custom_fields(self, content: str) -> list[int]:
from sklearn.utils.multiclass import type_of_target
if self.custom_fields_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
y = self.custom_fields_classifier.predict(X)
custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0]
if type_of_target(y).startswith("multilabel"):
# the usual case when there are multiple custom fields.
return list(custom_fields_ids)
elif type_of_target(y) == "binary" and custom_fields_ids != -1:
# This is for when we have binary classification with only one
# custom field and the result is to assign this custom field.
return [custom_fields_ids]
else:
# Usually binary as well with -1 as the result, but we're
# going to catch everything else here as well.
return []
else:
return []