mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-16 00:36:22 +00:00
Adding more typing around the classification and matching
This commit is contained in:

committed by
Trenton H

parent
07e7bcd30b
commit
d376f9e7a3
@@ -5,6 +5,7 @@ import re
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -81,7 +82,7 @@ class DocumentClassifier:
|
||||
self._stemmer = None
|
||||
self._stop_words = None
|
||||
|
||||
def load(self):
|
||||
def load(self) -> None:
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with open(settings.MODEL_FILE, "rb") as f:
|
||||
@@ -120,19 +121,20 @@ class DocumentClassifier:
|
||||
raise IncompatibleClassifierVersionError
|
||||
|
||||
def save(self):
|
||||
target_file = settings.MODEL_FILE
|
||||
target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part")
|
||||
target_file: Path = settings.MODEL_FILE
|
||||
target_file_temp = target_file.with_suffix(".pickle.part")
|
||||
|
||||
with open(target_file_temp, "wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
pickle.dump(self.last_auto_type_hash, f)
|
||||
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
|
||||
pickle.dump(self.tags_classifier, f)
|
||||
|
||||
pickle.dump(self.correspondent_classifier, f)
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
pickle.dump(self.storage_path_classifier, f)
|
||||
@@ -380,7 +382,7 @@ class DocumentClassifier:
|
||||
|
||||
return content
|
||||
|
||||
def predict_correspondent(self, content: str):
|
||||
def predict_correspondent(self, content: str) -> Optional[int]:
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
@@ -391,7 +393,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_document_type(self, content: str):
|
||||
def predict_document_type(self, content: str) -> Optional[int]:
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
@@ -402,7 +404,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_tags(self, content: str):
|
||||
def predict_tags(self, content: str) -> List[int]:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
@@ -423,7 +425,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return []
|
||||
|
||||
def predict_storage_path(self, content: str):
|
||||
def predict_storage_path(self, content: str) -> Optional[int]:
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
|
Reference in New Issue
Block a user