mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Changes classifier training to hold less data in memory at the same time
This commit is contained in:
		| @@ -5,6 +5,7 @@ import pickle | |||||||
| import re | import re | ||||||
| import shutil | import shutil | ||||||
| import warnings | import warnings | ||||||
|  | from typing import Iterator | ||||||
| from typing import List | from typing import List | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| @@ -136,21 +137,22 @@ class DocumentClassifier: | |||||||
|  |  | ||||||
|     def train(self): |     def train(self): | ||||||
|  |  | ||||||
|         data = [] |  | ||||||
|         labels_tags = [] |         labels_tags = [] | ||||||
|         labels_correspondent = [] |         labels_correspondent = [] | ||||||
|         labels_document_type = [] |         labels_document_type = [] | ||||||
|         labels_storage_path = [] |         labels_storage_path = [] | ||||||
|  |  | ||||||
|  |         docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True) | ||||||
|  |  | ||||||
|  |         if docs_queryset.count() == 0: | ||||||
|  |             raise ValueError("No training data available.") | ||||||
|  |  | ||||||
|         # Step 1: Extract and preprocess training data from the database. |         # Step 1: Extract and preprocess training data from the database. | ||||||
|         logger.debug("Gathering data from database...") |         logger.debug("Gathering data from database...") | ||||||
|         m = hashlib.sha1() |         m = hashlib.sha1() | ||||||
|         for doc in Document.objects.order_by("pk").exclude( |         for doc in docs_queryset: | ||||||
|             tags__is_inbox_tag=True, |  | ||||||
|         ): |  | ||||||
|             preprocessed_content = self.preprocess_content(doc.content) |             preprocessed_content = self.preprocess_content(doc.content) | ||||||
|             m.update(preprocessed_content.encode("utf-8")) |             m.update(preprocessed_content.encode("utf-8")) | ||||||
|             data.append(preprocessed_content) |  | ||||||
|  |  | ||||||
|             y = -1 |             y = -1 | ||||||
|             dt = doc.document_type |             dt = doc.document_type | ||||||
| @@ -183,9 +185,6 @@ class DocumentClassifier: | |||||||
|             m.update(y.to_bytes(4, "little", signed=True)) |             m.update(y.to_bytes(4, "little", signed=True)) | ||||||
|             labels_storage_path.append(y) |             labels_storage_path.append(y) | ||||||
|  |  | ||||||
|         if not data: |  | ||||||
|             raise ValueError("No training data available.") |  | ||||||
|  |  | ||||||
|         new_data_hash = m.digest() |         new_data_hash = m.digest() | ||||||
|  |  | ||||||
|         if self.data_hash and new_data_hash == self.data_hash: |         if self.data_hash and new_data_hash == self.data_hash: | ||||||
| @@ -207,7 +206,7 @@ class DocumentClassifier: | |||||||
|         logger.debug( |         logger.debug( | ||||||
|             "{} documents, {} tag(s), {} correspondent(s), " |             "{} documents, {} tag(s), {} correspondent(s), " | ||||||
|             "{} document type(s). {} storage path(es)".format( |             "{} document type(s). {} storage path(es)".format( | ||||||
|                 len(data), |                 docs_queryset.count(), | ||||||
|                 num_tags, |                 num_tags, | ||||||
|                 num_correspondents, |                 num_correspondents, | ||||||
|                 num_document_types, |                 num_document_types, | ||||||
| @@ -221,12 +220,18 @@ class DocumentClassifier: | |||||||
|  |  | ||||||
|         # Step 2: vectorize data |         # Step 2: vectorize data | ||||||
|         logger.debug("Vectorizing data...") |         logger.debug("Vectorizing data...") | ||||||
|  |  | ||||||
|  |         def content_generator() -> Iterator[str]: | ||||||
|  |             for doc in docs_queryset: | ||||||
|  |                 yield self.preprocess_content(doc.content) | ||||||
|  |  | ||||||
|         self.data_vectorizer = CountVectorizer( |         self.data_vectorizer = CountVectorizer( | ||||||
|             analyzer="word", |             analyzer="word", | ||||||
|             ngram_range=(1, 2), |             ngram_range=(1, 2), | ||||||
|             min_df=0.01, |             min_df=0.01, | ||||||
|         ) |         ) | ||||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) |  | ||||||
|  |         data_vectorized = self.data_vectorizer.fit_transform(content_generator()) | ||||||
|  |  | ||||||
|         # See the notes here: |         # See the notes here: | ||||||
|         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 |         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 | ||||||
| @@ -341,7 +346,7 @@ class DocumentClassifier: | |||||||
|  |  | ||||||
|         return content |         return content | ||||||
|  |  | ||||||
|     def predict_correspondent(self, content): |     def predict_correspondent(self, content: str): | ||||||
|         if self.correspondent_classifier: |         if self.correspondent_classifier: | ||||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) |             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||||
|             correspondent_id = self.correspondent_classifier.predict(X) |             correspondent_id = self.correspondent_classifier.predict(X) | ||||||
| @@ -352,7 +357,7 @@ class DocumentClassifier: | |||||||
|         else: |         else: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def predict_document_type(self, content): |     def predict_document_type(self, content: str): | ||||||
|         if self.document_type_classifier: |         if self.document_type_classifier: | ||||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) |             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||||
|             document_type_id = self.document_type_classifier.predict(X) |             document_type_id = self.document_type_classifier.predict(X) | ||||||
| @@ -363,7 +368,7 @@ class DocumentClassifier: | |||||||
|         else: |         else: | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def predict_tags(self, content): |     def predict_tags(self, content: str): | ||||||
|         from sklearn.utils.multiclass import type_of_target |         from sklearn.utils.multiclass import type_of_target | ||||||
|  |  | ||||||
|         if self.tags_classifier: |         if self.tags_classifier: | ||||||
| @@ -384,7 +389,7 @@ class DocumentClassifier: | |||||||
|         else: |         else: | ||||||
|             return [] |             return [] | ||||||
|  |  | ||||||
|     def predict_storage_path(self, content): |     def predict_storage_path(self, content: str): | ||||||
|         if self.storage_path_classifier: |         if self.storage_path_classifier: | ||||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) |             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||||
|             storage_path_id = self.storage_path_classifier.predict(X) |             storage_path_id = self.storage_path_classifier.predict(X) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Trenton H
					Trenton H