mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	the document classifier is now stateless
This commit is contained in:
		| @@ -34,7 +34,6 @@ class DocumentClassifier(object): | |||||||
|         self.tags_classifier = None |         self.tags_classifier = None | ||||||
|         self.correspondent_classifier = None |         self.correspondent_classifier = None | ||||||
|         self.document_type_classifier = None |         self.document_type_classifier = None | ||||||
|         self.X = None |  | ||||||
|  |  | ||||||
|     def reload(self): |     def reload(self): | ||||||
|         if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: |         if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: | ||||||
| @@ -167,14 +166,10 @@ class DocumentClassifier(object): | |||||||
|                 "classifier." |                 "classifier." | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def update(self, document): |     def predict_correspondent(self, content): | ||||||
|         self.X = self.data_vectorizer.transform( |  | ||||||
|             [preprocess_content(document.content)] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def predict_correspondent(self): |  | ||||||
|         if self.correspondent_classifier: |         if self.correspondent_classifier: | ||||||
|             y = self.correspondent_classifier.predict(self.X) |             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||||
|  |             y = self.correspondent_classifier.predict(X) | ||||||
|             correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] |             correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] | ||||||
|             if correspondent_id != -1: |             if correspondent_id != -1: | ||||||
|                 return correspondent_id |                 return correspondent_id | ||||||
| @@ -183,9 +178,10 @@ class DocumentClassifier(object): | |||||||
|         else: |         else: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def predict_document_type(self): |     def predict_document_type(self, content): | ||||||
|         if self.document_type_classifier: |         if self.document_type_classifier: | ||||||
|             y = self.document_type_classifier.predict(self.X) |             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||||
|  |             y = self.document_type_classifier.predict(X) | ||||||
|             document_type_id = self.document_type_binarizer.inverse_transform(y)[0] |             document_type_id = self.document_type_binarizer.inverse_transform(y)[0] | ||||||
|             if document_type_id != -1: |             if document_type_id != -1: | ||||||
|                 return document_type_id |                 return document_type_id | ||||||
| @@ -194,9 +190,10 @@ class DocumentClassifier(object): | |||||||
|         else: |         else: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def predict_tags(self): |     def predict_tags(self, content): | ||||||
|         if self.tags_classifier: |         if self.tags_classifier: | ||||||
|             y = self.tags_classifier.predict(self.X) |             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||||
|  |             y = self.tags_classifier.predict(X) | ||||||
|             tags_ids = self.tags_binarizer.inverse_transform(y)[0] |             tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||||
|             return tags_ids |             return tags_ids | ||||||
|         else: |         else: | ||||||
|   | |||||||
| @@ -182,7 +182,6 @@ class Consumer: | |||||||
|  |  | ||||||
|             try: |             try: | ||||||
|                 self.classifier.reload() |                 self.classifier.reload() | ||||||
|                 self.classifier.update(document) |  | ||||||
|                 classifier = self.classifier |                 classifier = self.classifier | ||||||
|             except FileNotFoundError: |             except FileNotFoundError: | ||||||
|                 logging.getLogger(__name__).warning("Cannot classify documents, " |                 logging.getLogger(__name__).warning("Cannot classify documents, " | ||||||
|   | |||||||
| @@ -84,9 +84,6 @@ class Command(Renderable, BaseCommand): | |||||||
|                 "Processing document {}".format(document.title) |                 "Processing document {}".format(document.title) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             if classifier: |  | ||||||
|                 classifier.update(document) |  | ||||||
|  |  | ||||||
|             if options['correspondent']: |             if options['correspondent']: | ||||||
|                 set_correspondent( |                 set_correspondent( | ||||||
|                     sender=None, |                     sender=None, | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ from documents.models import MatchingModel, Correspondent, DocumentType, Tag | |||||||
|  |  | ||||||
| def match_correspondents(document_content, classifier): | def match_correspondents(document_content, classifier): | ||||||
|     correspondents = Correspondent.objects.all() |     correspondents = Correspondent.objects.all() | ||||||
|     predicted_correspondent_id = classifier.predict_correspondent() if classifier else None |     predicted_correspondent_id = classifier.predict_correspondent(document_content) if classifier else None | ||||||
|  |  | ||||||
|     matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id] |     matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id] | ||||||
|     return matched_correspondents |     return matched_correspondents | ||||||
| @@ -15,7 +15,7 @@ def match_correspondents(document_content, classifier): | |||||||
|  |  | ||||||
| def match_document_types(document_content, classifier): | def match_document_types(document_content, classifier): | ||||||
|     document_types = DocumentType.objects.all() |     document_types = DocumentType.objects.all() | ||||||
|     predicted_document_type_id = classifier.predict_document_type() if classifier else None |     predicted_document_type_id = classifier.predict_document_type(document_content) if classifier else None | ||||||
|  |  | ||||||
|     matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id] |     matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id] | ||||||
|     return matched_document_types |     return matched_document_types | ||||||
| @@ -23,7 +23,7 @@ def match_document_types(document_content, classifier): | |||||||
|  |  | ||||||
| def match_tags(document_content, classifier): | def match_tags(document_content, classifier): | ||||||
|     objects = Tag.objects.all() |     objects = Tag.objects.all() | ||||||
|     predicted_tag_ids = classifier.predict_tags() if classifier else [] |     predicted_tag_ids = classifier.predict_tags(document_content) if classifier else [] | ||||||
|  |  | ||||||
|     matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids] |     matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids] | ||||||
|     return matched_tags |     return matched_tags | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jonas Winkler
					Jonas Winkler