Changes classifier training to hold less data in memory at the same time

This commit is contained in:
Trenton H 2023-02-22 11:33:11 -08:00
parent 16d3041d7a
commit 8709ea4df0

View File

@ -5,6 +5,7 @@ import pickle
import re import re
import shutil import shutil
import warnings import warnings
from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -136,21 +137,22 @@ class DocumentClassifier:
def train(self): def train(self):
data = []
labels_tags = [] labels_tags = []
labels_correspondent = [] labels_correspondent = []
labels_document_type = [] labels_document_type = []
labels_storage_path = [] labels_storage_path = []
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
if docs_queryset.count() == 0:
raise ValueError("No training data available.")
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...") logger.debug("Gathering data from database...")
m = hashlib.sha1() m = hashlib.sha1()
for doc in Document.objects.order_by("pk").exclude( for doc in docs_queryset:
tags__is_inbox_tag=True,
):
preprocessed_content = self.preprocess_content(doc.content) preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8")) m.update(preprocessed_content.encode("utf-8"))
data.append(preprocessed_content)
y = -1 y = -1
dt = doc.document_type dt = doc.document_type
@ -183,9 +185,6 @@ class DocumentClassifier:
m.update(y.to_bytes(4, "little", signed=True)) m.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y) labels_storage_path.append(y)
if not data:
raise ValueError("No training data available.")
new_data_hash = m.digest() new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash: if self.data_hash and new_data_hash == self.data_hash:
@ -207,7 +206,7 @@ class DocumentClassifier:
logger.debug( logger.debug(
"{} documents, {} tag(s), {} correspondent(s), " "{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s). {} storage path(es)".format( "{} document type(s). {} storage path(es)".format(
len(data), docs_queryset.count(),
num_tags, num_tags,
num_correspondents, num_correspondents,
num_document_types, num_document_types,
@ -221,12 +220,18 @@ class DocumentClassifier:
# Step 2: vectorize data # Step 2: vectorize data
logger.debug("Vectorizing data...") logger.debug("Vectorizing data...")
def content_generator() -> Iterator[str]:
for doc in docs_queryset:
yield self.preprocess_content(doc.content)
self.data_vectorizer = CountVectorizer( self.data_vectorizer = CountVectorizer(
analyzer="word", analyzer="word",
ngram_range=(1, 2), ngram_range=(1, 2),
min_df=0.01, min_df=0.01,
) )
data_vectorized = self.data_vectorizer.fit_transform(data)
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
# See the notes here: # See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501 # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
@ -341,7 +346,7 @@ class DocumentClassifier:
return content return content
def predict_correspondent(self, content): def predict_correspondent(self, content: str):
if self.correspondent_classifier: if self.correspondent_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X) correspondent_id = self.correspondent_classifier.predict(X)
@ -352,7 +357,7 @@ class DocumentClassifier:
else: else:
return None return None
def predict_document_type(self, content): def predict_document_type(self, content: str):
if self.document_type_classifier: if self.document_type_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X) document_type_id = self.document_type_classifier.predict(X)
@ -363,7 +368,7 @@ class DocumentClassifier:
else: else:
return None return None
def predict_tags(self, content): def predict_tags(self, content: str):
from sklearn.utils.multiclass import type_of_target from sklearn.utils.multiclass import type_of_target
if self.tags_classifier: if self.tags_classifier:
@ -384,7 +389,7 @@ class DocumentClassifier:
else: else:
return [] return []
def predict_storage_path(self, content): def predict_storage_path(self, content: str):
if self.storage_path_classifier: if self.storage_path_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X) storage_path_id = self.storage_path_classifier.predict(X)