mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Changes classifier training to hold less data in memory at the same time
This commit is contained in:
parent
16d3041d7a
commit
8709ea4df0
@ -5,6 +5,7 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Iterator
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -136,21 +137,22 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|
||||||
data = []
|
|
||||||
labels_tags = []
|
labels_tags = []
|
||||||
labels_correspondent = []
|
labels_correspondent = []
|
||||||
labels_document_type = []
|
labels_document_type = []
|
||||||
labels_storage_path = []
|
labels_storage_path = []
|
||||||
|
|
||||||
|
docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
|
||||||
|
|
||||||
|
if docs_queryset.count() == 0:
|
||||||
|
raise ValueError("No training data available.")
|
||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
m = hashlib.sha1()
|
m = hashlib.sha1()
|
||||||
for doc in Document.objects.order_by("pk").exclude(
|
for doc in docs_queryset:
|
||||||
tags__is_inbox_tag=True,
|
|
||||||
):
|
|
||||||
preprocessed_content = self.preprocess_content(doc.content)
|
preprocessed_content = self.preprocess_content(doc.content)
|
||||||
m.update(preprocessed_content.encode("utf-8"))
|
m.update(preprocessed_content.encode("utf-8"))
|
||||||
data.append(preprocessed_content)
|
|
||||||
|
|
||||||
y = -1
|
y = -1
|
||||||
dt = doc.document_type
|
dt = doc.document_type
|
||||||
@ -183,9 +185,6 @@ class DocumentClassifier:
|
|||||||
m.update(y.to_bytes(4, "little", signed=True))
|
m.update(y.to_bytes(4, "little", signed=True))
|
||||||
labels_storage_path.append(y)
|
labels_storage_path.append(y)
|
||||||
|
|
||||||
if not data:
|
|
||||||
raise ValueError("No training data available.")
|
|
||||||
|
|
||||||
new_data_hash = m.digest()
|
new_data_hash = m.digest()
|
||||||
|
|
||||||
if self.data_hash and new_data_hash == self.data_hash:
|
if self.data_hash and new_data_hash == self.data_hash:
|
||||||
@ -207,7 +206,7 @@ class DocumentClassifier:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"{} documents, {} tag(s), {} correspondent(s), "
|
"{} documents, {} tag(s), {} correspondent(s), "
|
||||||
"{} document type(s). {} storage path(es)".format(
|
"{} document type(s). {} storage path(es)".format(
|
||||||
len(data),
|
docs_queryset.count(),
|
||||||
num_tags,
|
num_tags,
|
||||||
num_correspondents,
|
num_correspondents,
|
||||||
num_document_types,
|
num_document_types,
|
||||||
@ -221,12 +220,18 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
# Step 2: vectorize data
|
# Step 2: vectorize data
|
||||||
logger.debug("Vectorizing data...")
|
logger.debug("Vectorizing data...")
|
||||||
|
|
||||||
|
def content_generator() -> Iterator[str]:
|
||||||
|
for doc in docs_queryset:
|
||||||
|
yield self.preprocess_content(doc.content)
|
||||||
|
|
||||||
self.data_vectorizer = CountVectorizer(
|
self.data_vectorizer = CountVectorizer(
|
||||||
analyzer="word",
|
analyzer="word",
|
||||||
ngram_range=(1, 2),
|
ngram_range=(1, 2),
|
||||||
min_df=0.01,
|
min_df=0.01,
|
||||||
)
|
)
|
||||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
|
||||||
|
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
|
||||||
|
|
||||||
# See the notes here:
|
# See the notes here:
|
||||||
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
||||||
@ -341,7 +346,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def predict_correspondent(self, content):
|
def predict_correspondent(self, content: str):
|
||||||
if self.correspondent_classifier:
|
if self.correspondent_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
correspondent_id = self.correspondent_classifier.predict(X)
|
correspondent_id = self.correspondent_classifier.predict(X)
|
||||||
@ -352,7 +357,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_document_type(self, content):
|
def predict_document_type(self, content: str):
|
||||||
if self.document_type_classifier:
|
if self.document_type_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
document_type_id = self.document_type_classifier.predict(X)
|
document_type_id = self.document_type_classifier.predict(X)
|
||||||
@ -363,7 +368,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_tags(self, content):
|
def predict_tags(self, content: str):
|
||||||
from sklearn.utils.multiclass import type_of_target
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
if self.tags_classifier:
|
if self.tags_classifier:
|
||||||
@ -384,7 +389,7 @@ class DocumentClassifier:
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def predict_storage_path(self, content):
|
def predict_storage_path(self, content: str):
|
||||||
if self.storage_path_classifier:
|
if self.storage_path_classifier:
|
||||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
storage_path_id = self.storage_path_classifier.predict(X)
|
storage_path_id = self.storage_path_classifier.predict(X)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user