mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-08 21:23:44 -05:00
Performance: Classifier performance optimizations (#10363)
This commit is contained in:
@@ -16,16 +16,29 @@ if TYPE_CHECKING:
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
|
||||
from documents.caching import CACHE_5_MINUTES
|
||||
from documents.caching import CACHE_50_MINUTES
|
||||
from documents.caching import CLASSIFIER_HASH_KEY
|
||||
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
||||
from documents.caching import CLASSIFIER_VERSION_KEY
|
||||
from documents.caching import StoredLRUCache
|
||||
from documents.models import Document
|
||||
from documents.models import MatchingModel
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
ADVANCED_TEXT_PROCESSING_ENABLED = (
|
||||
settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED
|
||||
)
|
||||
|
||||
read_cache = caches["read-cache"]
|
||||
|
||||
|
||||
RE_DIGIT = re.compile(r"\d")
|
||||
RE_WORD = re.compile(r"\b[\w]+\b") # words that may contain digits
|
||||
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
def __init__(self, message: str, *args: object) -> None:
|
||||
@@ -92,15 +105,28 @@ class DocumentClassifier:
|
||||
self.last_auto_type_hash: bytes | None = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.data_vectorizer_hash = None
|
||||
self.tags_binarizer = None
|
||||
self.tags_classifier = None
|
||||
self.correspondent_classifier = None
|
||||
self.document_type_classifier = None
|
||||
self.storage_path_classifier = None
|
||||
|
||||
self._stemmer = None
|
||||
# 10,000 elements roughly use 200 to 500 KB per worker,
|
||||
# and also in the shared Redis cache,
|
||||
# Keep this cache small to minimize lookup and I/O latency.
|
||||
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
||||
self._stem_cache = StoredLRUCache(
|
||||
f"stem_cache_v{self.FORMAT_VERSION}",
|
||||
capacity=10000,
|
||||
)
|
||||
self._stop_words = None
|
||||
|
||||
def _update_data_vectorizer_hash(self):
|
||||
self.data_vectorizer_hash = sha256(
|
||||
pickle.dumps(self.data_vectorizer),
|
||||
).hexdigest()
|
||||
|
||||
def load(self) -> None:
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
@@ -119,6 +145,7 @@ class DocumentClassifier:
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
@@ -269,7 +296,7 @@ class DocumentClassifier:
|
||||
Generates the content for documents, but once at a time
|
||||
"""
|
||||
for doc in docs_queryset:
|
||||
yield self.preprocess_content(doc.content)
|
||||
yield self.preprocess_content(doc.content, shared_cache=False)
|
||||
|
||||
self.data_vectorizer = CountVectorizer(
|
||||
analyzer="word",
|
||||
@@ -347,6 +374,7 @@ class DocumentClassifier:
|
||||
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
self._update_data_vectorizer_hash()
|
||||
|
||||
# Set the classifier information into the cache
|
||||
# Caching for 50 minutes, so slightly less than the normal retrain time
|
||||
@@ -356,30 +384,15 @@ class DocumentClassifier:
|
||||
|
||||
return True
|
||||
|
||||
def preprocess_content(self, content: str) -> str: # pragma: no cover
|
||||
"""
|
||||
Process to contents of a document, distilling it down into
|
||||
words which are meaningful to the content
|
||||
"""
|
||||
|
||||
# Lower case the document
|
||||
content = content.lower().strip()
|
||||
# Reduce spaces
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
# Get only the letters
|
||||
content = re.sub(r"[^\w\s]", " ", content)
|
||||
|
||||
# If the NLTK language is supported, do further processing
|
||||
if settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED:
|
||||
def _init_advanced_text_processing(self):
|
||||
if self._stop_words is None or self._stemmer is None:
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import SnowballStemmer
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Not really hacky, since it isn't private and is documented, but
|
||||
# set the search path for NLTK data to the single location it should be in
|
||||
nltk.data.path = [settings.NLTK_DIR]
|
||||
|
||||
try:
|
||||
# Preload the corpus early, to force the lazy loader to transform
|
||||
stopwords.ensure_loaded()
|
||||
@@ -387,41 +400,100 @@ class DocumentClassifier:
|
||||
# Do some one time setup
|
||||
# Sometimes, somehow, there's multiple threads loading the corpus
|
||||
# and it's not thread safe, raising an AttributeError
|
||||
if self._stemmer is None:
|
||||
self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE)
|
||||
if self._stop_words is None:
|
||||
self._stop_words = set(stopwords.words(settings.NLTK_LANGUAGE))
|
||||
|
||||
# Tokenize
|
||||
# This splits the content into tokens, roughly words
|
||||
words: list[str] = word_tokenize(
|
||||
content,
|
||||
language=settings.NLTK_LANGUAGE,
|
||||
)
|
||||
|
||||
meaningful_words = []
|
||||
for word in words:
|
||||
# Skip stop words
|
||||
# These are words like "a", "and", "the" which add little meaning
|
||||
if word in self._stop_words:
|
||||
continue
|
||||
# Stem the words
|
||||
# This reduces the words to their stems.
|
||||
# "amazement" returns "amaz"
|
||||
# "amaze" returns "amaz
|
||||
# "amazed" returns "amaz"
|
||||
meaningful_words.append(self._stemmer.stem(word))
|
||||
|
||||
return " ".join(meaningful_words)
|
||||
|
||||
self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE)
|
||||
self._stop_words = frozenset(stopwords.words(settings.NLTK_LANGUAGE))
|
||||
except AttributeError:
|
||||
logger.debug("Could not initialize NLTK for advanced text processing.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def stem_and_skip_stop_words(self, words: list[str], *, shared_cache=True):
|
||||
"""
|
||||
Reduce a list of words to their stem. Stop words are converted to empty strings.
|
||||
:param words: the list of words to stem
|
||||
"""
|
||||
|
||||
def _stem_and_skip_stop_word(word: str):
|
||||
"""
|
||||
Reduce a given word to its stem. If it's a stop word, return an empty string.
|
||||
E.g. "amazement", "amaze" and "amazed" all return "amaz".
|
||||
"""
|
||||
cached = self._stem_cache.get(word)
|
||||
if cached is not None:
|
||||
return cached
|
||||
elif word in self._stop_words:
|
||||
return ""
|
||||
# Assumption: words that contain numbers are never stemmed
|
||||
elif RE_DIGIT.search(word):
|
||||
return word
|
||||
else:
|
||||
result = self._stemmer.stem(word)
|
||||
self._stem_cache.set(word, result)
|
||||
return result
|
||||
|
||||
if shared_cache:
|
||||
self._stem_cache.load()
|
||||
|
||||
# Stem the words and skip stop words
|
||||
result = " ".join(
|
||||
filter(None, (_stem_and_skip_stop_word(w) for w in words)),
|
||||
)
|
||||
if shared_cache:
|
||||
self._stem_cache.save()
|
||||
return result
|
||||
|
||||
def preprocess_content(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
shared_cache=True,
|
||||
) -> str:
|
||||
"""
|
||||
Process the contents of a document, distilling it down into
|
||||
words which are meaningful to the content.
|
||||
|
||||
A stemmer cache is shared across workers with the parameter "shared_cache".
|
||||
This is unnecessary when training the classifier.
|
||||
"""
|
||||
|
||||
# Lower case the document, reduce space,
|
||||
# and keep only letters and digits.
|
||||
content = " ".join(match.group().lower() for match in RE_WORD.finditer(content))
|
||||
|
||||
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
if not self._init_advanced_text_processing():
|
||||
return content
|
||||
# Tokenize
|
||||
# This splits the content into tokens, roughly words
|
||||
words = word_tokenize(content, language=settings.NLTK_LANGUAGE)
|
||||
# Stem the words and skip stop words
|
||||
content = self.stem_and_skip_stop_words(words, shared_cache=shared_cache)
|
||||
|
||||
return content
|
||||
|
||||
def _get_vectorizer_cache_key(self, content: str):
|
||||
hash = sha256(content.encode())
|
||||
hash.update(
|
||||
f"|{self.FORMAT_VERSION}|{settings.NLTK_LANGUAGE}|{settings.NLTK_ENABLED}|{self.data_vectorizer_hash}".encode(),
|
||||
)
|
||||
return f"vectorized_content_{hash.hexdigest()}"
|
||||
|
||||
def _vectorize(self, content: str):
|
||||
key = self._get_vectorizer_cache_key(content)
|
||||
serialized_result = read_cache.get(key)
|
||||
if serialized_result is None:
|
||||
result = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
read_cache.set(key, pickle.dumps(result), CACHE_5_MINUTES)
|
||||
else:
|
||||
read_cache.touch(key, CACHE_5_MINUTES)
|
||||
result = pickle.loads(serialized_result)
|
||||
return result
|
||||
|
||||
def predict_correspondent(self, content: str) -> int | None:
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
if correspondent_id != -1:
|
||||
return correspondent_id
|
||||
@@ -432,7 +504,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_document_type(self, content: str) -> int | None:
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
if document_type_id != -1:
|
||||
return document_type_id
|
||||
@@ -445,7 +517,7 @@ class DocumentClassifier:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
@@ -464,7 +536,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_storage_path(self, content: str) -> int | None:
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
if storage_path_id != -1:
|
||||
return storage_path_id
|
||||
|
Reference in New Issue
Block a user