diff --git a/Pipfile b/Pipfile index 25a159b6a..4c3cb76f2 100644 --- a/Pipfile +++ b/Pipfile @@ -56,6 +56,7 @@ mysqlclient = "*" celery = {extras = ["redis"], version = "*"} django-celery-results = "*" setproctitle = "*" +nltk = "*" [dev-packages] coveralls = "*" diff --git a/Pipfile.lock b/Pipfile.lock index 8831f8db9..4e87ca968 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -889,6 +889,14 @@ "index": "pypi", "version": "==2.1.1" }, + "nltk": { + "hashes": [ + "sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8", + "sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502" + ], + "index": "pypi", + "version": "==3.7" + }, "numpy": { "hashes": [ "sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676", diff --git a/docker/docker-entrypoint.sh b/docker/docker-entrypoint.sh index 14a0650f0..c97225007 100755 --- a/docker/docker-entrypoint.sh +++ b/docker/docker-entrypoint.sh @@ -53,6 +53,24 @@ map_folders() { export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}" } +nltk_data () { + # Store the NLTK data outside the Docker container + local nltk_data_dir="${DATA_DIR}/nltk" + + # Download or update the snowball stemmer data + python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data + + # Download or update the stopwords corpus + python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords + + # Download or update the punkt tokenizer data + python3 -m nltk.downloader -d "${nltk_data_dir}" punkt + + # Set env so nltk can find the downloaded data + export NLTK_DATA="${nltk_data_dir}" + +} + initialize() { # Setup environment from secrets before anything else @@ -105,6 +123,8 @@ initialize() { done set -e + nltk_data + "${gosu_cmd[@]}" /sbin/docker-prepare.sh } diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 927151b1d..27964d0f8 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -5,12 +5,15 @@ import pickle import re import shutil import warnings +from typing import List from typing import Optional from django.conf import settings from documents.models import Document from documents.models import MatchingModel +logger = logging.getLogger("paperless.classifier") + class IncompatibleClassifierVersionError(Exception): pass @@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception): pass -logger = logging.getLogger("paperless.classifier") - - -def preprocess_content(content: str) -> str: - content = content.lower().strip() - content = re.sub(r"\s+", " ", content) - return content - - def load_classifier() -> Optional["DocumentClassifier"]: if not os.path.isfile(settings.MODEL_FILE): logger.debug( @@ -81,6 +75,8 @@ class DocumentClassifier: self.document_type_classifier = None self.storage_path_classifier = None + self.stemmer = None + def load(self): # Catch warnings for processing with warnings.catch_warnings(record=True) as w: @@ -139,11 +135,11 @@ class DocumentClassifier: def train(self): - data = list() - labels_tags = list() - labels_correspondent = list() - labels_document_type = list() - labels_storage_path = list() + data = [] + labels_tags = [] + labels_correspondent = [] + labels_document_type = [] + labels_storage_path = [] # Step 1: Extract and preprocess training data from the database. logger.debug("Gathering data from database...") @@ -151,7 +147,7 @@ class DocumentClassifier: for doc in Document.objects.order_by("pk").exclude( tags__is_inbox_tag=True, ): - preprocessed_content = preprocess_content(doc.content) + preprocessed_content = self.preprocess_content(doc.content) m.update(preprocessed_content.encode("utf-8")) data.append(preprocessed_content) @@ -231,6 +227,11 @@ class DocumentClassifier: ) data_vectorized = self.data_vectorizer.fit_transform(data) + # See the notes here: + # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501 + # This attribute isn't needed to function and can be large + self.data_vectorizer.stop_words_ = None + # Step 3: train the classifiers if num_tags > 0: logger.debug("Training tags classifier...") @@ -296,9 +297,36 @@ class DocumentClassifier: return True + def preprocess_content(self, content: str) -> str: + """ + Process to contents of a document, distilling it down into + words which are meaningful to the content + """ + from nltk.tokenize import word_tokenize + from nltk.corpus import stopwords + from nltk.stem import SnowballStemmer + + if self.stemmer is None: + self.stemmer = SnowballStemmer("english") + + # Lower case the document + content = content.lower().strip() + # Get only the letters (remove punctuation too) + content = re.sub(r"[^\w\s]", " ", content) + # Tokenize + # TODO configurable language + words: List[str] = word_tokenize(content, language="english") + # Remove stop words + stops = set(stopwords.words("english")) + meaningful_words = [w for w in words if w not in stops] + # Stem words + meaningful_words = [self.stemmer.stem(w) for w in meaningful_words] + + return " ".join(meaningful_words) + def predict_correspondent(self, content): if self.correspondent_classifier: - X = self.data_vectorizer.transform([preprocess_content(content)]) + X = self.data_vectorizer.transform([self.preprocess_content(content)]) correspondent_id = self.correspondent_classifier.predict(X) if correspondent_id != -1: return correspondent_id @@ -309,7 +337,7 @@ class DocumentClassifier: def predict_document_type(self, content): if self.document_type_classifier: - X = self.data_vectorizer.transform([preprocess_content(content)]) + X = self.data_vectorizer.transform([self.preprocess_content(content)]) document_type_id = self.document_type_classifier.predict(X) if document_type_id != -1: return document_type_id @@ -322,7 +350,7 @@ class DocumentClassifier: from sklearn.utils.multiclass import type_of_target if self.tags_classifier: - X = self.data_vectorizer.transform([preprocess_content(content)]) + X = self.data_vectorizer.transform([self.preprocess_content(content)]) y = self.tags_classifier.predict(X) tags_ids = self.tags_binarizer.inverse_transform(y)[0] if type_of_target(y).startswith("multilabel"): @@ -341,7 +369,7 @@ class DocumentClassifier: def predict_storage_path(self, content): if self.storage_path_classifier: - X = self.data_vectorizer.transform([preprocess_content(content)]) + X = self.data_vectorizer.transform([self.preprocess_content(content)]) storage_path_id = self.storage_path_classifier.predict(X) if storage_path_id != -1: return storage_path_id