mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 03:16:10 -06:00 
			
		
		
		
	Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:
		
				
					committed by
					
						
						Trenton H
					
				
			
			
				
	
			
			
			
						parent
						
							14d82bd8ff
						
					
				
				
					commit
					d856e48045
				
			
							
								
								
									
										1
									
								
								Pipfile
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								Pipfile
									
									
									
									
									
								
							@@ -56,6 +56,7 @@ mysqlclient = "*"
 | 
			
		||||
celery = {extras = ["redis"], version = "*"}
 | 
			
		||||
django-celery-results = "*"
 | 
			
		||||
setproctitle = "*"
 | 
			
		||||
nltk = "*"
 | 
			
		||||
 | 
			
		||||
[dev-packages]
 | 
			
		||||
coveralls = "*"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										8
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										8
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							@@ -889,6 +889,14 @@
 | 
			
		||||
            "index": "pypi",
 | 
			
		||||
            "version": "==2.1.1"
 | 
			
		||||
        },
 | 
			
		||||
        "nltk": {
 | 
			
		||||
            "hashes": [
 | 
			
		||||
                "sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8",
 | 
			
		||||
                "sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502"
 | 
			
		||||
            ],
 | 
			
		||||
            "index": "pypi",
 | 
			
		||||
            "version": "==3.7"
 | 
			
		||||
        },
 | 
			
		||||
        "numpy": {
 | 
			
		||||
            "hashes": [
 | 
			
		||||
                "sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676",
 | 
			
		||||
 
 | 
			
		||||
@@ -53,6 +53,24 @@ map_folders() {
 | 
			
		||||
	export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
nltk_data () {
 | 
			
		||||
	# Store the NLTK data outside the Docker container
 | 
			
		||||
	local nltk_data_dir="${DATA_DIR}/nltk"
 | 
			
		||||
 | 
			
		||||
	# Download or update the snowball stemmer data
 | 
			
		||||
	python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data
 | 
			
		||||
 | 
			
		||||
	# Download or update the stopwords corpus
 | 
			
		||||
	python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords
 | 
			
		||||
 | 
			
		||||
	# Download or update the punkt tokenizer data
 | 
			
		||||
	python3 -m nltk.downloader -d "${nltk_data_dir}" punkt
 | 
			
		||||
 | 
			
		||||
	# Set env so nltk can find the downloaded data
 | 
			
		||||
	export NLTK_DATA="${nltk_data_dir}"
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
initialize() {
 | 
			
		||||
 | 
			
		||||
	# Setup environment from secrets before anything else
 | 
			
		||||
@@ -105,6 +123,8 @@ initialize() {
 | 
			
		||||
	done
 | 
			
		||||
	set -e
 | 
			
		||||
 | 
			
		||||
	nltk_data
 | 
			
		||||
 | 
			
		||||
	"${gosu_cmd[@]}" /sbin/docker-prepare.sh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,12 +5,15 @@ import pickle
 | 
			
		||||
import re
 | 
			
		||||
import shutil
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from documents.models import Document
 | 
			
		||||
from documents.models import MatchingModel
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger("paperless.classifier")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IncompatibleClassifierVersionError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
@@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger("paperless.classifier")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_content(content: str) -> str:
 | 
			
		||||
    content = content.lower().strip()
 | 
			
		||||
    content = re.sub(r"\s+", " ", content)
 | 
			
		||||
    return content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_classifier() -> Optional["DocumentClassifier"]:
 | 
			
		||||
    if not os.path.isfile(settings.MODEL_FILE):
 | 
			
		||||
        logger.debug(
 | 
			
		||||
@@ -81,6 +75,8 @@ class DocumentClassifier:
 | 
			
		||||
        self.document_type_classifier = None
 | 
			
		||||
        self.storage_path_classifier = None
 | 
			
		||||
 | 
			
		||||
        self.stemmer = None
 | 
			
		||||
 | 
			
		||||
    def load(self):
 | 
			
		||||
        # Catch warnings for processing
 | 
			
		||||
        with warnings.catch_warnings(record=True) as w:
 | 
			
		||||
@@ -139,11 +135,11 @@ class DocumentClassifier:
 | 
			
		||||
 | 
			
		||||
    def train(self):
 | 
			
		||||
 | 
			
		||||
        data = list()
 | 
			
		||||
        labels_tags = list()
 | 
			
		||||
        labels_correspondent = list()
 | 
			
		||||
        labels_document_type = list()
 | 
			
		||||
        labels_storage_path = list()
 | 
			
		||||
        data = []
 | 
			
		||||
        labels_tags = []
 | 
			
		||||
        labels_correspondent = []
 | 
			
		||||
        labels_document_type = []
 | 
			
		||||
        labels_storage_path = []
 | 
			
		||||
 | 
			
		||||
        # Step 1: Extract and preprocess training data from the database.
 | 
			
		||||
        logger.debug("Gathering data from database...")
 | 
			
		||||
@@ -151,7 +147,7 @@ class DocumentClassifier:
 | 
			
		||||
        for doc in Document.objects.order_by("pk").exclude(
 | 
			
		||||
            tags__is_inbox_tag=True,
 | 
			
		||||
        ):
 | 
			
		||||
            preprocessed_content = preprocess_content(doc.content)
 | 
			
		||||
            preprocessed_content = self.preprocess_content(doc.content)
 | 
			
		||||
            m.update(preprocessed_content.encode("utf-8"))
 | 
			
		||||
            data.append(preprocessed_content)
 | 
			
		||||
 | 
			
		||||
@@ -231,6 +227,11 @@ class DocumentClassifier:
 | 
			
		||||
        )
 | 
			
		||||
        data_vectorized = self.data_vectorizer.fit_transform(data)
 | 
			
		||||
 | 
			
		||||
        # See the notes here:
 | 
			
		||||
        # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501
 | 
			
		||||
        # This attribute isn't needed to function and can be large
 | 
			
		||||
        self.data_vectorizer.stop_words_ = None
 | 
			
		||||
 | 
			
		||||
        # Step 3: train the classifiers
 | 
			
		||||
        if num_tags > 0:
 | 
			
		||||
            logger.debug("Training tags classifier...")
 | 
			
		||||
@@ -296,9 +297,36 @@ class DocumentClassifier:
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def preprocess_content(self, content: str) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Process to contents of a document, distilling it down into
 | 
			
		||||
        words which are meaningful to the content
 | 
			
		||||
        """
 | 
			
		||||
        from nltk.tokenize import word_tokenize
 | 
			
		||||
        from nltk.corpus import stopwords
 | 
			
		||||
        from nltk.stem import SnowballStemmer
 | 
			
		||||
 | 
			
		||||
        if self.stemmer is None:
 | 
			
		||||
            self.stemmer = SnowballStemmer("english")
 | 
			
		||||
 | 
			
		||||
        # Lower case the document
 | 
			
		||||
        content = content.lower().strip()
 | 
			
		||||
        # Get only the letters (remove punctuation too)
 | 
			
		||||
        content = re.sub(r"[^\w\s]", " ", content)
 | 
			
		||||
        # Tokenize
 | 
			
		||||
        # TODO configurable language
 | 
			
		||||
        words: List[str] = word_tokenize(content, language="english")
 | 
			
		||||
        # Remove stop words
 | 
			
		||||
        stops = set(stopwords.words("english"))
 | 
			
		||||
        meaningful_words = [w for w in words if w not in stops]
 | 
			
		||||
        # Stem words
 | 
			
		||||
        meaningful_words = [self.stemmer.stem(w) for w in meaningful_words]
 | 
			
		||||
 | 
			
		||||
        return " ".join(meaningful_words)
 | 
			
		||||
 | 
			
		||||
    def predict_correspondent(self, content):
 | 
			
		||||
        if self.correspondent_classifier:
 | 
			
		||||
            X = self.data_vectorizer.transform([preprocess_content(content)])
 | 
			
		||||
            X = self.data_vectorizer.transform([self.preprocess_content(content)])
 | 
			
		||||
            correspondent_id = self.correspondent_classifier.predict(X)
 | 
			
		||||
            if correspondent_id != -1:
 | 
			
		||||
                return correspondent_id
 | 
			
		||||
@@ -309,7 +337,7 @@ class DocumentClassifier:
 | 
			
		||||
 | 
			
		||||
    def predict_document_type(self, content):
 | 
			
		||||
        if self.document_type_classifier:
 | 
			
		||||
            X = self.data_vectorizer.transform([preprocess_content(content)])
 | 
			
		||||
            X = self.data_vectorizer.transform([self.preprocess_content(content)])
 | 
			
		||||
            document_type_id = self.document_type_classifier.predict(X)
 | 
			
		||||
            if document_type_id != -1:
 | 
			
		||||
                return document_type_id
 | 
			
		||||
@@ -322,7 +350,7 @@ class DocumentClassifier:
 | 
			
		||||
        from sklearn.utils.multiclass import type_of_target
 | 
			
		||||
 | 
			
		||||
        if self.tags_classifier:
 | 
			
		||||
            X = self.data_vectorizer.transform([preprocess_content(content)])
 | 
			
		||||
            X = self.data_vectorizer.transform([self.preprocess_content(content)])
 | 
			
		||||
            y = self.tags_classifier.predict(X)
 | 
			
		||||
            tags_ids = self.tags_binarizer.inverse_transform(y)[0]
 | 
			
		||||
            if type_of_target(y).startswith("multilabel"):
 | 
			
		||||
@@ -341,7 +369,7 @@ class DocumentClassifier:
 | 
			
		||||
 | 
			
		||||
    def predict_storage_path(self, content):
 | 
			
		||||
        if self.storage_path_classifier:
 | 
			
		||||
            X = self.data_vectorizer.transform([preprocess_content(content)])
 | 
			
		||||
            X = self.data_vectorizer.transform([self.preprocess_content(content)])
 | 
			
		||||
            storage_path_id = self.storage_path_classifier.predict(X)
 | 
			
		||||
            if storage_path_id != -1:
 | 
			
		||||
                return storage_path_id
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user