Adding more typing around the classification and matching

This commit is contained in:
Trenton Holmes 2023-07-23 16:49:20 -07:00 committed by Trenton H
parent 07e7bcd30b
commit d376f9e7a3
3 changed files with 30 additions and 24 deletions

View File

@ -5,6 +5,7 @@ import re
import warnings import warnings
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from pathlib import Path
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -81,7 +82,7 @@ class DocumentClassifier:
self._stemmer = None self._stemmer = None
self._stop_words = None self._stop_words = None
def load(self): def load(self) -> None:
# Catch warnings for processing # Catch warnings for processing
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
with open(settings.MODEL_FILE, "rb") as f: with open(settings.MODEL_FILE, "rb") as f:
@ -120,19 +121,20 @@ class DocumentClassifier:
raise IncompatibleClassifierVersionError raise IncompatibleClassifierVersionError
def save(self): def save(self):
target_file = settings.MODEL_FILE target_file: Path = settings.MODEL_FILE
target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part") target_file_temp = target_file.with_suffix(".pickle.part")
with open(target_file_temp, "wb") as f: with open(target_file_temp, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f) pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.last_doc_change_time, f) pickle.dump(self.last_doc_change_time, f)
pickle.dump(self.last_auto_type_hash, f) pickle.dump(self.last_auto_type_hash, f)
pickle.dump(self.data_vectorizer, f) pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f) pickle.dump(self.tags_binarizer, f)
pickle.dump(self.tags_classifier, f) pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f) pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.document_type_classifier, f) pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f) pickle.dump(self.storage_path_classifier, f)
@ -380,7 +382,7 @@ class DocumentClassifier:
return content return content
def predict_correspondent(self, content: str): def predict_correspondent(self, content: str) -> Optional[int]:
if self.correspondent_classifier: if self.correspondent_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X) correspondent_id = self.correspondent_classifier.predict(X)
@ -391,7 +393,7 @@ class DocumentClassifier:
else: else:
return None return None
def predict_document_type(self, content: str): def predict_document_type(self, content: str) -> Optional[int]:
if self.document_type_classifier: if self.document_type_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X) document_type_id = self.document_type_classifier.predict(X)
@ -402,7 +404,7 @@ class DocumentClassifier:
else: else:
return None return None
def predict_tags(self, content: str): def predict_tags(self, content: str) -> List[int]:
from sklearn.utils.multiclass import type_of_target from sklearn.utils.multiclass import type_of_target
if self.tags_classifier: if self.tags_classifier:
@ -423,7 +425,7 @@ class DocumentClassifier:
else: else:
return [] return []
def predict_storage_path(self, content: str): def predict_storage_path(self, content: str) -> Optional[int]:
if self.storage_path_classifier: if self.storage_path_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X) storage_path_id = self.storage_path_classifier.predict(X)

View File

