Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal

This commit is contained in:
Trenton Holmes 2022-09-15 08:39:47 -07:00 committed by Trenton H
parent 14d82bd8ff
commit d856e48045
4 changed files with 76 additions and 19 deletions

View File

@ -56,6 +56,7 @@ mysqlclient = "*"
celery = {extras = ["redis"], version = "*"} celery = {extras = ["redis"], version = "*"}
django-celery-results = "*" django-celery-results = "*"
setproctitle = "*" setproctitle = "*"
nltk = "*"
[dev-packages] [dev-packages]
coveralls = "*" coveralls = "*"

8
Pipfile.lock generated
View File

@ -889,6 +889,14 @@
"index": "pypi", "index": "pypi",
"version": "==2.1.1" "version": "==2.1.1"
}, },
"nltk": {
"hashes": [
"sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8",
"sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502"
],
"index": "pypi",
"version": "==3.7"
},
"numpy": { "numpy": {
"hashes": [ "hashes": [
"sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676", "sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",

View File

@ -53,6 +53,24 @@ map_folders() {
export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}" export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
} }
nltk_data () {
# Store the NLTK data outside the Docker container
local nltk_data_dir="${DATA_DIR}/nltk"
# Download or update the snowball stemmer data
python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data
# Download or update the stopwords corpus
python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords
# Download or update the punkt tokenizer data
python3 -m nltk.downloader -d "${nltk_data_dir}" punkt
# Set env so nltk can find the downloaded data
export NLTK_DATA="${nltk_data_dir}"
}
initialize() { initialize() {
# Setup environment from secrets before anything else # Setup environment from secrets before anything else
@ -105,6 +123,8 @@ initialize() {
done done
set -e set -e
nltk_data
"${gosu_cmd[@]}" /sbin/docker-prepare.sh "${gosu_cmd[@]}" /sbin/docker-prepare.sh
} }

View File

@ -5,12 +5,15 @@ import pickle
import re import re
import shutil import shutil
import warnings import warnings
from typing import List
from typing import Optional from typing import Optional
from django.conf import settings from django.conf import settings
from documents.models import Document from documents.models import Document
from documents.models import MatchingModel from documents.models import MatchingModel
logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception): class IncompatibleClassifierVersionError(Exception):
pass pass
@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
pass pass
logger = logging.getLogger("paperless.classifier")
def preprocess_content(content: str) -> str:
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
def load_classifier() -> Optional["DocumentClassifier"]: def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE): if not os.path.isfile(settings.MODEL_FILE):
logger.debug( logger.debug(
@ -81,6 +75,8 @@ class DocumentClassifier:
self.document_type_classifier = None self.document_type_classifier = None
self.storage_path_classifier = None self.storage_path_classifier = None
self.stemmer = None
def load(self): def load(self):
# Catch warnings for processing # Catch warnings for processing
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
@ -139,11 +135,11 @@ class DocumentClassifier:
def train(self): def train(self):
data = list() data = []
labels_tags = list() labels_tags = []
labels_correspondent = list() labels_correspondent = []
labels_document_type = list() labels_document_type = []
labels_storage_path = list() labels_storage_path = []
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...") logger.debug("Gathering data from database...")
@ -151,7 +147,7 @@ class DocumentClassifier:
for doc in Document.objects.order_by("pk").exclude( for doc in Document.objects.order_by("pk").exclude(
tags__is_inbox_tag=True, tags__is_inbox_tag=True,
): ):
preprocessed_content = preprocess_content(doc.content) preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8")) m.update(preprocessed_content.encode("utf-8"))
data.append(preprocessed_content) data.append(preprocessed_content)
@ -231,6 +227,11 @@ class DocumentClassifier:
) )
data_vectorized = self.data_vectorizer.fit_transform(data) data_vectorized = self.data_vectorizer.fit_transform(data)
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
# This attribute isn't needed to function and can be large
self.data_vectorizer.stop_words_ = None
# Step 3: train the classifiers # Step 3: train the classifiers
if num_tags > 0: if num_tags > 0:
logger.debug("Training tags classifier...") logger.debug("Training tags classifier...")
@ -296,9 +297,36 @@ class DocumentClassifier:
return True return True
def preprocess_content(self, content: str) -> str:
"""
Process to contents of a document, distilling it down into
words which are meaningful to the content
"""
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
if self.stemmer is None:
self.stemmer = SnowballStemmer("english")
# Lower case the document
content = content.lower().strip()
# Get only the letters (remove punctuation too)
content = re.sub(r"[^\w\s]", " ", content)
# Tokenize
# TODO configurable language
words: List[str] = word_tokenize(content, language="english")
# Remove stop words
stops = set(stopwords.words("english"))
meaningful_words = [w for w in words if w not in stops]
# Stem words
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
return " ".join(meaningful_words)
def predict_correspondent(self, content): def predict_correspondent(self, content):
if self.correspondent_classifier: if self.correspondent_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X) correspondent_id = self.correspondent_classifier.predict(X)
if correspondent_id != -1: if correspondent_id != -1:
return correspondent_id return correspondent_id
@ -309,7 +337,7 @@ class DocumentClassifier:
def predict_document_type(self, content): def predict_document_type(self, content):
if self.document_type_classifier: if self.document_type_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X) document_type_id = self.document_type_classifier.predict(X)
if document_type_id != -1: if document_type_id != -1:
return document_type_id return document_type_id
@ -322,7 +350,7 @@ class DocumentClassifier:
from sklearn.utils.multiclass import type_of_target from sklearn.utils.multiclass import type_of_target
if self.tags_classifier: if self.tags_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
y = self.tags_classifier.predict(X) y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y)[0] tags_ids = self.tags_binarizer.inverse_transform(y)[0]
if type_of_target(y).startswith("multilabel"): if type_of_target(y).startswith("multilabel"):
@ -341,7 +369,7 @@ class DocumentClassifier:
def predict_storage_path(self, content): def predict_storage_path(self, content):
if self.storage_path_classifier: if self.storage_path_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X) storage_path_id = self.storage_path_classifier.predict(X)
if storage_path_id != -1: if storage_path_id != -1:
return storage_path_id return storage_path_id