mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Updates the pre-processing of document content to be much more robust, with tokenization, stemming and stop word removal
This commit is contained in:
		 Trenton Holmes
					Trenton Holmes
				
			
				
					committed by
					
						 Trenton H
						Trenton H
					
				
			
			
				
	
			
			
			 Trenton H
						Trenton H
					
				
			
						parent
						
							14d82bd8ff
						
					
				
				
					commit
					d856e48045
				
			| @@ -5,12 +5,15 @@ import pickle | ||||
| import re | ||||
| import shutil | ||||
| import warnings | ||||
| from typing import List | ||||
| from typing import Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| from documents.models import Document | ||||
| from documents.models import MatchingModel | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| class IncompatibleClassifierVersionError(Exception): | ||||
|     pass | ||||
| @@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| def preprocess_content(content: str) -> str: | ||||
|     content = content.lower().strip() | ||||
|     content = re.sub(r"\s+", " ", content) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def load_classifier() -> Optional["DocumentClassifier"]: | ||||
|     if not os.path.isfile(settings.MODEL_FILE): | ||||
|         logger.debug( | ||||
| @@ -81,6 +75,8 @@ class DocumentClassifier: | ||||
|         self.document_type_classifier = None | ||||
|         self.storage_path_classifier = None | ||||
|  | ||||
|         self.stemmer = None | ||||
|  | ||||
|     def load(self): | ||||
|         # Catch warnings for processing | ||||
|         with warnings.catch_warnings(record=True) as w: | ||||
| @@ -139,11 +135,11 @@ class DocumentClassifier: | ||||
|  | ||||
|     def train(self): | ||||
|  | ||||
|         data = list() | ||||
|         labels_tags = list() | ||||
|         labels_correspondent = list() | ||||
|         labels_document_type = list() | ||||
|         labels_storage_path = list() | ||||
|         data = [] | ||||
|         labels_tags = [] | ||||
|         labels_correspondent = [] | ||||
|         labels_document_type = [] | ||||
|         labels_storage_path = [] | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logger.debug("Gathering data from database...") | ||||
| @@ -151,7 +147,7 @@ class DocumentClassifier: | ||||
|         for doc in Document.objects.order_by("pk").exclude( | ||||
|             tags__is_inbox_tag=True, | ||||
|         ): | ||||
|             preprocessed_content = preprocess_content(doc.content) | ||||
|             preprocessed_content = self.preprocess_content(doc.content) | ||||
|             m.update(preprocessed_content.encode("utf-8")) | ||||
|             data.append(preprocessed_content) | ||||
|  | ||||
| @@ -231,6 +227,11 @@ class DocumentClassifier: | ||||
|         ) | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||
|  | ||||
|         # See the notes here: | ||||
|         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 | ||||
|         # This attribute isn't needed to function and can be large | ||||
|         self.data_vectorizer.stop_words_ = None | ||||
|  | ||||
|         # Step 3: train the classifiers | ||||
|         if num_tags > 0: | ||||
|             logger.debug("Training tags classifier...") | ||||
| @@ -296,9 +297,36 @@ class DocumentClassifier: | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def preprocess_content(self, content: str) -> str: | ||||
|         """ | ||||
|         Process to contents of a document, distilling it down into | ||||
|         words which are meaningful to the content | ||||
|         """ | ||||
|         from nltk.tokenize import word_tokenize | ||||
|         from nltk.corpus import stopwords | ||||
|         from nltk.stem import SnowballStemmer | ||||
|  | ||||
|         if self.stemmer is None: | ||||
|             self.stemmer = SnowballStemmer("english") | ||||
|  | ||||
|         # Lower case the document | ||||
|         content = content.lower().strip() | ||||
|         # Get only the letters (remove punctuation too) | ||||
|         content = re.sub(r"[^\w\s]", " ", content) | ||||
|         # Tokenize | ||||
|         # TODO configurable language | ||||
|         words: List[str] = word_tokenize(content, language="english") | ||||
|         # Remove stop words | ||||
|         stops = set(stopwords.words("english")) | ||||
|         meaningful_words = [w for w in words if w not in stops] | ||||
|         # Stem words | ||||
|         meaningful_words = [self.stemmer.stem(w) for w in meaningful_words] | ||||
|  | ||||
|         return " ".join(meaningful_words) | ||||
|  | ||||
|     def predict_correspondent(self, content): | ||||
|         if self.correspondent_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             correspondent_id = self.correspondent_classifier.predict(X) | ||||
|             if correspondent_id != -1: | ||||
|                 return correspondent_id | ||||
| @@ -309,7 +337,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_document_type(self, content): | ||||
|         if self.document_type_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             document_type_id = self.document_type_classifier.predict(X) | ||||
|             if document_type_id != -1: | ||||
|                 return document_type_id | ||||
| @@ -322,7 +350,7 @@ class DocumentClassifier: | ||||
|         from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
|         if self.tags_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             y = self.tags_classifier.predict(X) | ||||
|             tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||
|             if type_of_target(y).startswith("multilabel"): | ||||
| @@ -341,7 +369,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_storage_path(self, content): | ||||
|         if self.storage_path_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             storage_path_id = self.storage_path_classifier.predict(X) | ||||
|             if storage_path_id != -1: | ||||
|                 return storage_path_id | ||||
|   | ||||
		Reference in New Issue
	
	Block a user