From d376f9e7a3b2ad2fb3b913545a7ee5a2828edd99 Mon Sep 17 00:00:00 2001 From: Trenton Holmes <797416+stumpylog@users.noreply.github.com> Date: Sun, 23 Jul 2023 16:49:20 -0700 Subject: [PATCH] Adding more typing around the classification and matching --- src/documents/classifier.py | 18 ++++++++++-------- src/documents/matching.py | 14 ++++++++------ src/documents/signals/handlers.py | 22 ++++++++++++---------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 3dd1a60aa..5ed203934 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -5,6 +5,7 @@ import re import warnings from datetime import datetime from hashlib import sha256 +from pathlib import Path from typing import Iterator from typing import List from typing import Optional @@ -81,7 +82,7 @@ class DocumentClassifier: self._stemmer = None self._stop_words = None - def load(self): + def load(self) -> None: # Catch warnings for processing with warnings.catch_warnings(record=True) as w: with open(settings.MODEL_FILE, "rb") as f: @@ -120,19 +121,20 @@ class DocumentClassifier: raise IncompatibleClassifierVersionError def save(self): - target_file = settings.MODEL_FILE - target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part") + target_file: Path = settings.MODEL_FILE + target_file_temp = target_file.with_suffix(".pickle.part") with open(target_file_temp, "wb") as f: pickle.dump(self.FORMAT_VERSION, f) + pickle.dump(self.last_doc_change_time, f) pickle.dump(self.last_auto_type_hash, f) pickle.dump(self.data_vectorizer, f) pickle.dump(self.tags_binarizer, f) - pickle.dump(self.tags_classifier, f) + pickle.dump(self.correspondent_classifier, f) pickle.dump(self.document_type_classifier, f) pickle.dump(self.storage_path_classifier, f) @@ -380,7 +382,7 @@ class DocumentClassifier: return content - def predict_correspondent(self, content: str): + def predict_correspondent(self, content: str) -> Optional[int]: if self.correspondent_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) correspondent_id = self.correspondent_classifier.predict(X) @@ -391,7 +393,7 @@ class DocumentClassifier: else: return None - def predict_document_type(self, content: str): + def predict_document_type(self, content: str) -> Optional[int]: if self.document_type_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) document_type_id = self.document_type_classifier.predict(X) @@ -402,7 +404,7 @@ class DocumentClassifier: else: return None - def predict_tags(self, content: str): + def predict_tags(self, content: str) -> List[int]: from sklearn.utils.multiclass import type_of_target if self.tags_classifier: @@ -423,7 +425,7 @@ class DocumentClassifier: else: return [] - def predict_storage_path(self, content: str): + def predict_storage_path(self, content: str) -> Optional[int]: if self.storage_path_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) storage_path_id = self.storage_path_classifier.predict(X) diff --git a/src/documents/matching.py b/src/documents/matching.py index a7ceb5a5a..eb0f4f8b5 100644 --- a/src/documents/matching.py +++ b/src/documents/matching.py @@ -1,7 +1,9 @@ import logging import re +from documents.classifier import DocumentClassifier from documents.models import Correspondent +from documents.models import Document from documents.models import DocumentType from documents.models import MatchingModel from documents.models import StoragePath @@ -11,7 +13,7 @@ from documents.permissions import get_objects_for_user_owner_aware logger = logging.getLogger("paperless.matching") -def log_reason(matching_model, document, reason): +def log_reason(matching_model: MatchingModel, document: Document, reason: str): class_name = type(matching_model).__name__ logger.debug( f"{class_name} {matching_model.name} matched on document " @@ -19,7 +21,7 @@ def log_reason(matching_model, document, reason): ) -def match_correspondents(document, classifier, user=None): +def match_correspondents(document: Document, classifier: DocumentClassifier, user=None): pred_id = classifier.predict_correspondent(document.content) if classifier else None if user is None and document.owner is not None: @@ -43,7 +45,7 @@ def match_correspondents(document, classifier, user=None): ) -def match_document_types(document, classifier, user=None): +def match_document_types(document: Document, classifier: DocumentClassifier, user=None): pred_id = classifier.predict_document_type(document.content) if classifier else None if user is None and document.owner is not None: @@ -67,7 +69,7 @@ def match_document_types(document, classifier, user=None): ) -def match_tags(document, classifier, user=None): +def match_tags(document: Document, classifier: DocumentClassifier, user=None): predicted_tag_ids = classifier.predict_tags(document.content) if classifier else [] if user is None and document.owner is not None: @@ -90,7 +92,7 @@ def match_tags(document, classifier, user=None): ) -def match_storage_paths(document, classifier, user=None): +def match_storage_paths(document: Document, classifier: DocumentClassifier, user=None): pred_id = classifier.predict_storage_path(document.content) if classifier else None if user is None and document.owner is not None: @@ -114,7 +116,7 @@ def match_storage_paths(document, classifier, user=None): ) -def matches(matching_model, document): +def matches(matching_model: MatchingModel, document: Document): search_kwargs = {} document_content = document.content diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 4a39d98ea..4e0d13c20 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -1,6 +1,7 @@ import logging import os import shutil +from typing import Optional from celery import states from celery.signals import before_task_publish @@ -21,6 +22,7 @@ from django.utils import timezone from filelock import FileLock from documents import matching +from documents.classifier import DocumentClassifier from documents.file_handling import create_source_path_directory from documents.file_handling import delete_empty_directories from documents.file_handling import generate_unique_filename @@ -33,7 +35,7 @@ from documents.permissions import get_objects_for_user_owner_aware logger = logging.getLogger("paperless.handlers") -def add_inbox_tags(sender, document=None, logging_group=None, **kwargs): +def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs): if document.owner is not None: tags = get_objects_for_user_owner_aware( document.owner, @@ -48,9 +50,9 @@ def add_inbox_tags(sender, document=None, logging_group=None, **kwargs): def set_correspondent( sender, - document=None, + document: Document, logging_group=None, - classifier=None, + classifier: Optional[DocumentClassifier] = None, replace=False, use_first=True, suggest=False, @@ -111,9 +113,9 @@ def set_correspondent( def set_document_type( sender, - document=None, + document: Document, logging_group=None, - classifier=None, + classifier: Optional[DocumentClassifier] = None, replace=False, use_first=True, suggest=False, @@ -175,9 +177,9 @@ def set_document_type( def set_tags( sender, - document=None, + document: Document, logging_group=None, - classifier=None, + classifier: Optional[DocumentClassifier] = None, replace=False, suggest=False, base_url=None, @@ -239,9 +241,9 @@ def set_tags( def set_storage_path( sender, - document=None, + document: Document, logging_group=None, - classifier=None, + classifier: Optional[DocumentClassifier] = None, replace=False, use_first=True, suggest=False, @@ -491,7 +493,7 @@ def update_filename_and_move_files(sender, instance: Document, **kwargs): ) -def set_log_entry(sender, document=None, logging_group=None, **kwargs): +def set_log_entry(sender, document: Document, logging_group=None, **kwargs): ct = ContentType.objects.get(model="document") user = User.objects.get(username="consumer")