Adding more typing around the classification and matching

This commit is contained in:
Trenton Holmes
2023-07-23 16:49:20 -07:00
committed by Trenton H
parent 07e7bcd30b
commit d376f9e7a3
3 changed files with 30 additions and 24 deletions

View File

@@ -5,6 +5,7 @@ import re
import warnings
from datetime import datetime
from hashlib import sha256
from pathlib import Path
from typing import Iterator
from typing import List
from typing import Optional
@@ -81,7 +82,7 @@ class DocumentClassifier:
self._stemmer = None
self._stop_words = None
def load(self):
def load(self) -> None:
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
with open(settings.MODEL_FILE, "rb") as f:
@@ -120,19 +121,20 @@ class DocumentClassifier:
raise IncompatibleClassifierVersionError
def save(self):
target_file = settings.MODEL_FILE
target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part")
target_file: Path = settings.MODEL_FILE
target_file_temp = target_file.with_suffix(".pickle.part")
with open(target_file_temp, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.last_doc_change_time, f)
pickle.dump(self.last_auto_type_hash, f)
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f)
@@ -380,7 +382,7 @@ class DocumentClassifier:
return content
def predict_correspondent(self, content: str):
def predict_correspondent(self, content: str) -> Optional[int]:
if self.correspondent_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X)
@@ -391,7 +393,7 @@ class DocumentClassifier:
else:
return None
def predict_document_type(self, content: str):
def predict_document_type(self, content: str) -> Optional[int]:
if self.document_type_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X)
@@ -402,7 +404,7 @@ class DocumentClassifier:
else:
return None
def predict_tags(self, content: str):
def predict_tags(self, content: str) -> List[int]:
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier:
@@ -423,7 +425,7 @@ class DocumentClassifier:
else:
return []
def predict_storage_path(self, content: str):
def predict_storage_path(self, content: str) -> Optional[int]:
if self.storage_path_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X)