mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Changes classifier training to hold less data in memory at the same time
This commit is contained in:
		| @@ -5,6 +5,7 @@ import pickle | ||||
| import re | ||||
| import shutil | ||||
| import warnings | ||||
| from typing import Iterator | ||||
| from typing import List | ||||
| from typing import Optional | ||||
|  | ||||
| @@ -136,21 +137,22 @@ class DocumentClassifier: | ||||
|  | ||||
|     def train(self): | ||||
|  | ||||
|         data = [] | ||||
|         labels_tags = [] | ||||
|         labels_correspondent = [] | ||||
|         labels_document_type = [] | ||||
|         labels_storage_path = [] | ||||
|  | ||||
|         docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True) | ||||
|  | ||||
|         if docs_queryset.count() == 0: | ||||
|             raise ValueError("No training data available.") | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logger.debug("Gathering data from database...") | ||||
|         m = hashlib.sha1() | ||||
|         for doc in Document.objects.order_by("pk").exclude( | ||||
|             tags__is_inbox_tag=True, | ||||
|         ): | ||||
|         for doc in docs_queryset: | ||||
|             preprocessed_content = self.preprocess_content(doc.content) | ||||
|             m.update(preprocessed_content.encode("utf-8")) | ||||
|             data.append(preprocessed_content) | ||||
|  | ||||
|             y = -1 | ||||
|             dt = doc.document_type | ||||
| @@ -183,9 +185,6 @@ class DocumentClassifier: | ||||
|             m.update(y.to_bytes(4, "little", signed=True)) | ||||
|             labels_storage_path.append(y) | ||||
|  | ||||
|         if not data: | ||||
|             raise ValueError("No training data available.") | ||||
|  | ||||
|         new_data_hash = m.digest() | ||||
|  | ||||
|         if self.data_hash and new_data_hash == self.data_hash: | ||||
| @@ -207,7 +206,7 @@ class DocumentClassifier: | ||||
|         logger.debug( | ||||
|             "{} documents, {} tag(s), {} correspondent(s), " | ||||
|             "{} document type(s). {} storage path(es)".format( | ||||
|                 len(data), | ||||
|                 docs_queryset.count(), | ||||
|                 num_tags, | ||||
|                 num_correspondents, | ||||
|                 num_document_types, | ||||
| @@ -221,12 +220,18 @@ class DocumentClassifier: | ||||
|  | ||||
|         # Step 2: vectorize data | ||||
|         logger.debug("Vectorizing data...") | ||||
|  | ||||
|         def content_generator() -> Iterator[str]: | ||||
|             for doc in docs_queryset: | ||||
|                 yield self.preprocess_content(doc.content) | ||||
|  | ||||
|         self.data_vectorizer = CountVectorizer( | ||||
|             analyzer="word", | ||||
|             ngram_range=(1, 2), | ||||
|             min_df=0.01, | ||||
|         ) | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||
|  | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(content_generator()) | ||||
|  | ||||
|         # See the notes here: | ||||
|         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 | ||||
| @@ -341,7 +346,7 @@ class DocumentClassifier: | ||||
|  | ||||
|         return content | ||||
|  | ||||
|     def predict_correspondent(self, content): | ||||
|     def predict_correspondent(self, content: str): | ||||
|         if self.correspondent_classifier: | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             correspondent_id = self.correspondent_classifier.predict(X) | ||||
| @@ -352,7 +357,7 @@ class DocumentClassifier: | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     def predict_document_type(self, content): | ||||
|     def predict_document_type(self, content: str): | ||||
|         if self.document_type_classifier: | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             document_type_id = self.document_type_classifier.predict(X) | ||||
| @@ -363,7 +368,7 @@ class DocumentClassifier: | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     def predict_tags(self, content): | ||||
|     def predict_tags(self, content: str): | ||||
|         from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
|         if self.tags_classifier: | ||||
| @@ -384,7 +389,7 @@ class DocumentClassifier: | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|     def predict_storage_path(self, content): | ||||
|     def predict_storage_path(self, content: str): | ||||
|         if self.storage_path_classifier: | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             storage_path_id = self.storage_path_classifier.predict(X) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Trenton H
					Trenton H