@ -1,7 +1,9 @@
import logging import logging
import re import re
from documents.classifier import DocumentClassifier
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType from documents.models import DocumentType
from documents.models import MatchingModel from documents.models import MatchingModel
from documents.models import StoragePath from documents.models import StoragePath
@ -11,7 +13,7 @@ from documents.permissions import get_objects_for_user_owner_aware
logger = logging.getLogger("paperless.matching") logger = logging.getLogger("paperless.matching")
def log_reason(matching_model, document, reason): def log_reason(matching_model: MatchingModel, document: Document, reason: str):
class_name = type(matching_model).__name__ class_name = type(matching_model).__name__
logger.debug( logger.debug(
f"{class_name} {matching_model.name} matched on document " f"{class_name} {matching_model.name} matched on document "
@ -19,7 +21,7 @@ def log_reason(matching_model, document, reason):
) )
def match_correspondents(document, classifier, user=None): def match_correspondents(document: Document, classifier: DocumentClassifier, user=None):
pred_id = classifier.predict_correspondent(document.content) if classifier else None pred_id = classifier.predict_correspondent(document.content) if classifier else None
if user is None and document.owner is not None: if user is None and document.owner is not None:
@ -43,7 +45,7 @@ def match_correspondents(document, classifier, user=None):
) )
def match_document_types(document, classifier, user=None): def match_document_types(document: Document, classifier: DocumentClassifier, user=None):
pred_id = classifier.predict_document_type(document.content) if classifier else None pred_id = classifier.predict_document_type(document.content) if classifier else None
if user is None and document.owner is not None: if user is None and document.owner is not None:
@ -67,7 +69,7 @@ def match_document_types(document, classifier, user=None):
) )
def match_tags(document, classifier, user=None): def match_tags(document: Document, classifier: DocumentClassifier, user=None):
predicted_tag_ids = classifier.predict_tags(document.content) if classifier else [] predicted_tag_ids = classifier.predict_tags(document.content) if classifier else []
if user is None and document.owner is not None: if user is None and document.owner is not None:
@ -90,7 +92,7 @@ def match_tags(document, classifier, user=None):
) )
def match_storage_paths(document, classifier, user=None): def match_storage_paths(document: Document, classifier: DocumentClassifier, user=None):
pred_id = classifier.predict_storage_path(document.content) if classifier else None pred_id = classifier.predict_storage_path(document.content) if classifier else None
if user is None and document.owner is not None: if user is None and document.owner is not None:
@ -114,7 +116,7 @@ def match_storage_paths(document, classifier, user=None):
) )
def matches(matching_model, document): def matches(matching_model: MatchingModel, document: Document):
search_kwargs = {} search_kwargs = {}
document_content = document.content document_content = document.content

View File

@ -1,6 +1,7 @@
import logging import logging
import os import os
import shutil import shutil
from typing import Optional
from celery import states from celery import states
from celery.signals import before_task_publish from celery.signals import before_task_publish
@ -21,6 +22,7 @@ from django.utils import timezone
from filelock import FileLock from filelock import FileLock
from documents import matching from documents import matching
from documents.classifier import DocumentClassifier
from documents.file_handling import create_source_path_directory from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories from documents.file_handling import delete_empty_directories
from documents.file_handling import generate_unique_filename from documents.file_handling import generate_unique_filename
@ -33,7 +35,7 @@ from documents.permissions import get_objects_for_user_owner_aware
logger = logging.getLogger("paperless.handlers") logger = logging.getLogger("paperless.handlers")
def add_inbox_tags(sender, document=None, logging_group=None, **kwargs): def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
if document.owner is not None: if document.owner is not None:
tags = get_objects_for_user_owner_aware( tags = get_objects_for_user_owner_aware(
document.owner, document.owner,
@ -48,9 +50,9 @@ def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
def set_correspondent( def set_correspondent(
sender, sender,
document=None, document: Document,
logging_group=None, logging_group=None,
classifier=None, classifier: Optional[DocumentClassifier] = None,
replace=False, replace=False,
use_first=True, use_first=True,
suggest=False, suggest=False,
@ -111,9 +113,9 @@ def set_correspondent(
def set_document_type( def set_document_type(
sender, sender,
document=None, document: Document,
logging_group=None, logging_group=None,
classifier=None, classifier: Optional[DocumentClassifier] = None,
replace=False, replace=False,
use_first=True, use_first=True,
suggest=False, suggest=False,
@ -175,9 +177,9 @@ def set_document_type(
def set_tags( def set_tags(
sender, sender,
document=None, document: Document,
logging_group=None, logging_group=None,
classifier=None, classifier: Optional[DocumentClassifier] = None,
replace=False, replace=False,
suggest=False, suggest=False,
base_url=None, base_url=None,
@ -239,9 +241,9 @@ def set_tags(
def set_storage_path( def set_storage_path(
sender, sender,
document=None, document: Document,
logging_group=None, logging_group=None,
classifier=None, classifier: Optional[DocumentClassifier] = None,
replace=False, replace=False,
use_first=True, use_first=True,
suggest=False, suggest=False,
@ -491,7 +493,7 @@ def update_filename_and_move_files(sender, instance: Document, **kwargs):
) )
def set_log_entry(sender, document=None, logging_group=None, **kwargs): def set_log_entry(sender, document: Document, logging_group=None, **kwargs):
ct = ContentType.objects.get(model="document") ct = ContentType.objects.get(model="document")
user = User.objects.get(username="consumer") user = User.objects.get(username="consumer")