mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-05 02:06:20 +00:00
578 lines
22 KiB
Python
578 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import pickle
|
|
import re
|
|
import warnings
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import joblib
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator
|
|
from datetime import datetime
|
|
|
|
from numpy import ndarray
|
|
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from django.core.cache import caches
|
|
|
|
from documents.caching import CACHE_5_MINUTES
|
|
from documents.caching import CACHE_50_MINUTES
|
|
from documents.caching import CLASSIFIER_HASH_KEY
|
|
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
|
from documents.caching import CLASSIFIER_VERSION_KEY
|
|
from documents.caching import StoredLRUCache
|
|
from documents.models import Document
|
|
from documents.models import MatchingModel
|
|
|
|
logger = logging.getLogger("paperless.classifier")
|
|
|
|
ADVANCED_TEXT_PROCESSING_ENABLED = (
|
|
settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED
|
|
)
|
|
|
|
read_cache = caches["read-cache"]
|
|
|
|
|
|
RE_DIGIT = re.compile(r"\d")
|
|
RE_WORD = re.compile(r"\b[\w]+\b") # words that may contain digits
|
|
|
|
|
|
class IncompatibleClassifierVersionError(Exception):
|
|
def __init__(self, message: str, *args: object) -> None:
|
|
self.message: str = message
|
|
super().__init__(*args)
|
|
|
|
|
|
class ClassifierModelCorruptError(Exception):
|
|
pass
|
|
|
|
|
|
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
|
|
if not settings.MODEL_FILE.is_file():
|
|
logger.debug(
|
|
"Document classification model does not exist (yet), not "
|
|
"performing automatic matching.",
|
|
)
|
|
return None
|
|
|
|
classifier = DocumentClassifier()
|
|
try:
|
|
classifier.load()
|
|
|
|
except IncompatibleClassifierVersionError as e:
|
|
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
|
|
Path(settings.MODEL_FILE).unlink()
|
|
classifier = None
|
|
if raise_exception:
|
|
raise e
|
|
except ClassifierModelCorruptError as e:
|
|
# there's something wrong with the model file.
|
|
logger.exception(
|
|
"Unrecoverable error while loading document "
|
|
"classification model, deleting model file.",
|
|
)
|
|
Path(settings.MODEL_FILE).unlink
|
|
classifier = None
|
|
if raise_exception:
|
|
raise e
|
|
except OSError as e:
|
|
logger.exception("IO error while loading document classification model")
|
|
classifier = None
|
|
if raise_exception:
|
|
raise e
|
|
except Exception as e: # pragma: no cover
|
|
logger.exception("Unknown error while loading document classification model")
|
|
classifier = None
|
|
if raise_exception:
|
|
raise e
|
|
|
|
return classifier
|
|
|
|
|
|
class DocumentClassifier:
|
|
# v7 - Updated scikit-learn package version
|
|
# v8 - Added storage path classifier
|
|
# v9 - Changed from hashing to time/ids for re-train check
|
|
# v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
|
|
FORMAT_VERSION = 10
|
|
|
|
def __init__(self) -> None:
|
|
# last time a document changed and therefore training might be required
|
|
self.last_doc_change_time: datetime | None = None
|
|
# Hash of primary keys of AUTO matching values last used in training
|
|
self.last_auto_type_hash: bytes | None = None
|
|
|
|
self.data_vectorizer = None
|
|
self.data_vectorizer_hash = None
|
|
self.tags_binarizer = None
|
|
self.tags_classifier = None
|
|
self.correspondent_classifier = None
|
|
self.document_type_classifier = None
|
|
self.storage_path_classifier = None
|
|
self._stemmer = None
|
|
# 10,000 elements roughly use 200 to 500 KB per worker,
|
|
# and also in the shared Redis cache,
|
|
# Keep this cache small to minimize lookup and I/O latency.
|
|
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
|
self._stem_cache = StoredLRUCache(
|
|
f"stem_cache_v{self.FORMAT_VERSION}",
|
|
capacity=10000,
|
|
)
|
|
self._stop_words = None
|
|
|
|
def _update_data_vectorizer_hash(self):
|
|
self.data_vectorizer_hash = sha256(
|
|
pickle.dumps(self.data_vectorizer),
|
|
).hexdigest()
|
|
|
|
def load(self) -> None:
|
|
from sklearn.exceptions import InconsistentVersionWarning
|
|
|
|
# Catch warnings for processing
|
|
with warnings.catch_warnings(record=True) as w:
|
|
state = None
|
|
try:
|
|
state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
|
|
except ValueError:
|
|
# Some environments may fail to mmap small files; fall back to normal load
|
|
state = joblib.load(settings.MODEL_FILE, mmap_mode=None)
|
|
except Exception as err:
|
|
# Fallback to old pickle-based format. Try to read the version and a field to
|
|
# distinguish truly corrupt files from incompatible versions.
|
|
try:
|
|
with Path(settings.MODEL_FILE).open("rb") as f:
|
|
_version = pickle.load(f)
|
|
try:
|
|
_ = pickle.load(f)
|
|
except Exception as inner:
|
|
raise ClassifierModelCorruptError from inner
|
|
# Old, incompatible format
|
|
raise IncompatibleClassifierVersionError(
|
|
"Cannot load classifier, incompatible versions.",
|
|
) from err
|
|
except (
|
|
IncompatibleClassifierVersionError,
|
|
ClassifierModelCorruptError,
|
|
):
|
|
raise
|
|
except Exception:
|
|
# Not even a readable pickle header
|
|
raise ClassifierModelCorruptError from err
|
|
|
|
if (
|
|
not isinstance(state, dict)
|
|
or state.get("format_version") != self.FORMAT_VERSION
|
|
):
|
|
raise IncompatibleClassifierVersionError(
|
|
"Cannot load classifier, incompatible versions.",
|
|
)
|
|
|
|
try:
|
|
self.last_doc_change_time = state.get("last_doc_change_time")
|
|
self.last_auto_type_hash = state.get("last_auto_type_hash")
|
|
|
|
self.data_vectorizer = state.get("data_vectorizer")
|
|
self._update_data_vectorizer_hash()
|
|
self.tags_binarizer = state.get("tags_binarizer")
|
|
|
|
self.tags_classifier = state.get("tags_classifier")
|
|
self.correspondent_classifier = state.get("correspondent_classifier")
|
|
self.document_type_classifier = state.get("document_type_classifier")
|
|
self.storage_path_classifier = state.get("storage_path_classifier")
|
|
except Exception as err:
|
|
raise ClassifierModelCorruptError from err
|
|
|
|
# Check for the warning about unpickling from differing versions
|
|
# and consider it incompatible
|
|
sk_learn_warning_url = (
|
|
"https://scikit-learn.org/stable/"
|
|
"model_persistence.html"
|
|
"#security-maintainability-limitations"
|
|
)
|
|
for warning in w:
|
|
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
|
|
if issubclass(warning.category, InconsistentVersionWarning) or (
|
|
issubclass(warning.category, UserWarning)
|
|
and sk_learn_warning_url in str(warning.message)
|
|
):
|
|
raise IncompatibleClassifierVersionError("sklearn version update")
|
|
|
|
def save(self) -> None:
|
|
target_file: Path = settings.MODEL_FILE
|
|
target_file_temp: Path = target_file.with_suffix(".joblib.part")
|
|
|
|
state = {
|
|
"format_version": self.FORMAT_VERSION,
|
|
"last_doc_change_time": self.last_doc_change_time,
|
|
"last_auto_type_hash": self.last_auto_type_hash,
|
|
"data_vectorizer": self.data_vectorizer,
|
|
"tags_binarizer": self.tags_binarizer,
|
|
"tags_classifier": self.tags_classifier,
|
|
"correspondent_classifier": self.correspondent_classifier,
|
|
"document_type_classifier": self.document_type_classifier,
|
|
"storage_path_classifier": self.storage_path_classifier,
|
|
}
|
|
|
|
joblib.dump(state, target_file_temp, compress=3)
|
|
|
|
target_file_temp.rename(target_file)
|
|
|
|
def train(self) -> bool:
|
|
# Get non-inbox documents
|
|
docs_queryset = (
|
|
Document.objects.exclude(
|
|
tags__is_inbox_tag=True,
|
|
)
|
|
.select_related("document_type", "correspondent", "storage_path")
|
|
.prefetch_related("tags")
|
|
.order_by("pk")
|
|
)
|
|
|
|
# No documents exit to train against
|
|
if docs_queryset.count() == 0:
|
|
raise ValueError("No training data available.")
|
|
|
|
labels_tags = []
|
|
labels_correspondent = []
|
|
labels_document_type = []
|
|
labels_storage_path = []
|
|
|
|
# Step 1: Extract and preprocess training data from the database.
|
|
logger.debug("Gathering data from database...")
|
|
hasher = sha256()
|
|
for doc in docs_queryset:
|
|
y = -1
|
|
dt = doc.document_type
|
|
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
|
y = dt.pk
|
|
hasher.update(y.to_bytes(4, "little", signed=True))
|
|
labels_document_type.append(y)
|
|
|
|
y = -1
|
|
cor = doc.correspondent
|
|
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
|
y = cor.pk
|
|
hasher.update(y.to_bytes(4, "little", signed=True))
|
|
labels_correspondent.append(y)
|
|
|
|
tags: list[int] = list(
|
|
doc.tags.filter(matching_algorithm=MatchingModel.MATCH_AUTO)
|
|
.order_by("pk")
|
|
.values_list("pk", flat=True),
|
|
)
|
|
for tag in tags:
|
|
hasher.update(tag.to_bytes(4, "little", signed=True))
|
|
labels_tags.append(tags)
|
|
|
|
y = -1
|
|
sp = doc.storage_path
|
|
if sp and sp.matching_algorithm == MatchingModel.MATCH_AUTO:
|
|
y = sp.pk
|
|
hasher.update(y.to_bytes(4, "little", signed=True))
|
|
labels_storage_path.append(y)
|
|
|
|
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
|
|
|
|
num_tags = len(labels_tags_unique)
|
|
|
|
# Check if retraining is actually required.
|
|
# A document has been updated since the classifier was trained
|
|
# New auto tags, types, correspondent, storage paths exist
|
|
latest_doc_change = docs_queryset.latest("modified").modified
|
|
if (
|
|
self.last_doc_change_time is not None
|
|
and self.last_doc_change_time >= latest_doc_change
|
|
) and self.last_auto_type_hash == hasher.digest():
|
|
logger.info("No updates since last training")
|
|
# Set the classifier information into the cache
|
|
# Caching for 50 minutes, so slightly less than the normal retrain time
|
|
cache.set(
|
|
CLASSIFIER_MODIFIED_KEY,
|
|
self.last_doc_change_time,
|
|
CACHE_50_MINUTES,
|
|
)
|
|
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
|
|
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
|
|
return False
|
|
|
|
# subtract 1 since -1 (null) is also part of the classes.
|
|
|
|
# union with {-1} accounts for cases where all documents have
|
|
# correspondents and types assigned, so -1 isn't part of labels_x, which
|
|
# it usually is.
|
|
num_correspondents: int = len(set(labels_correspondent) | {-1}) - 1
|
|
num_document_types: int = len(set(labels_document_type) | {-1}) - 1
|
|
num_storage_paths: int = len(set(labels_storage_path) | {-1}) - 1
|
|
|
|
logger.debug(
|
|
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
|
|
f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
|
|
)
|
|
|
|
from sklearn.feature_extraction.text import CountVectorizer
|
|
from sklearn.neural_network import MLPClassifier
|
|
from sklearn.preprocessing import LabelBinarizer
|
|
from sklearn.preprocessing import MultiLabelBinarizer
|
|
|
|
# Step 2: vectorize data
|
|
logger.debug("Vectorizing data...")
|
|
|
|
def content_generator() -> Iterator[str]:
|
|
"""
|
|
Generates the content for documents, but once at a time
|
|
"""
|
|
for doc in docs_queryset:
|
|
yield self.preprocess_content(doc.content, shared_cache=False)
|
|
|
|
self.data_vectorizer = CountVectorizer(
|
|
analyzer="word",
|
|
ngram_range=(1, 2),
|
|
min_df=0.01,
|
|
)
|
|
|
|
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
|
|
content_generator(),
|
|
)
|
|
|
|
# See the notes here:
|
|
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
|
|
# This attribute isn't needed to function and can be large
|
|
self.data_vectorizer.stop_words_ = None
|
|
|
|
# Step 3: train the classifiers
|
|
if num_tags > 0:
|
|
logger.debug("Training tags classifier...")
|
|
|
|
if num_tags == 1:
|
|
# Special case where only one tag has auto:
|
|
# Fallback to binary classification.
|
|
labels_tags = [
|
|
label[0] if len(label) == 1 else -1 for label in labels_tags
|
|
]
|
|
self.tags_binarizer = LabelBinarizer()
|
|
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
|
|
labels_tags,
|
|
).ravel()
|
|
else:
|
|
self.tags_binarizer = MultiLabelBinarizer()
|
|
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
|
|
|
self.tags_classifier = MLPClassifier(tol=0.01)
|
|
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
|
else:
|
|
self.tags_classifier = None
|
|
logger.debug("There are no tags. Not training tags classifier.")
|
|
|
|
if num_correspondents > 0:
|
|
logger.debug("Training correspondent classifier...")
|
|
self.correspondent_classifier = MLPClassifier(tol=0.01)
|
|
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
|
|
else:
|
|
self.correspondent_classifier = None
|
|
logger.debug(
|
|
"There are no correspondents. Not training correspondent classifier.",
|
|
)
|
|
|
|
if num_document_types > 0:
|
|
logger.debug("Training document type classifier...")
|
|
self.document_type_classifier = MLPClassifier(tol=0.01)
|
|
self.document_type_classifier.fit(data_vectorized, labels_document_type)
|
|
else:
|
|
self.document_type_classifier = None
|
|
logger.debug(
|
|
"There are no document types. Not training document type classifier.",
|
|
)
|
|
|
|
if num_storage_paths > 0:
|
|
logger.debug(
|
|
"Training storage paths classifier...",
|
|
)
|
|
self.storage_path_classifier = MLPClassifier(tol=0.01)
|
|
self.storage_path_classifier.fit(
|
|
data_vectorized,
|
|
labels_storage_path,
|
|
)
|
|
else:
|
|
self.storage_path_classifier = None
|
|
logger.debug(
|
|
"There are no storage paths. Not training storage path classifier.",
|
|
)
|
|
|
|
self.last_doc_change_time = latest_doc_change
|
|
self.last_auto_type_hash = hasher.digest()
|
|
self._update_data_vectorizer_hash()
|
|
|
|
# Set the classifier information into the cache
|
|
# Caching for 50 minutes, so slightly less than the normal retrain time
|
|
cache.set(CLASSIFIER_MODIFIED_KEY, self.last_doc_change_time, CACHE_50_MINUTES)
|
|
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
|
|
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
|
|
|
|
return True
|
|
|
|
def _init_advanced_text_processing(self):
|
|
if self._stop_words is None or self._stemmer is None:
|
|
import nltk
|
|
from nltk.corpus import stopwords
|
|
from nltk.stem import SnowballStemmer
|
|
|
|
# Not really hacky, since it isn't private and is documented, but
|
|
# set the search path for NLTK data to the single location it should be in
|
|
nltk.data.path = [settings.NLTK_DIR]
|
|
try:
|
|
# Preload the corpus early, to force the lazy loader to transform
|
|
stopwords.ensure_loaded()
|
|
|
|
# Do some one time setup
|
|
# Sometimes, somehow, there's multiple threads loading the corpus
|
|
# and it's not thread safe, raising an AttributeError
|
|
self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE)
|
|
self._stop_words = frozenset(stopwords.words(settings.NLTK_LANGUAGE))
|
|
except AttributeError:
|
|
logger.debug("Could not initialize NLTK for advanced text processing.")
|
|
return False
|
|
return True
|
|
|
|
def stem_and_skip_stop_words(self, words: list[str], *, shared_cache=True):
|
|
"""
|
|
Reduce a list of words to their stem. Stop words are converted to empty strings.
|
|
:param words: the list of words to stem
|
|
"""
|
|
|
|
def _stem_and_skip_stop_word(word: str):
|
|
"""
|
|
Reduce a given word to its stem. If it's a stop word, return an empty string.
|
|
E.g. "amazement", "amaze" and "amazed" all return "amaz".
|
|
"""
|
|
cached = self._stem_cache.get(word)
|
|
if cached is not None:
|
|
return cached
|
|
elif word in self._stop_words:
|
|
return ""
|
|
# Assumption: words that contain numbers are never stemmed
|
|
elif RE_DIGIT.search(word):
|
|
return word
|
|
else:
|
|
result = self._stemmer.stem(word)
|
|
self._stem_cache.set(word, result)
|
|
return result
|
|
|
|
if shared_cache:
|
|
self._stem_cache.load()
|
|
|
|
# Stem the words and skip stop words
|
|
result = " ".join(
|
|
filter(None, (_stem_and_skip_stop_word(w) for w in words)),
|
|
)
|
|
if shared_cache:
|
|
self._stem_cache.save()
|
|
return result
|
|
|
|
def preprocess_content(
|
|
self,
|
|
content: str,
|
|
*,
|
|
shared_cache=True,
|
|
) -> str:
|
|
"""
|
|
Process the contents of a document, distilling it down into
|
|
words which are meaningful to the content.
|
|
|
|
A stemmer cache is shared across workers with the parameter "shared_cache".
|
|
This is unnecessary when training the classifier.
|
|
"""
|
|
|
|
# Lower case the document, reduce space,
|
|
# and keep only letters and digits.
|
|
content = " ".join(match.group().lower() for match in RE_WORD.finditer(content))
|
|
|
|
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
|
from nltk.tokenize import word_tokenize
|
|
|
|
if not self._init_advanced_text_processing():
|
|
return content
|
|
# Tokenize
|
|
# This splits the content into tokens, roughly words
|
|
words = word_tokenize(content, language=settings.NLTK_LANGUAGE)
|
|
# Stem the words and skip stop words
|
|
content = self.stem_and_skip_stop_words(words, shared_cache=shared_cache)
|
|
|
|
return content
|
|
|
|
def _get_vectorizer_cache_key(self, content: str):
|
|
hash = sha256(content.encode())
|
|
hash.update(
|
|
f"|{self.FORMAT_VERSION}|{settings.NLTK_LANGUAGE}|{settings.NLTK_ENABLED}|{self.data_vectorizer_hash}".encode(),
|
|
)
|
|
return f"vectorized_content_{hash.hexdigest()}"
|
|
|
|
def _vectorize(self, content: str):
|
|
key = self._get_vectorizer_cache_key(content)
|
|
serialized_result = read_cache.get(key)
|
|
if serialized_result is None:
|
|
result = self.data_vectorizer.transform([self.preprocess_content(content)])
|
|
read_cache.set(key, pickle.dumps(result), CACHE_5_MINUTES)
|
|
else:
|
|
read_cache.touch(key, CACHE_5_MINUTES)
|
|
result = pickle.loads(serialized_result)
|
|
return result
|
|
|
|
def predict_correspondent(self, content: str) -> int | None:
|
|
if self.correspondent_classifier:
|
|
X = self._vectorize(content)
|
|
correspondent_id = self.correspondent_classifier.predict(X)
|
|
if correspondent_id != -1:
|
|
return correspondent_id
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
def predict_document_type(self, content: str) -> int | None:
|
|
if self.document_type_classifier:
|
|
X = self._vectorize(content)
|
|
document_type_id = self.document_type_classifier.predict(X)
|
|
if document_type_id != -1:
|
|
return document_type_id
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
def predict_tags(self, content: str) -> list[int]:
|
|
from sklearn.utils.multiclass import type_of_target
|
|
|
|
if self.tags_classifier:
|
|
X = self._vectorize(content)
|
|
y = self.tags_classifier.predict(X)
|
|
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
|
if type_of_target(y).startswith("multilabel"):
|
|
# the usual case when there are multiple tags.
|
|
return list(tags_ids)
|
|
elif type_of_target(y) == "binary" and tags_ids != -1:
|
|
# This is for when we have binary classification with only one
|
|
# tag and the result is to assign this tag.
|
|
return [tags_ids]
|
|
else:
|
|
# Usually binary as well with -1 as the result, but we're
|
|
# going to catch everything else here as well.
|
|
return []
|
|
else:
|
|
return []
|
|
|
|
def predict_storage_path(self, content: str) -> int | None:
|
|
if self.storage_path_classifier:
|
|
X = self._vectorize(content)
|
|
storage_path_id = self.storage_path_classifier.predict(X)
|
|
if storage_path_id != -1:
|
|
return storage_path_id
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|