mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Changes classifier training to hold less data in memory at the same time
This commit is contained in:
parent
16d3041d7a
commit
8709ea4df0
@ -5,6 +5,7 @@ import pickle
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
@ -136,21 +137,22 @@ class DocumentClassifier:
|
||||
|
||||
def train(self):
|
||||
|
||||
data = []
|
||||
labels_tags = []
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
|
||||
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
|
||||
|
||||
if docs_queryset.count() == 0:
|
||||
raise ValueError("No training data available.")
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
m = hashlib.sha1()
|
||||
for doc in Document.objects.order_by("pk").exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
):
|
||||
for doc in docs_queryset:
|
||||
preprocessed_content = self.preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode("utf-8"))
|
||||
data.append(preprocessed_content)
|
||||
|
||||
y = -1
|
||||
dt = doc.document_type
|
||||
@ -183,9 +185,6 @@ class DocumentClassifier:
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_storage_path.append(y)
|
||||
|
||||
if not data:
|
||||
raise ValueError("No training data available.")
|
||||
|
||||
new_data_hash = m.digest()
|
||||
|
||||
if self.data_hash and new_data_hash == self.data_hash:
|
||||
@ -207,7 +206,7 @@ class DocumentClassifier:
|
||||
logger.debug(
|
||||
"{} documents, {} tag(s), {} correspondent(s), "
|
||||
"{} document type(s). {} storage path(es)".format(
|
||||
len(data),
|
||||
docs_queryset.count(),
|
||||
num_tags,
|
||||
num_correspondents,
|
||||
num_document_types,
|
||||
@ -221,12 +220,18 @@ class DocumentClassifier:
|
||||
|
||||
# Step 2: vectorize data
|
||||
logger.debug("Vectorizing data...")
|
||||
|
||||
def content_generator() -> Iterator[str]:
|
||||
for doc in docs_queryset:
|
||||
yield self.preprocess_content(doc.content)
|
||||
|
||||
self.data_vectorizer = CountVectorizer(
|
||||
analyzer="word",
|
||||
ngram_range=(1, 2),
|
||||
min_df=0.01,
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
|
||||
|
||||
# See the notes here:
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
||||
@ -341,7 +346,7 @@ class DocumentClassifier:
|
||||
|
||||
return content
|
||||
|
||||
def predict_correspondent(self, content):
|
||||
def predict_correspondent(self, content: str):
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
@ -352,7 +357,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_document_type(self, content):
|
||||
def predict_document_type(self, content: str):
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
@ -363,7 +368,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_tags(self, content):
|
||||
def predict_tags(self, content: str):
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
@ -384,7 +389,7 @@ class DocumentClassifier:
|
||||
else:
|
||||
return []
|
||||
|
||||
def predict_storage_path(self, content):
|
||||
def predict_storage_path(self, content: str):
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
|
Loading…
x
Reference in New Issue
Block a user