mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Code style changes
This commit is contained in:
		| @@ -165,8 +165,9 @@ def remove_document_type_from_selected(modeladmin, request, queryset): | |||||||
|  |  | ||||||
|  |  | ||||||
| def run_document_classifier_on_selected(modeladmin, request, queryset): | def run_document_classifier_on_selected(modeladmin, request, queryset): | ||||||
|  |     clf = DocumentClassifier() | ||||||
|     try: |     try: | ||||||
|         clf = DocumentClassifier.load_classifier() |         clf.reload() | ||||||
|         return simple_action( |         return simple_action( | ||||||
|             modeladmin=modeladmin, |             modeladmin=modeladmin, | ||||||
|             request=request, |             request=request, | ||||||
| @@ -201,4 +202,3 @@ remove_document_type_from_selected.short_description = \ | |||||||
|     "Remove document type from selected documents" |     "Remove document type from selected documents" | ||||||
| run_document_classifier_on_selected.short_description = \ | run_document_classifier_on_selected.short_description = \ | ||||||
|     "Run document classifier on selected" |     "Run document classifier on selected" | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,14 +2,13 @@ import logging | |||||||
| import os | import os | ||||||
| import pickle | import pickle | ||||||
|  |  | ||||||
|  | from sklearn.feature_extraction.text import CountVectorizer | ||||||
| from sklearn.neural_network import MLPClassifier | from sklearn.neural_network import MLPClassifier | ||||||
|  | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | ||||||
|  |  | ||||||
| from documents.models import Correspondent, DocumentType, Tag, Document | from documents.models import Correspondent, DocumentType, Tag, Document | ||||||
| from paperless import settings | from paperless import settings | ||||||
|  |  | ||||||
| from sklearn.feature_extraction.text import CountVectorizer |  | ||||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def preprocess_content(content): | def preprocess_content(content): | ||||||
|     content = content.lower() |     content = content.lower() | ||||||
| @@ -23,26 +22,21 @@ def preprocess_content(content): | |||||||
|  |  | ||||||
| class DocumentClassifier(object): | class DocumentClassifier(object): | ||||||
|  |  | ||||||
|     classifier_version = None |     def __init__(self): | ||||||
|  |         self.classifier_version = 0 | ||||||
|  |  | ||||||
|     data_vectorizer = None |         self.data_vectorizer = None | ||||||
|  |  | ||||||
|     tags_binarizer = None |         self.tags_binarizer = None | ||||||
|     correspondent_binarizer = None |         self.correspondent_binarizer = None | ||||||
|     document_type_binarizer = None |         self.document_type_binarizer = None | ||||||
|  |  | ||||||
|     tags_classifier = None |         self.tags_classifier = None | ||||||
|     correspondent_classifier = None |         self.correspondent_classifier = None | ||||||
|     document_type_classifier = None |         self.document_type_classifier = None | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def load_classifier(): |  | ||||||
|         clf = DocumentClassifier() |  | ||||||
|         clf.reload() |  | ||||||
|         return clf |  | ||||||
|  |  | ||||||
|     def reload(self): |     def reload(self): | ||||||
|         if self.classifier_version is None or os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: |         if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: | ||||||
|             logging.getLogger(__name__).info("Reloading classifier models") |             logging.getLogger(__name__).info("Reloading classifier models") | ||||||
|             with open(settings.MODEL_FILE, "rb") as f: |             with open(settings.MODEL_FILE, "rb") as f: | ||||||
|                 self.data_vectorizer = pickle.load(f) |                 self.data_vectorizer = pickle.load(f) | ||||||
| @@ -77,27 +71,54 @@ class DocumentClassifier(object): | |||||||
|         logging.getLogger(__name__).info("Gathering data from database...") |         logging.getLogger(__name__).info("Gathering data from database...") | ||||||
|         for doc in Document.objects.exclude(tags__is_inbox_tag=True): |         for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||||
|             data.append(preprocess_content(doc.content)) |             data.append(preprocess_content(doc.content)) | ||||||
|             labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1) |  | ||||||
|             labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1) |             y = -1 | ||||||
|             tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)] |             if doc.document_type: | ||||||
|  |                 if doc.document_type.automatic_classification: | ||||||
|  |                     y = doc.document_type.id | ||||||
|  |             labels_document_type.append(y) | ||||||
|  |  | ||||||
|  |             y = -1 | ||||||
|  |             if doc.correspondent: | ||||||
|  |                 if doc.correspondent.automatic_classification: | ||||||
|  |                     y = doc.correspondent.id | ||||||
|  |             labels_correspondent.append(y) | ||||||
|  |  | ||||||
|  |             tags = [tag.id for tag in doc.tags.filter( | ||||||
|  |                 automatic_classification=True | ||||||
|  |             )] | ||||||
|             labels_tags.append(tags) |             labels_tags.append(tags) | ||||||
|  |  | ||||||
|         labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) |         labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) | ||||||
|         logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type)))) |         logging.getLogger(__name__).info( | ||||||
|  |             "{} documents, {} tag(s), {} correspondent(s), " | ||||||
|  |             "{} document type(s).".format( | ||||||
|  |                 len(data), | ||||||
|  |                 len(labels_tags_unique), | ||||||
|  |                 len(set(labels_correspondent)), | ||||||
|  |                 len(set(labels_document_type)) | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # Step 2: vectorize data |         # Step 2: vectorize data | ||||||
|         logging.getLogger(__name__).info("Vectorizing data...") |         logging.getLogger(__name__).info("Vectorizing data...") | ||||||
|         self.data_vectorizer = CountVectorizer(analyzer="char", ngram_range=(3, 5), min_df=0.1) |         self.data_vectorizer = CountVectorizer( | ||||||
|  |             analyzer="char", | ||||||
|  |             ngram_range=(3, 5), | ||||||
|  |             min_df=0.1 | ||||||
|  |         ) | ||||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) |         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||||
|  |  | ||||||
|         self.tags_binarizer = MultiLabelBinarizer() |         self.tags_binarizer = MultiLabelBinarizer() | ||||||
|         labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) |         labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) | ||||||
|  |  | ||||||
|         self.correspondent_binarizer = LabelBinarizer() |         self.correspondent_binarizer = LabelBinarizer() | ||||||
|         labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) |         labels_correspondent_vectorized = \ | ||||||
|  |             self.correspondent_binarizer.fit_transform(labels_correspondent) | ||||||
|  |  | ||||||
|         self.document_type_binarizer = LabelBinarizer() |         self.document_type_binarizer = LabelBinarizer() | ||||||
|         labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type) |         labels_document_type_vectorized = \ | ||||||
|  |             self.document_type_binarizer.fit_transform(labels_document_type) | ||||||
|  |  | ||||||
|         # Step 3: train the classifiers |         # Step 3: train the classifiers | ||||||
|         if len(self.tags_binarizer.classes_) > 0: |         if len(self.tags_binarizer.classes_) > 0: | ||||||
| @@ -106,62 +127,114 @@ class DocumentClassifier(object): | |||||||
|             self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) |             self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) | ||||||
|         else: |         else: | ||||||
|             self.tags_classifier = None |             self.tags_classifier = None | ||||||
|             logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") |             logging.getLogger(__name__).info( | ||||||
|  |                 "There are no tags. Not training tags classifier." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if len(self.correspondent_binarizer.classes_) > 0: |         if len(self.correspondent_binarizer.classes_) > 0: | ||||||
|             logging.getLogger(__name__).info("Training correspondent classifier...") |             logging.getLogger(__name__).info( | ||||||
|  |                 "Training correspondent classifier..." | ||||||
|  |             ) | ||||||
|             self.correspondent_classifier = MLPClassifier(verbose=True) |             self.correspondent_classifier = MLPClassifier(verbose=True) | ||||||
|             self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) |             self.correspondent_classifier.fit( | ||||||
|  |                 data_vectorized, | ||||||
|  |                 labels_correspondent_vectorized | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             self.correspondent_classifier = None |             self.correspondent_classifier = None | ||||||
|             logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") |             logging.getLogger(__name__).info( | ||||||
|  |                 "There are no correspondents. Not training correspondent " | ||||||
|  |                 "classifier." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         if len(self.document_type_binarizer.classes_) > 0: |         if len(self.document_type_binarizer.classes_) > 0: | ||||||
|             logging.getLogger(__name__).info("Training document type classifier...") |             logging.getLogger(__name__).info( | ||||||
|  |                 "Training document type classifier..." | ||||||
|  |             ) | ||||||
|             self.document_type_classifier = MLPClassifier(verbose=True) |             self.document_type_classifier = MLPClassifier(verbose=True) | ||||||
|             self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized) |             self.document_type_classifier.fit( | ||||||
|  |                 data_vectorized, | ||||||
|  |                 labels_document_type_vectorized | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             self.document_type_classifier = None |             self.document_type_classifier = None | ||||||
|             logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") |             logging.getLogger(__name__).info( | ||||||
|  |                 "There are no document types. Not training document type " | ||||||
|  |                 "classifier." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False): |     def classify_document( | ||||||
|         X = self.data_vectorizer.transform([preprocess_content(document.content)]) |             self, document, classify_correspondent=False, | ||||||
|  |             classify_document_type=False, classify_tags=False, | ||||||
|  |             replace_tags=False): | ||||||
|  |  | ||||||
|         update_fields = () |         X = self.data_vectorizer.transform( | ||||||
|  |             [preprocess_content(document.content)] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         if classify_correspondent and self.correspondent_classifier is not None: |         if classify_correspondent and self.correspondent_classifier: | ||||||
|             y_correspondent = self.correspondent_classifier.predict(X) |             self._classify_correspondent(X, document) | ||||||
|             correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] |  | ||||||
|  |         if classify_document_type and self.document_type_classifier: | ||||||
|  |             self._classify_document_type(X, document) | ||||||
|  |  | ||||||
|  |         if classify_tags and self.tags_classifier: | ||||||
|  |             self._classify_tags(X, document, replace_tags) | ||||||
|  |  | ||||||
|  |         document.save(update_fields=("correspondent", "document_type")) | ||||||
|  |  | ||||||
|  |     def _classify_correspondent(self, X, document): | ||||||
|  |         y = self.correspondent_classifier.predict(X) | ||||||
|  |         correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0] | ||||||
|  |         try: | ||||||
|  |             correspondent = None | ||||||
|  |             if correspondent_id != -1: | ||||||
|  |                 correspondent = Correspondent.objects.get(id=correspondent_id) | ||||||
|  |                 logging.getLogger(__name__).info( | ||||||
|  |                     "Detected correspondent: {}".format(correspondent.name) | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 logging.getLogger(__name__).info("Detected correspondent: -") | ||||||
|  |             document.correspondent = correspondent | ||||||
|  |         except Correspondent.DoesNotExist: | ||||||
|  |             logging.getLogger(__name__).warning( | ||||||
|  |                 "Detected correspondent with id {} does not exist " | ||||||
|  |                 "anymore! Did you delete it?".format(correspondent_id) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def _classify_document_type(self, X, document): | ||||||
|  |         y = self.document_type_classifier.predict(X) | ||||||
|  |         document_type_id = self.document_type_binarizer.inverse_transform(y)[0] | ||||||
|  |         try: | ||||||
|  |             document_type = None | ||||||
|  |             if document_type_id != -1: | ||||||
|  |                 document_type = DocumentType.objects.get(id=document_type_id) | ||||||
|  |                 logging.getLogger(__name__).info( | ||||||
|  |                     "Detected document type: {}".format(document_type.name) | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 logging.getLogger(__name__).info("Detected document type: -") | ||||||
|  |             document.document_type = document_type | ||||||
|  |         except DocumentType.DoesNotExist: | ||||||
|  |             logging.getLogger(__name__).warning( | ||||||
|  |                 "Detected document type with id {} does not exist " | ||||||
|  |                 "anymore! Did you delete it?".format(document_type_id) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def _classify_tags(self, X, document, replace_tags): | ||||||
|  |         y = self.tags_classifier.predict(X) | ||||||
|  |         tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||||
|  |         if replace_tags: | ||||||
|  |             document.tags.clear() | ||||||
|  |         for tag_id in tags_ids: | ||||||
|             try: |             try: | ||||||
|                 correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None |                 tag = Tag.objects.get(id=tag_id) | ||||||
|                 logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-")) |                 logging.getLogger(__name__).info( | ||||||
|                 document.correspondent = correspondent |                     "Detected tag: {}".format(tag.name) | ||||||
|                 update_fields = update_fields + ("correspondent",) |                 ) | ||||||
|             except Correspondent.DoesNotExist: |                 document.tags.add(tag) | ||||||
|                 logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id)) |             except Tag.DoesNotExist: | ||||||
|  |                 logging.getLogger(__name__).warning( | ||||||
|         if classify_document_type and self.document_type_classifier is not None: |                     "Detected tag with id {} does not exist anymore! Did " | ||||||
|             y_type = self.document_type_classifier.predict(X) |                     "you delete it?".format(tag_id) | ||||||
|             type_id = self.document_type_binarizer.inverse_transform(y_type)[0] |                 ) | ||||||
|             try: |  | ||||||
|                 document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None |  | ||||||
|                 logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-")) |  | ||||||
|                 document.document_type = document_type |  | ||||||
|                 update_fields = update_fields + ("document_type",) |  | ||||||
|             except DocumentType.DoesNotExist: |  | ||||||
|                 logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id)) |  | ||||||
|  |  | ||||||
|         if classify_tags and self.tags_classifier is not None: |  | ||||||
|             y_tags = self.tags_classifier.predict(X) |  | ||||||
|             tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0] |  | ||||||
|             if replace_tags: |  | ||||||
|                 document.tags.clear() |  | ||||||
|             for tag_id in tags_ids: |  | ||||||
|                 try: |  | ||||||
|                     tag = Tag.objects.get(id=tag_id) |  | ||||||
|                     document.tags.add(tag) |  | ||||||
|                     logging.getLogger(__name__).info("Detected tag: {}".format(tag.name)) |  | ||||||
|                 except Tag.DoesNotExist: |  | ||||||
|                     logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id)) |  | ||||||
|  |  | ||||||
|         document.save(update_fields=update_fields) |  | ||||||
|   | |||||||
| @@ -54,8 +54,9 @@ class Command(Renderable, BaseCommand): | |||||||
|         documents = queryset.distinct() |         documents = queryset.distinct() | ||||||
|  |  | ||||||
|         logging.getLogger(__name__).info("Loading classifier") |         logging.getLogger(__name__).info("Loading classifier") | ||||||
|  |         clf = DocumentClassifier() | ||||||
|         try: |         try: | ||||||
|             clf = DocumentClassifier.load_classifier() |             clf.reload() | ||||||
|         except FileNotFoundError: |         except FileNotFoundError: | ||||||
|             logging.getLogger(__name__).fatal("Cannot classify documents, " |             logging.getLogger(__name__).fatal("Cannot classify documents, " | ||||||
|                                               "classifier model file was not " |                                               "classifier model file was not " | ||||||
|   | |||||||
| @@ -20,7 +20,13 @@ from rest_framework.viewsets import ( | |||||||
|     ReadOnlyModelViewSet |     ReadOnlyModelViewSet | ||||||
| ) | ) | ||||||
|  |  | ||||||
| from .filters import CorrespondentFilterSet, DocumentFilterSet, TagFilterSet, DocumentTypeFilterSet | from .filters import ( | ||||||
|  |     CorrespondentFilterSet, | ||||||
|  |     DocumentFilterSet, | ||||||
|  |     TagFilterSet, | ||||||
|  |     DocumentTypeFilterSet | ||||||
|  | ) | ||||||
|  |  | ||||||
| from .forms import UploadForm | from .forms import UploadForm | ||||||
| from .models import Correspondent, Document, Log, Tag, DocumentType | from .models import Correspondent, Document, Log, Tag, DocumentType | ||||||
| from .serialisers import ( | from .serialisers import ( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jonas Winkler
					Jonas Winkler