Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal

This commit is contained in:
Trenton Holmes 2022-09-15 08:39:47 -07:00 committed by Trenton H
parent 14d82bd8ff
commit d856e48045
4 changed files with 76 additions and 19 deletions

View File

@ -56,6 +56,7 @@ mysqlclient = "*"
celery = {extras = ["redis"], version = "*"}
django-celery-results = "*"
setproctitle = "*"
nltk = "*"
[dev-packages]
coveralls = "*"

8
Pipfile.lock generated
View File

@ -889,6 +889,14 @@
"index": "pypi",
"version": "==2.1.1"
},
"nltk": {
"hashes": [
"sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8",
"sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502"
],
"index": "pypi",
"version": "==3.7"
},
"numpy": {
"hashes": [
"sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",

View File

@ -53,6 +53,24 @@ map_folders() {
export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
}
nltk_data () {
# Store the NLTK data outside the Docker container
local nltk_data_dir="${DATA_DIR}/nltk"
# Download or update the snowball stemmer data
python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data
# Download or update the stopwords corpus
python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords
# Download or update the punkt tokenizer data
python3 -m nltk.downloader -d "${nltk_data_dir}" punkt
# Set env so nltk can find the downloaded data
export NLTK_DATA="${nltk_data_dir}"
}
initialize() {
# Setup environment from secrets before anything else
@ -105,6 +123,8 @@ initialize() {
done
set -e
nltk_data
"${gosu_cmd[@]}" /sbin/docker-prepare.sh
}

View File

@ -5,12 +5,15 @@ import pickle
import re
import shutil
import warnings
from typing import List
from typing import Optional
from django.conf import settings
from documents.models import Document
from documents.models import MatchingModel
logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception):
pass
@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
pass
logger = logging.getLogger("paperless.classifier")
def preprocess_content(content: str) -> str:
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
@ -81,6 +75,8 @@ class DocumentClassifier:
self.document_type_classifier = None
self.storage_path_classifier = None
self.stemmer = None
def load(self):
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
@ -139,11 +135,11 @@ class DocumentClassifier:
def train(self):
data = list()
labels_tags = list()
labels_correspondent = list()
labels_document_type = list()
labels_storage_path = list()
data = []
labels_tags = []
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
@ -151,7 +147,7 @@ class DocumentClassifier:
for doc in Document.objects.order_by("pk").exclude(
tags__is_inbox_tag=True,
):
preprocessed_content = preprocess_content(doc.content)
preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8"))
data.append(preprocessed_content)
@ -231,6 +227,11 @@ class DocumentClassifier:
)
data_vectorized = self.data_vectorizer.fit_transform(data)
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
# This attribute isn't needed to function and can be large
self.data_vectorizer.stop_words_ = None
# Step 3: train the classifiers
if num_tags > 0:
logger.debug("Training tags classifier...")
@ -296,9 +297,36 @@ class DocumentClassifier:
return True
def preprocess_content(self, content: str) -> str:
"""
Process to contents of a document, distilling it down into
words which are meaningful to the content
"""
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
if self.stemmer is None:
self.stemmer = SnowballStemmer("english")
# Lower case the document
content = content.lower().strip()
# Get only the letters (remove punctuation too)
content = re.sub(r"[^\w\s]", " ", content)
# Tokenize
# TODO configurable language
words: List[str] = word_tokenize(content, language="english")
# Remove stop words
stops = set(stopwords.words("english"))
meaningful_words = [w for w in words if w not in stops]
# Stem words
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
return " ".join(meaningful_words)
def predict_correspondent(self, content):
if self.correspondent_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X)
if correspondent_id != -1:
return correspondent_id
@ -309,7 +337,7 @@ class DocumentClassifier:
def predict_document_type(self, content):
if self.document_type_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X)
if document_type_id != -1:
return document_type_id
@ -322,7 +350,7 @@ class DocumentClassifier:
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
if type_of_target(y).startswith("multilabel"):
@ -341,7 +369,7 @@ class DocumentClassifier:
def predict_storage_path(self, content):
if self.storage_path_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X)
if storage_path_id != -1:
return storage_path_id