mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:
parent
14d82bd8ff
commit
d856e48045
1
Pipfile
1
Pipfile
@ -56,6 +56,7 @@ mysqlclient = "*"
|
|||||||
celery = {extras = ["redis"], version = "*"}
|
celery = {extras = ["redis"], version = "*"}
|
||||||
django-celery-results = "*"
|
django-celery-results = "*"
|
||||||
setproctitle = "*"
|
setproctitle = "*"
|
||||||
|
nltk = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
coveralls = "*"
|
coveralls = "*"
|
||||||
|
8
Pipfile.lock
generated
8
Pipfile.lock
generated
@ -889,6 +889,14 @@
|
|||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==2.1.1"
|
"version": "==2.1.1"
|
||||||
},
|
},
|
||||||
|
"nltk": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8",
|
||||||
|
"sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==3.7"
|
||||||
|
},
|
||||||
"numpy": {
|
"numpy": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",
|
"sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",
|
||||||
|
@ -53,6 +53,24 @@ map_folders() {
|
|||||||
export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
|
export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nltk_data () {
|
||||||
|
# Store the NLTK data outside the Docker container
|
||||||
|
local nltk_data_dir="${DATA_DIR}/nltk"
|
||||||
|
|
||||||
|
# Download or update the snowball stemmer data
|
||||||
|
python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data
|
||||||
|
|
||||||
|
# Download or update the stopwords corpus
|
||||||
|
python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords
|
||||||
|
|
||||||
|
# Download or update the punkt tokenizer data
|
||||||
|
python3 -m nltk.downloader -d "${nltk_data_dir}" punkt
|
||||||
|
|
||||||
|
# Set env so nltk can find the downloaded data
|
||||||
|
export NLTK_DATA="${nltk_data_dir}"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
initialize() {
|
initialize() {
|
||||||
|
|
||||||
# Setup environment from secrets before anything else
|
# Setup environment from secrets before anything else
|
||||||
@ -105,6 +123,8 @@ initialize() {
|
|||||||
done
|
done
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
nltk_data
|
||||||
|
|
||||||
"${gosu_cmd[@]}" /sbin/docker-prepare.sh
|
"${gosu_cmd[@]}" /sbin/docker-prepare.sh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,12 +5,15 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import MatchingModel
|
from documents.models import MatchingModel
|
||||||
|
|
||||||
|
logger = logging.getLogger("paperless.classifier")
|
||||||
|
|
||||||
|
|
||||||
class IncompatibleClassifierVersionError(Exception):
|
class IncompatibleClassifierVersionError(Exception):
|
||||||
pass
|
pass
|
||||||
@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.classifier")
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_content(content: str) -> str:
|
|
||||||
content = content.lower().strip()
|
|
||||||
content = re.sub(r"\s+", " ", content)
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def load_classifier() -> Optional["DocumentClassifier"]:
|
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||||
if not os.path.isfile(settings.MODEL_FILE):
|
if not os.path.isfile(settings.MODEL_FILE):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -81,6 +75,8 @@ class DocumentClassifier:
|
|||||||
self.document_type_classifier = None
|
self.document_type_classifier = None
|
||||||
self.storage_path_classifier = None
|
self.storage_path_classifier = None
|
||||||
|
|
||||||
|
self.stemmer = None
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
# Catch warnings for processing
|
# Catch warnings for processing
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
@ -139,11 +135,11 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
|
||||||
data = list()
|
data = []
|
||||||
labels_tags = list()
|
labels_tags = []
|
||||||
labels_correspondent = list()
|
labels_correspondent = []
|
||||||
labels_document_type = list()
|
labels_document_type = []
|
||||||
labels_storage_path = list()
|
labels_storage_path = []
|
||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
@ -151,7 +147,7 @@ class DocumentClassifier:
|
|||||||
for doc in Document.objects.order_by("pk").exclude(
|
for doc in Document.objects.order_by("pk").exclude(
|
||||||
tags__is_inbox_tag=True,
|
tags__is_inbox_tag=True,
|
||||||
):
|
):
|
||||||
preprocessed_content = preprocess_content(doc.content)
|
preprocessed_content = self.preprocess_content(doc.content)
|
||||||
m.update(preprocessed_content.encode("utf-8"))
|
m.update(preprocessed_content.encode("utf-8"))
|
||||||
data.append(preprocessed_content)
|
data.append(preprocessed_content)
|
||||||
|
|
||||||
@ -231,6 +227,11 @@ class DocumentClassifier:
|
|||||||
)
|
)
|
||||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||||
|
|
||||||
|
# See the notes here:
|
||||||
|
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
||||||
|
# This attribute isn't needed to function and can be large
|
||||||
|
self.data_vectorizer.stop_words_ = None
|
||||||
|
|
||||||
# Step 3: train the classifiers
|
# Step 3: train the classifiers
|
||||||
if num_tags > 0:
|
if num_tags > 0:
|
||||||
logger.debug("Training tags classifier...")
|
logger.debug("Training tags classifier...")
|
||||||
@ -296,9 +297,36 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def preprocess_content(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Process to contents of a document, distilling it down into
|
||||||
|
words which are meaningful to the content
|
||||||
|
"""
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from nltk.stem import SnowballStemmer
|
||||||
|
|
||||||
|
if self.stemmer is None:
|
||||||
|
self.stemmer = SnowballStemmer("english")
|
||||||
|
|
||||||
|
# Lower case the document
|
||||||
|
content = content.lower().strip()
|
||||||
|
# Get only the letters (remove punctuation too)
|
||||||
|
content = re.sub(r"[^\w\s]", " ", content)
|
||||||
|
# Tokenize
|
||||||
|
# TODO configurable language
|
||||||
|
words: List[str] = word_tokenize(content, language="english")
|
||||||
|
# Remove stop words
|
||||||
|
stops = set(stopwords.words("english"))
|
||||||
|
meaningful_words = [w for w in words if w not in stops]
|
||||||
|
# Stem words
|
||||||
|
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
|
||||||
|
|
||||||
|
return " ".join(meaningful_words)
|
||||||
|
|
||||||
def predict_correspondent(self, content):
|
def predict_correspondent(self, content):
|
||||||
if self.correspondent_classifier:
|
if self.correspondent_classifier:
|
||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
correspondent_id = self.correspondent_classifier.predict(X)
|
correspondent_id = self.correspondent_classifier.predict(X)
|
||||||
if correspondent_id != -1:
|
if correspondent_id != -1:
|
||||||
return correspondent_id
|
return correspondent_id
|
||||||
@ -309,7 +337,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
def predict_document_type(self, content):
|
def predict_document_type(self, content):
|
||||||
if self.document_type_classifier:
|
if self.document_type_classifier:
|
||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
document_type_id = self.document_type_classifier.predict(X)
|
document_type_id = self.document_type_classifier.predict(X)
|
||||||
if document_type_id != -1:
|
if document_type_id != -1:
|
||||||
return document_type_id
|
return document_type_id
|
||||||
@ -322,7 +350,7 @@ class DocumentClassifier:
|
|||||||
from sklearn.utils.multiclass import type_of_target
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
if self.tags_classifier:
|
if self.tags_classifier:
|
||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
y = self.tags_classifier.predict(X)
|
y = self.tags_classifier.predict(X)
|
||||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||||
if type_of_target(y).startswith("multilabel"):
|
if type_of_target(y).startswith("multilabel"):
|
||||||
@ -341,7 +369,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
def predict_storage_path(self, content):
|
def predict_storage_path(self, content):
|
||||||
if self.storage_path_classifier:
|
if self.storage_path_classifier:
|
||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||||
storage_path_id = self.storage_path_classifier.predict(X)
|
storage_path_id = self.storage_path_classifier.predict(X)
|
||||||
if storage_path_id != -1:
|
if storage_path_id != -1:
|
||||||
return storage_path_id
|
return storage_path_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user