mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:
parent
14d82bd8ff
commit
d856e48045
1
Pipfile
1
Pipfile
@ -56,6 +56,7 @@ mysqlclient = "*"
|
||||
celery = {extras = ["redis"], version = "*"}
|
||||
django-celery-results = "*"
|
||||
setproctitle = "*"
|
||||
nltk = "*"
|
||||
|
||||
[dev-packages]
|
||||
coveralls = "*"
|
||||
|
8
Pipfile.lock
generated
8
Pipfile.lock
generated
@ -889,6 +889,14 @@
|
||||
"index": "pypi",
|
||||
"version": "==2.1.1"
|
||||
},
|
||||
"nltk": {
|
||||
"hashes": [
|
||||
"sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8",
|
||||
"sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==3.7"
|
||||
},
|
||||
"numpy": {
|
||||
"hashes": [
|
||||
"sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",
|
||||
|
@ -53,6 +53,24 @@ map_folders() {
|
||||
export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
|
||||
}
|
||||
|
||||
nltk_data () {
|
||||
# Store the NLTK data outside the Docker container
|
||||
local nltk_data_dir="${DATA_DIR}/nltk"
|
||||
|
||||
# Download or update the snowball stemmer data
|
||||
python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data
|
||||
|
||||
# Download or update the stopwords corpus
|
||||
python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords
|
||||
|
||||
# Download or update the punkt tokenizer data
|
||||
python3 -m nltk.downloader -d "${nltk_data_dir}" punkt
|
||||
|
||||
# Set env so nltk can find the downloaded data
|
||||
export NLTK_DATA="${nltk_data_dir}"
|
||||
|
||||
}
|
||||
|
||||
initialize() {
|
||||
|
||||
# Setup environment from secrets before anything else
|
||||
@ -105,6 +123,8 @@ initialize() {
|
||||
done
|
||||
set -e
|
||||
|
||||
nltk_data
|
||||
|
||||
"${gosu_cmd[@]}" /sbin/docker-prepare.sh
|
||||
}
|
||||
|
||||
|
@ -5,12 +5,15 @@ import pickle
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
from documents.models import Document
|
||||
from documents.models import MatchingModel
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
pass
|
||||
@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
|
||||
def preprocess_content(content: str) -> str:
|
||||
content = content.lower().strip()
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
return content
|
||||
|
||||
|
||||
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
if not os.path.isfile(settings.MODEL_FILE):
|
||||
logger.debug(
|
||||
@ -81,6 +75,8 @@ class DocumentClassifier:
|
||||
self.document_type_classifier = None
|
||||
self.storage_path_classifier = None
|
||||
|
||||
self.stemmer = None
|
||||
|
||||
def load(self):
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
@ -139,11 +135,11 @@ class DocumentClassifier:
|
||||
|
||||
def train(self):
|
||||
|
||||
data = list()
|
||||
labels_tags = list()
|
||||
labels_correspondent = list()
|
||||
labels_document_type = list()
|
||||
labels_storage_path = list()
|
||||
data = []
|
||||
labels_tags = []
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
@ -151,7 +147,7 @@ class DocumentClassifier:
|
||||
for doc in Document.objects.order_by("pk").exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
):
|
||||
preprocessed_content = preprocess_content(doc.content)
|
||||
preprocessed_content = self.preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode("utf-8"))
|
||||
data.append(preprocessed_content)
|
||||
|
||||
@ -231,6 +227,11 @@ class DocumentClassifier:
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
# See the notes here:
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
||||
# This attribute isn't needed to function and can be large
|
||||
self.data_vectorizer.stop_words_ = None
|
||||
|
||||
# Step 3: train the classifiers
|
||||
if num_tags > 0:
|
||||
logger.debug("Training tags classifier...")
|
||||
@ -296,9 +297,36 @@ class DocumentClassifier:
|
||||
|
||||
return True
|
||||
|
||||
def preprocess_content(self, content: str) -> str:
|
||||
"""
|
||||
Process to contents of a document, distilling it down into
|
||||
words which are meaningful to the content
|
||||
"""
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import SnowballStemmer
|
||||
|
||||
if self.stemmer is None:
|
||||
self.stemmer = SnowballStemmer("english")
|
||||
|
||||
# Lower case the document
|
||||
content = content.lower().strip()
|
||||
# Get only the letters (remove punctuation too)
|
||||
content = re.sub(r"[^\w\s]", " ", content)
|
||||
# Tokenize
|
||||
# TODO configurable language
|
||||
words: List[str] = word_tokenize(content, language="english")
|
||||
# Remove stop words
|
||||
stops = set(stopwords.words("english"))
|
||||
meaningful_words = [w for w in words if w not in stops]
|
||||
# Stem words
|
||||
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
|
||||
|
||||
return " ".join(meaningful_words)
|
||||
|
||||
def predict_correspondent(self, content):
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
if correspondent_id != -1:
|
||||
return correspondent_id
|
||||
@ -309,7 +337,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_document_type(self, content):
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
if document_type_id != -1:
|
||||
return document_type_id
|
||||
@ -322,7 +350,7 @@ class DocumentClassifier:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
@ -341,7 +369,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_storage_path(self, content):
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
if storage_path_id != -1:
|
||||
return storage_path_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user