From a632b6b711ce11f456ae2a9dc30ab718f18d63d2 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:39:19 -0800 Subject: [PATCH] Add custom fields to classifier --- src/documents/classifier.py | 82 +++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 4 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 728c83228..58c2058b5 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -97,6 +97,8 @@ class DocumentClassifier: self.correspondent_classifier = None self.document_type_classifier = None self.storage_path_classifier = None + self.custom_fields_binarizer = None + self.custom_fields_classifier = None self._stemmer = None self._stop_words = None @@ -120,11 +122,12 @@ class DocumentClassifier: self.data_vectorizer = pickle.load(f) self.tags_binarizer = pickle.load(f) - self.tags_classifier = pickle.load(f) self.correspondent_classifier = pickle.load(f) self.document_type_classifier = pickle.load(f) self.storage_path_classifier = pickle.load(f) + self.custom_fields_binarizer = pickle.load(f) + self.custom_fields_classifier = pickle.load(f) except Exception as err: raise ClassifierModelCorruptError from err @@ -162,6 +165,9 @@ class DocumentClassifier: pickle.dump(self.document_type_classifier, f) pickle.dump(self.storage_path_classifier, f) + pickle.dump(self.custom_fields_binarizer, f) + pickle.dump(self.custom_fields_classifier, f) + target_file_temp.rename(target_file) def train(self) -> bool: @@ -183,6 +189,7 @@ class DocumentClassifier: labels_correspondent = [] labels_document_type = [] labels_storage_path = [] + labels_custom_fields = [] # Step 1: Extract and preprocess training data from the database. logger.debug("Gathering data from database...") @@ -218,13 +225,25 @@ class DocumentClassifier: hasher.update(y.to_bytes(4, "little", signed=True)) labels_storage_path.append(y) - labels_tags_unique = {tag for tags in labels_tags for tag in tags} + custom_fields = sorted( + cf.pk + for cf in doc.custom_fields.filter( + field__matching_algorithm=MatchingModel.MATCH_AUTO, + ) + ) + for cf in custom_fields: + hasher.update(cf.to_bytes(4, "little", signed=True)) + labels_custom_fields.append(custom_fields) + labels_tags_unique = {tag for tags in labels_tags for tag in tags} num_tags = len(labels_tags_unique) + labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs} + num_custom_fields = len(labels_custom_fields_unique) + # Check if retraining is actually required. # A document has been updated since the classifier was trained - # New auto tags, types, correspondent, storage paths exist + # New auto tags, types, correspondent, storage paths or custom fields exist latest_doc_change = docs_queryset.latest("modified").modified if ( self.last_doc_change_time is not None @@ -253,7 +272,8 @@ class DocumentClassifier: logger.debug( f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), " - f"{num_document_types} document type(s). {num_storage_paths} storage path(s)", + f"{num_document_types} document type(s), {num_storage_paths} storage path(s), " + f"{num_custom_fields} custom field(s)", ) from sklearn.feature_extraction.text import CountVectorizer @@ -345,6 +365,39 @@ class DocumentClassifier: "There are no storage paths. Not training storage path classifier.", ) + if num_custom_fields > 0: + logger.debug("Training custom fields classifier...") + + if num_custom_fields == 1: + # Special case where only one custom field has auto: + # Fallback to binary classification. + labels_custom_fields = [ + label[0] if len(label) == 1 else -1 + for label in labels_custom_fields + ] + self.custom_fields_binarizer = LabelBinarizer() + labels_custom_fields_vectorized = ( + self.custom_fields_binarizer.fit_transform( + labels_custom_fields, + ).ravel() + ) + else: + self.custom_fields_binarizer = MultiLabelBinarizer() + labels_custom_fields_vectorized = ( + self.custom_fields_binarizer.fit_transform(labels_custom_fields) + ) + + self.custom_fields_classifier = MLPClassifier(tol=0.01) + self.custom_fields_classifier.fit( + data_vectorized, + labels_custom_fields_vectorized, + ) + else: + self.custom_fields_classifier = None + logger.debug( + "There are no custom fields. Not training custom fields classifier.", + ) + self.last_doc_change_time = latest_doc_change self.last_auto_type_hash = hasher.digest() @@ -472,3 +525,24 @@ class DocumentClassifier: return None else: return None + + def predict_custom_fields(self, content: str) -> list[int]: + from sklearn.utils.multiclass import type_of_target + + if self.custom_fields_classifier: + X = self.data_vectorizer.transform([self.preprocess_content(content)]) + y = self.custom_fields_classifier.predict(X) + custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0] + if type_of_target(y).startswith("multilabel"): + # the usual case when there are multiple custom fields. + return list(custom_fields_ids) + elif type_of_target(y) == "binary" and custom_fields_ids != -1: + # This is for when we have binary classification with only one + # custom field and the result is to assign this custom field. + return [custom_fields_ids] + else: + # Usually binary as well with -1 as the result, but we're + # going to catch everything else here as well. + return [] + else: + return []