mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:
		 Trenton Holmes
					Trenton Holmes
				
			
				
					committed by
					
						 Trenton H
						Trenton H
					
				
			
			
				
	
			
			
			 Trenton H
						Trenton H
					
				
			
						parent
						
							14d82bd8ff
						
					
				
				
					commit
					d856e48045
				
			
							
								
								
									
										1
									
								
								Pipfile
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								Pipfile
									
									
									
									
									
								
							| @@ -56,6 +56,7 @@ mysqlclient = "*" | ||||
| celery = {extras = ["redis"], version = "*"} | ||||
| django-celery-results = "*" | ||||
| setproctitle = "*" | ||||
| nltk = "*" | ||||
|  | ||||
| [dev-packages] | ||||
| coveralls = "*" | ||||
|   | ||||
							
								
								
									
										8
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										8
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							| @@ -889,6 +889,14 @@ | ||||
|             "index": "pypi", | ||||
|             "version": "==2.1.1" | ||||
|         }, | ||||
|         "nltk": { | ||||
|             "hashes": [ | ||||
|                 "sha256:ba3de02490308b248f9b94c8bc1ac0683e9aa2ec49ee78536d8667afb5e3eec8", | ||||
|                 "sha256:d6507d6460cec76d70afea4242a226a7542f85c669177b9c7f562b7cf1b05502" | ||||
|             ], | ||||
|             "index": "pypi", | ||||
|             "version": "==3.7" | ||||
|         }, | ||||
|         "numpy": { | ||||
|             "hashes": [ | ||||
|                 "sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676", | ||||
|   | ||||
| @@ -53,6 +53,24 @@ map_folders() { | ||||
| 	export CONSUME_DIR="${PAPERLESS_CONSUMPTION_DIR:-/usr/src/paperless/consume}" | ||||
| } | ||||
|  | ||||
| nltk_data () { | ||||
| 	# Store the NLTK data outside the Docker container | ||||
| 	local nltk_data_dir="${DATA_DIR}/nltk" | ||||
|  | ||||
| 	# Download or update the snowball stemmer data | ||||
| 	python3 -m nltk.downloader -d "${nltk_data_dir}" snowball_data | ||||
|  | ||||
| 	# Download or update the stopwords corpus | ||||
| 	python3 -m nltk.downloader -d "${nltk_data_dir}" stopwords | ||||
|  | ||||
| 	# Download or update the punkt tokenizer data | ||||
| 	python3 -m nltk.downloader -d "${nltk_data_dir}" punkt | ||||
|  | ||||
| 	# Set env so nltk can find the downloaded data | ||||
| 	export NLTK_DATA="${nltk_data_dir}" | ||||
|  | ||||
| } | ||||
|  | ||||
| initialize() { | ||||
|  | ||||
| 	# Setup environment from secrets before anything else | ||||
| @@ -105,6 +123,8 @@ initialize() { | ||||
| 	done | ||||
| 	set -e | ||||
|  | ||||
| 	nltk_data | ||||
|  | ||||
| 	"${gosu_cmd[@]}" /sbin/docker-prepare.sh | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -5,12 +5,15 @@ import pickle | ||||
| import re | ||||
| import shutil | ||||
| import warnings | ||||
| from typing import List | ||||
| from typing import Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| from documents.models import Document | ||||
| from documents.models import MatchingModel | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| class IncompatibleClassifierVersionError(Exception): | ||||
|     pass | ||||
| @@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| def preprocess_content(content: str) -> str: | ||||
|     content = content.lower().strip() | ||||
|     content = re.sub(r"\s+", " ", content) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def load_classifier() -> Optional["DocumentClassifier"]: | ||||
|     if not os.path.isfile(settings.MODEL_FILE): | ||||
|         logger.debug( | ||||
| @@ -81,6 +75,8 @@ class DocumentClassifier: | ||||
|         self.document_type_classifier = None | ||||
|         self.storage_path_classifier = None | ||||
|  | ||||
|         self.stemmer = None | ||||
|  | ||||
|     def load(self): | ||||
|         # Catch warnings for processing | ||||
|         with warnings.catch_warnings(record=True) as w: | ||||
| @@ -139,11 +135,11 @@ class DocumentClassifier: | ||||
|  | ||||
|     def train(self): | ||||
|  | ||||
|         data = list() | ||||
|         labels_tags = list() | ||||
|         labels_correspondent = list() | ||||
|         labels_document_type = list() | ||||
|         labels_storage_path = list() | ||||
|         data = [] | ||||
|         labels_tags = [] | ||||
|         labels_correspondent = [] | ||||
|         labels_document_type = [] | ||||
|         labels_storage_path = [] | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logger.debug("Gathering data from database...") | ||||
| @@ -151,7 +147,7 @@ class DocumentClassifier: | ||||
|         for doc in Document.objects.order_by("pk").exclude( | ||||
|             tags__is_inbox_tag=True, | ||||
|         ): | ||||
|             preprocessed_content = preprocess_content(doc.content) | ||||
|             preprocessed_content = self.preprocess_content(doc.content) | ||||
|             m.update(preprocessed_content.encode("utf-8")) | ||||
|             data.append(preprocessed_content) | ||||
|  | ||||
| @@ -231,6 +227,11 @@ class DocumentClassifier: | ||||
|         ) | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||
|  | ||||
|         # See the notes here: | ||||
|         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 | ||||
|         # This attribute isn't needed to function and can be large | ||||
|         self.data_vectorizer.stop_words_ = None | ||||
|  | ||||
|         # Step 3: train the classifiers | ||||
|         if num_tags > 0: | ||||
|             logger.debug("Training tags classifier...") | ||||
| @@ -296,9 +297,36 @@ class DocumentClassifier: | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def preprocess_content(self, content: str) -> str: | ||||
|         """ | ||||
|         Process to contents of a document, distilling it down into | ||||
|         words which are meaningful to the content | ||||
|         """ | ||||
|         from nltk.tokenize import word_tokenize | ||||
|         from nltk.corpus import stopwords | ||||
|         from nltk.stem import SnowballStemmer | ||||
|  | ||||
|         if self.stemmer is None: | ||||
|             self.stemmer = SnowballStemmer("english") | ||||
|  | ||||
|         # Lower case the document | ||||
|         content = content.lower().strip() | ||||
|         # Get only the letters (remove punctuation too) | ||||
|         content = re.sub(r"[^\w\s]", " ", content) | ||||
|         # Tokenize | ||||
|         # TODO configurable language | ||||
|         words: List[str] = word_tokenize(content, language="english") | ||||
|         # Remove stop words | ||||
|         stops = set(stopwords.words("english")) | ||||
|         meaningful_words = [w for w in words if w not in stops] | ||||
|         # Stem words | ||||
|         meaningful_words = [self.stemmer.stem(w) for w in meaningful_words] | ||||
|  | ||||
|         return " ".join(meaningful_words) | ||||
|  | ||||
|     def predict_correspondent(self, content): | ||||
|         if self.correspondent_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             correspondent_id = self.correspondent_classifier.predict(X) | ||||
|             if correspondent_id != -1: | ||||
|                 return correspondent_id | ||||
| @@ -309,7 +337,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_document_type(self, content): | ||||
|         if self.document_type_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             document_type_id = self.document_type_classifier.predict(X) | ||||
|             if document_type_id != -1: | ||||
|                 return document_type_id | ||||
| @@ -322,7 +350,7 @@ class DocumentClassifier: | ||||
|         from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
|         if self.tags_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             y = self.tags_classifier.predict(X) | ||||
|             tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||
|             if type_of_target(y).startswith("multilabel"): | ||||
| @@ -341,7 +369,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_storage_path(self, content): | ||||
|         if self.storage_path_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             storage_path_id = self.storage_path_classifier.predict(X) | ||||
|             if storage_path_id != -1: | ||||
|                 return storage_path_id | ||||
|   | ||||
		Reference in New Issue
	
	Block a user