diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 2779fad7b..66958087a 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -5,6 +5,7 @@ import pickle import re import shutil import warnings +from typing import Iterator from typing import List from typing import Optional @@ -136,21 +137,22 @@ class DocumentClassifier: def train(self): - data = [] labels_tags = [] labels_correspondent = [] labels_document_type = [] labels_storage_path = [] + docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True) + + if docs_queryset.count() == 0: + raise ValueError("No training data available.") + # Step 1: Extract and preprocess training data from the database. logger.debug("Gathering data from database...") m = hashlib.sha1() - for doc in Document.objects.order_by("pk").exclude( - tags__is_inbox_tag=True, - ): + for doc in docs_queryset: preprocessed_content = self.preprocess_content(doc.content) m.update(preprocessed_content.encode("utf-8")) - data.append(preprocessed_content) y = -1 dt = doc.document_type @@ -183,9 +185,6 @@ class DocumentClassifier: m.update(y.to_bytes(4, "little", signed=True)) labels_storage_path.append(y) - if not data: - raise ValueError("No training data available.") - new_data_hash = m.digest() if self.data_hash and new_data_hash == self.data_hash: @@ -207,7 +206,7 @@ class DocumentClassifier: logger.debug( "{} documents, {} tag(s), {} correspondent(s), " "{} document type(s). {} storage path(es)".format( - len(data), + docs_queryset.count(), num_tags, num_correspondents, num_document_types, @@ -221,12 +220,18 @@ class DocumentClassifier: # Step 2: vectorize data logger.debug("Vectorizing data...") + + def content_generator() -> Iterator[str]: + for doc in docs_queryset: + yield self.preprocess_content(doc.content) + self.data_vectorizer = CountVectorizer( analyzer="word", ngram_range=(1, 2), min_df=0.01, ) - data_vectorized = self.data_vectorizer.fit_transform(data) + + data_vectorized = self.data_vectorizer.fit_transform(content_generator()) # See the notes here: # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501 @@ -341,7 +346,7 @@ class DocumentClassifier: return content - def predict_correspondent(self, content): + def predict_correspondent(self, content: str): if self.correspondent_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) correspondent_id = self.correspondent_classifier.predict(X) @@ -352,7 +357,7 @@ class DocumentClassifier: else: return None - def predict_document_type(self, content): + def predict_document_type(self, content: str): if self.document_type_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) document_type_id = self.document_type_classifier.predict(X) @@ -363,7 +368,7 @@ class DocumentClassifier: else: return None - def predict_tags(self, content): + def predict_tags(self, content: str): from sklearn.utils.multiclass import type_of_target if self.tags_classifier: @@ -384,7 +389,7 @@ class DocumentClassifier: else: return [] - def predict_storage_path(self, content): + def predict_storage_path(self, content: str): if self.storage_path_classifier: X = self.data_vectorizer.transform([self.preprocess_content(content)]) storage_path_id = self.storage_path_classifier.predict(X)