mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Adding more typing around the classification and matching
This commit is contained in:
parent
07e7bcd30b
commit
d376f9e7a3
@ -5,6 +5,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -81,7 +82,7 @@ class DocumentClassifier:
|
|||||||
self._stemmer = None
|
self._stemmer = None
|
||||||
self._stop_words = None
|
self._stop_words = None
|
||||||
|
|
||||||
def load(self):
|
def load(self) -> None:
|
||||||
# Catch warnings for processing
|
# Catch warnings for processing
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
with open(settings.MODEL_FILE, "rb") as f:
|
with open(settings.MODEL_FILE, "rb") as f:
|
||||||
@ -120,19 +121,20 @@ class DocumentClassifier:
|
|||||||
raise IncompatibleClassifierVersionError
|
raise IncompatibleClassifierVersionError
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
target_file = settings.MODEL_FILE
|
target_file: Path = settings.MODEL_FILE
|
||||||
target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part")
|
target_file_temp = target_file.with_suffix(".pickle.part")
|
||||||
|
|
||||||
with open(target_file_temp, "wb") as f:
|
with open(target_file_temp, "wb") as f:
|
||||||
pickle.dump(self.FORMAT_VERSION, f)
|
pickle.dump(self.FORMAT_VERSION, f)
|
||||||
|
|
||||||
pickle.dump(self.last_doc_change_time, f)
|
pickle.dump(self.last_doc_change_time, f)
|
||||||
pickle.dump(self.last_auto_type_hash, f)
|
pickle.dump(self.last_auto_type_hash, f)
|
||||||
|
|
||||||
pickle.dump(self.data_vectorizer, f)
|
pickle.dump(self.data_vectorizer, f)
|
||||||
|
|
||||||
pickle.dump(self.tags_binarizer, f)
|
pickle.dump(self.tags_binarizer, f)
|
||||||
|
|
||||||
pickle.dump(self.tags_classifier, f)
|
pickle.dump(self.tags_classifier, f)
|
||||||
|
|
||||||
pickle.dump(self.correspondent_classifier, f)
|
pickle.dump(self.correspondent_classifier, f)
|
||||||
pickle.dump(self.document_type_classifier, f)
|
pickle.dump(self.document_type_classifier, f)
|
||||||
pickle.dump(self.storage_path_classifier, f)
|
pickle.dump(self.storage_path_classifier, f)
|
||||||
@ -380,7 +382,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def predict_correspondent(self, content: str):
|
def predict_correspondent(self, content: str) -> Optional[int]:
|
||||||
if self.correspondent_classifier:
|
if self.correspondent_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
correspondent_id = self.correspondent_classifier.predict(X)
|
correspondent_id = self.correspondent_classifier.predict(X)
|
||||||
@ -391,7 +393,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_document_type(self, content: str):
|
def predict_document_type(self, content: str) -> Optional[int]:
|
||||||
if self.document_type_classifier:
|
if self.document_type_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
document_type_id = self.document_type_classifier.predict(X)
|
document_type_id = self.document_type_classifier.predict(X)
|
||||||
@ -402,7 +404,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_tags(self, content: str):
|
def predict_tags(self, content: str) -> List[int]:
|
||||||
from sklearn.utils.multiclass import type_of_target
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
if self.tags_classifier:
|
if self.tags_classifier:
|
||||||
@ -423,7 +425,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def predict_storage_path(self, content: str):
|
def predict_storage_path(self, content: str) -> Optional[int]:
|
||||||
if self.storage_path_classifier:
|
if self.storage_path_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
storage_path_id = self.storage_path_classifier.predict(X)
|
storage_path_id = self.storage_path_classifier.predict(X)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
from documents.models import Correspondent
|
from documents.models import Correspondent
|
||||||
|
from documents.models import Document
|
||||||
from documents.models import DocumentType
|
from documents.models import DocumentType
|
||||||
from documents.models import MatchingModel
|
from documents.models import MatchingModel
|
||||||
from documents.models import StoragePath
|
from documents.models import StoragePath
|
||||||
@ -11,7 +13,7 @@ from documents.permissions import get_objects_for_user_owner_aware
|
|||||||
logger = logging.getLogger("paperless.matching")
|
logger = logging.getLogger("paperless.matching")
|
||||||
|
|
||||||
|
|
||||||
def log_reason(matching_model, document, reason):
|
def log_reason(matching_model: MatchingModel, document: Document, reason: str):
|
||||||
class_name = type(matching_model).__name__
|
class_name = type(matching_model).__name__
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{class_name} {matching_model.name} matched on document "
|
f"{class_name} {matching_model.name} matched on document "
|
||||||
@ -19,7 +21,7 @@ def log_reason(matching_model, document, reason):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_correspondents(document, classifier, user=None):
|
def match_correspondents(document: Document, classifier: DocumentClassifier, user=None):
|
||||||
pred_id = classifier.predict_correspondent(document.content) if classifier else None
|
pred_id = classifier.predict_correspondent(document.content) if classifier else None
|
||||||
|
|
||||||
if user is None and document.owner is not None:
|
if user is None and document.owner is not None:
|
||||||
@ -43,7 +45,7 @@ def match_correspondents(document, classifier, user=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_document_types(document, classifier, user=None):
|
def match_document_types(document: Document, classifier: DocumentClassifier, user=None):
|
||||||
pred_id = classifier.predict_document_type(document.content) if classifier else None
|
pred_id = classifier.predict_document_type(document.content) if classifier else None
|
||||||
|
|
||||||
if user is None and document.owner is not None:
|
if user is None and document.owner is not None:
|
||||||
@ -67,7 +69,7 @@ def match_document_types(document, classifier, user=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_tags(document, classifier, user=None):
|
def match_tags(document: Document, classifier: DocumentClassifier, user=None):
|
||||||
predicted_tag_ids = classifier.predict_tags(document.content) if classifier else []
|
predicted_tag_ids = classifier.predict_tags(document.content) if classifier else []
|
||||||
|
|
||||||
if user is None and document.owner is not None:
|
if user is None and document.owner is not None:
|
||||||
@ -90,7 +92,7 @@ def match_tags(document, classifier, user=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_storage_paths(document, classifier, user=None):
|
def match_storage_paths(document: Document, classifier: DocumentClassifier, user=None):
|
||||||
pred_id = classifier.predict_storage_path(document.content) if classifier else None
|
pred_id = classifier.predict_storage_path(document.content) if classifier else None
|
||||||
|
|
||||||
if user is None and document.owner is not None:
|
if user is None and document.owner is not None:
|
||||||
@ -114,7 +116,7 @@ def match_storage_paths(document, classifier, user=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def matches(matching_model, document):
|
def matches(matching_model: MatchingModel, document: Document):
|
||||||
search_kwargs = {}
|
search_kwargs = {}
|
||||||
|
|
||||||
document_content = document.content
|
document_content = document.content
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from celery import states
|
from celery import states
|
||||||
from celery.signals import before_task_publish
|
from celery.signals import before_task_publish
|
||||||
@ -21,6 +22,7 @@ from django.utils import timezone
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from documents import matching
|
from documents import matching
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
from documents.file_handling import create_source_path_directory
|
from documents.file_handling import create_source_path_directory
|
||||||
from documents.file_handling import delete_empty_directories
|
from documents.file_handling import delete_empty_directories
|
||||||
from documents.file_handling import generate_unique_filename
|
from documents.file_handling import generate_unique_filename
|
||||||
@ -33,7 +35,7 @@ from documents.permissions import get_objects_for_user_owner_aware
|
|||||||
logger = logging.getLogger("paperless.handlers")
|
logger = logging.getLogger("paperless.handlers")
|
||||||
|
|
||||||
|
|
||||||
def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
|
def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
|
||||||
if document.owner is not None:
|
if document.owner is not None:
|
||||||
tags = get_objects_for_user_owner_aware(
|
tags = get_objects_for_user_owner_aware(
|
||||||
document.owner,
|
document.owner,
|
||||||
@ -48,9 +50,9 @@ def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
|
|||||||
|
|
||||||
def set_correspondent(
|
def set_correspondent(
|
||||||
sender,
|
sender,
|
||||||
document=None,
|
document: Document,
|
||||||
logging_group=None,
|
logging_group=None,
|
||||||
classifier=None,
|
classifier: Optional[DocumentClassifier] = None,
|
||||||
replace=False,
|
replace=False,
|
||||||
use_first=True,
|
use_first=True,
|
||||||
suggest=False,
|
suggest=False,
|
||||||
@ -111,9 +113,9 @@ def set_correspondent(
|
|||||||
|
|
||||||
def set_document_type(
|
def set_document_type(
|
||||||
sender,
|
sender,
|
||||||
document=None,
|
document: Document,
|
||||||
logging_group=None,
|
logging_group=None,
|
||||||
classifier=None,
|
classifier: Optional[DocumentClassifier] = None,
|
||||||
replace=False,
|
replace=False,
|
||||||
use_first=True,
|
use_first=True,
|
||||||
suggest=False,
|
suggest=False,
|
||||||
@ -175,9 +177,9 @@ def set_document_type(
|
|||||||
|
|
||||||
def set_tags(
|
def set_tags(
|
||||||
sender,
|
sender,
|
||||||
document=None,
|
document: Document,
|
||||||
logging_group=None,
|
logging_group=None,
|
||||||
classifier=None,
|
classifier: Optional[DocumentClassifier] = None,
|
||||||
replace=False,
|
replace=False,
|
||||||
suggest=False,
|
suggest=False,
|
||||||
base_url=None,
|
base_url=None,
|
||||||
@ -239,9 +241,9 @@ def set_tags(
|
|||||||
|
|
||||||
def set_storage_path(
|
def set_storage_path(
|
||||||
sender,
|
sender,
|
||||||
document=None,
|
document: Document,
|
||||||
logging_group=None,
|
logging_group=None,
|
||||||
classifier=None,
|
classifier: Optional[DocumentClassifier] = None,
|
||||||
replace=False,
|
replace=False,
|
||||||
use_first=True,
|
use_first=True,
|
||||||
suggest=False,
|
suggest=False,
|
||||||
@ -491,7 +493,7 @@ def update_filename_and_move_files(sender, instance: Document, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_log_entry(sender, document=None, logging_group=None, **kwargs):
|
def set_log_entry(sender, document: Document, logging_group=None, **kwargs):
|
||||||
ct = ContentType.objects.get(model="document")
|
ct = ContentType.objects.get(model="document")
|
||||||
user = User.objects.get(username="consumer")
|
user = User.objects.get(username="consumer")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user