Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal

This commit is contained in:
Trenton Holmes
2022-09-15 08:39:47 -07:00
committed by Trenton H
parent 77a3f8ed60
commit 66884ea035
4 changed files with 76 additions and 19 deletions

View File

@@ -5,12 +5,15 @@ import pickle
import re
import shutil
import warnings
from typing import List
from typing import Optional
from django.conf import settings
from documents.models import Document
from documents.models import MatchingModel
logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception):
pass
@@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
pass
logger = logging.getLogger("paperless.classifier")
def preprocess_content(content: str) -> str:
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
@@ -81,6 +75,8 @@ class DocumentClassifier:
self.document_type_classifier = None
self.storage_path_classifier = None
self.stemmer = None
def load(self):
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
@@ -139,11 +135,11 @@ class DocumentClassifier:
def train(self):
data = list()
labels_tags = list()
labels_correspondent = list()
labels_document_type = list()
labels_storage_path = list()
data = []
labels_tags = []
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
@@ -151,7 +147,7 @@ class DocumentClassifier:
for doc in Document.objects.order_by("pk").exclude(
tags__is_inbox_tag=True,
):
preprocessed_content = preprocess_content(doc.content)
preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8"))
data.append(preprocessed_content)
@@ -231,6 +227,11 @@ class DocumentClassifier:
)
data_vectorized = self.data_vectorizer.fit_transform(data)
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
# This attribute isn't needed to function and can be large
self.data_vectorizer.stop_words_ = None
# Step 3: train the classifiers
if num_tags > 0:
logger.debug("Training tags classifier...")
@@ -296,9 +297,36 @@ class DocumentClassifier:
return True
def preprocess_content(self, content: str) -> str:
"""
Process to contents of a document, distilling it down into
words which are meaningful to the content
"""
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
if self.stemmer is None:
self.stemmer = SnowballStemmer("english")
# Lower case the document
content = content.lower().strip()
# Get only the letters (remove punctuation too)
content = re.sub(r"[^\w\s]", " ", content)
# Tokenize
# TODO configurable language
words: List[str] = word_tokenize(content, language="english")
# Remove stop words
stops = set(stopwords.words("english"))
meaningful_words = [w for w in words if w not in stops]
# Stem words
meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
return " ".join(meaningful_words)
def predict_correspondent(self, content):
if self.correspondent_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X)
if correspondent_id != -1:
return correspondent_id
@@ -309,7 +337,7 @@ class DocumentClassifier:
def predict_document_type(self, content):
if self.document_type_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X)
if document_type_id != -1:
return document_type_id
@@ -322,7 +350,7 @@ class DocumentClassifier:
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
if type_of_target(y).startswith("multilabel"):
@@ -341,7 +369,7 @@ class DocumentClassifier:
def predict_storage_path(self, content):
if self.storage_path_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)])
X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X)
if storage_path_id != -1:
return storage_path_id