Changes classifier training to hold less data in memory at the same time

This commit is contained in:
Trenton H 2023-02-22 11:33:11 -08:00
parent 16d3041d7a
commit 8709ea4df0

View File

@ -5,6 +5,7 @@ import pickle
import re
import shutil
import warnings
from typing import Iterator
from typing import List
from typing import Optional
@ -136,21 +137,22 @@ class DocumentClassifier:
def train(self):
data = []
labels_tags = []
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
if docs_queryset.count() == 0:
raise ValueError("No training data available.")
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
m = hashlib.sha1()
for doc in Document.objects.order_by("pk").exclude(
tags__is_inbox_tag=True,
):
for doc in docs_queryset:
preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8"))
data.append(preprocessed_content)
y = -1
dt = doc.document_type
@ -183,9 +185,6 @@ class DocumentClassifier:
m.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y)
if not data:
raise ValueError("No training data available.")
new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash:
@ -207,7 +206,7 @@ class DocumentClassifier:
logger.debug(
"{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s). {} storage path(es)".format(
len(data),
docs_queryset.count(),
num_tags,
num_correspondents,
num_document_types,
@ -221,12 +220,18 @@ class DocumentClassifier:
# Step 2: vectorize data
logger.debug("Vectorizing data...")
def content_generator() -> Iterator[str]:
for doc in docs_queryset:
yield self.preprocess_content(doc.content)
self.data_vectorizer = CountVectorizer(
analyzer="word",
ngram_range=(1, 2),
min_df=0.01,
)
data_vectorized = self.data_vectorizer.fit_transform(data)
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
@ -341,7 +346,7 @@ class DocumentClassifier:
return content
def predict_correspondent(self, content):
def predict_correspondent(self, content: str):
if self.correspondent_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X)
@ -352,7 +357,7 @@ class DocumentClassifier:
else:
return None
def predict_document_type(self, content):
def predict_document_type(self, content: str):
if self.document_type_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X)
@ -363,7 +368,7 @@ class DocumentClassifier:
else:
return None
def predict_tags(self, content):
def predict_tags(self, content: str):
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier:
@ -384,7 +389,7 @@ class DocumentClassifier:
else:
return []
def predict_storage_path(self, content):
def predict_storage_path(self, content: str):
if self.storage_path_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X)