mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:

committed by
Trenton H

parent
77a3f8ed60
commit
66884ea035
@@ -5,12 +5,15 @@ import pickle
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
from documents.models import Document
|
||||
from documents.models import MatchingModel
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
pass
|
||||
@@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
|
||||
def preprocess_content(content: str) -> str:
|
||||
content = content.lower().strip()
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
return content
|
||||
|
||||
|
||||
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
if not os.path.isfile(settings.MODEL_FILE):
|
||||
logger.debug(
|
||||
@@ -81,6 +75,8 @@ class DocumentClassifier:
|
||||
self.document_type_classifier = None
|
||||
self.storage_path_classifier = None
|
||||
|
||||
self.stemmer = None
|
||||
|
||||
def load(self):
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
@@ -139,11 +135,11 @@ class DocumentClassifier:
|
||||
|
||||
def train(self):
|
||||
|
||||
data = list()
|
||||
labels_tags = list()
|
||||
labels_correspondent = list()
|
||||
labels_document_type = list()
|
||||
labels_storage_path = list()
|
||||
data = []
|
||||
labels_tags = []
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
@@ -151,7 +147,7 @@ class DocumentClassifier:
|
||||
for doc in Document.objects.order_by("pk").exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
):
|
||||
preprocessed_content = preprocess_content(doc.content)
|
||||
preprocessed_content = self.preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode("utf-8"))
|
||||
data.append(preprocessed_content)
|
||||
|
||||
@@ -231,6 +227,11 @@ class DocumentClassifier:
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
# See the notes here:
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
|
||||
# This attribute isn't needed to function and can be large
|
||||
self.data_vectorizer.stop_words_ = None
|
||||
|
||||
# Step 3: train the classifiers
|
||||
if num_tags > 0:
|
||||
logger.debug("Training tags classifier...")
|
||||
@@ -296,9 +297,36 @@ class DocumentClassifier:
|
||||
|
||||
return True
|
||||
|
||||
def preprocess_content(self, content: str) -> str:
|
||||
"""
|
||||
Process to contents of a document, distilling it down into
|
||||
words which are meaningful to the content
|
||||
"""
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import SnowballStemmer
|
||||
|
||||
if self.stemmer is None:
|
||||
self.stemmer = SnowballStemmer("english")
|
||||
|
||||
# Lower case the document
|
||||
content = content.lower().strip()
|
||||
# Get only the letters (remove punctuation too)
|
||||
content = re.sub(r"[^\w\s]", " ", content)
|
||||
# Tokenize
|
||||
# TODO configurable language
|
||||
words: List[str] = word_tokenize(content, language="english")
|
||||
# Remove stop words
|
||||
stops = set(stopwords.words("english"))
|
||||
meaningful_words = [w for w in words if w not in stops]
|
||||
# Stem words
|
||||
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
|
||||
|
||||
return " ".join(meaningful_words)
|
||||
|
||||
def predict_correspondent(self, content):
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
if correspondent_id != -1:
|
||||
return correspondent_id
|
||||
@@ -309,7 +337,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_document_type(self, content):
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
if document_type_id != -1:
|
||||
return document_type_id
|
||||
@@ -322,7 +350,7 @@ class DocumentClassifier:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
@@ -341,7 +369,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_storage_path(self, content):
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
if storage_path_id != -1:
|
||||
return storage_path_id
|
||||
|
Reference in New Issue
Block a user