mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Add custom fields to classifier
This commit is contained in:
		| @@ -97,6 +97,8 @@ class DocumentClassifier: | ||||
|         self.correspondent_classifier = None | ||||
|         self.document_type_classifier = None | ||||
|         self.storage_path_classifier = None | ||||
|         self.custom_fields_binarizer = None | ||||
|         self.custom_fields_classifier = None | ||||
|  | ||||
|         self._stemmer = None | ||||
|         self._stop_words = None | ||||
| @@ -120,11 +122,12 @@ class DocumentClassifier: | ||||
|  | ||||
|                         self.data_vectorizer = pickle.load(f) | ||||
|                         self.tags_binarizer = pickle.load(f) | ||||
|  | ||||
|                         self.tags_classifier = pickle.load(f) | ||||
|                         self.correspondent_classifier = pickle.load(f) | ||||
|                         self.document_type_classifier = pickle.load(f) | ||||
|                         self.storage_path_classifier = pickle.load(f) | ||||
|                         self.custom_fields_binarizer = pickle.load(f) | ||||
|                         self.custom_fields_classifier = pickle.load(f) | ||||
|                     except Exception as err: | ||||
|                         raise ClassifierModelCorruptError from err | ||||
|  | ||||
| @@ -162,6 +165,9 @@ class DocumentClassifier: | ||||
|             pickle.dump(self.document_type_classifier, f) | ||||
|             pickle.dump(self.storage_path_classifier, f) | ||||
|  | ||||
|             pickle.dump(self.custom_fields_binarizer, f) | ||||
|             pickle.dump(self.custom_fields_classifier, f) | ||||
|  | ||||
|         target_file_temp.rename(target_file) | ||||
|  | ||||
|     def train(self) -> bool: | ||||
| @@ -183,6 +189,7 @@ class DocumentClassifier: | ||||
|         labels_correspondent = [] | ||||
|         labels_document_type = [] | ||||
|         labels_storage_path = [] | ||||
|         labels_custom_fields = [] | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logger.debug("Gathering data from database...") | ||||
| @@ -218,13 +225,25 @@ class DocumentClassifier: | ||||
|             hasher.update(y.to_bytes(4, "little", signed=True)) | ||||
|             labels_storage_path.append(y) | ||||
|  | ||||
|         labels_tags_unique = {tag for tags in labels_tags for tag in tags} | ||||
|             custom_fields = sorted( | ||||
|                 cf.pk | ||||
|                 for cf in doc.custom_fields.filter( | ||||
|                     field__matching_algorithm=MatchingModel.MATCH_AUTO, | ||||
|                 ) | ||||
|             ) | ||||
|             for cf in custom_fields: | ||||
|                 hasher.update(cf.to_bytes(4, "little", signed=True)) | ||||
|             labels_custom_fields.append(custom_fields) | ||||
|  | ||||
|         labels_tags_unique = {tag for tags in labels_tags for tag in tags} | ||||
|         num_tags = len(labels_tags_unique) | ||||
|  | ||||
|         labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs} | ||||
|         num_custom_fields = len(labels_custom_fields_unique) | ||||
|  | ||||
|         # Check if retraining is actually required. | ||||
|         # A document has been updated since the classifier was trained | ||||
|         # New auto tags, types, correspondent, storage paths exist | ||||
|         # New auto tags, types, correspondent, storage paths or custom fields exist | ||||
|         latest_doc_change = docs_queryset.latest("modified").modified | ||||
|         if ( | ||||
|             self.last_doc_change_time is not None | ||||
| @@ -253,7 +272,8 @@ class DocumentClassifier: | ||||
|  | ||||
|         logger.debug( | ||||
|             f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), " | ||||
|             f"{num_document_types} document type(s). {num_storage_paths} storage path(s)", | ||||
|             f"{num_document_types} document type(s), {num_storage_paths} storage path(s), " | ||||
|             f"{num_custom_fields} custom field(s)", | ||||
|         ) | ||||
|  | ||||
|         from sklearn.feature_extraction.text import CountVectorizer | ||||
| @@ -345,6 +365,39 @@ class DocumentClassifier: | ||||
|                 "There are no storage paths. Not training storage path classifier.", | ||||
|             ) | ||||
|  | ||||
|         if num_custom_fields > 0: | ||||
|             logger.debug("Training custom fields classifier...") | ||||
|  | ||||
|             if num_custom_fields == 1: | ||||
|                 # Special case where only one custom field has auto: | ||||
|                 # Fallback to binary classification. | ||||
|                 labels_custom_fields = [ | ||||
|                     label[0] if len(label) == 1 else -1 | ||||
|                     for label in labels_custom_fields | ||||
|                 ] | ||||
|                 self.custom_fields_binarizer = LabelBinarizer() | ||||
|                 labels_custom_fields_vectorized = ( | ||||
|                     self.custom_fields_binarizer.fit_transform( | ||||
|                         labels_custom_fields, | ||||
|                     ).ravel() | ||||
|                 ) | ||||
|             else: | ||||
|                 self.custom_fields_binarizer = MultiLabelBinarizer() | ||||
|                 labels_custom_fields_vectorized = ( | ||||
|                     self.custom_fields_binarizer.fit_transform(labels_custom_fields) | ||||
|                 ) | ||||
|  | ||||
|             self.custom_fields_classifier = MLPClassifier(tol=0.01) | ||||
|             self.custom_fields_classifier.fit( | ||||
|                 data_vectorized, | ||||
|                 labels_custom_fields_vectorized, | ||||
|             ) | ||||
|         else: | ||||
|             self.custom_fields_classifier = None | ||||
|             logger.debug( | ||||
|                 "There are no custom fields. Not training custom fields classifier.", | ||||
|             ) | ||||
|  | ||||
|         self.last_doc_change_time = latest_doc_change | ||||
|         self.last_auto_type_hash = hasher.digest() | ||||
|  | ||||
| @@ -472,3 +525,24 @@ class DocumentClassifier: | ||||
|                 return None | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     def predict_custom_fields(self, content: str) -> list[int]: | ||||
|         from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
|         if self.custom_fields_classifier: | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             y = self.custom_fields_classifier.predict(X) | ||||
|             custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0] | ||||
|             if type_of_target(y).startswith("multilabel"): | ||||
|                 # the usual case when there are multiple custom fields. | ||||
|                 return list(custom_fields_ids) | ||||
|             elif type_of_target(y) == "binary" and custom_fields_ids != -1: | ||||
|                 # This is for when we have binary classification with only one | ||||
|                 # custom field and the result is to assign this custom field. | ||||
|                 return [custom_fields_ids] | ||||
|             else: | ||||
|                 # Usually binary as well with -1 as the result, but we're | ||||
|                 # going to catch everything else here as well. | ||||
|                 return [] | ||||
|         else: | ||||
|             return [] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 shamoon
					shamoon