mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	don't load sklearn libraries unless needed
This commit is contained in:
		| @@ -5,10 +5,6 @@ import pickle | |||||||
| import re | import re | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from sklearn.feature_extraction.text import CountVectorizer |  | ||||||
| from sklearn.neural_network import MLPClassifier |  | ||||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer |  | ||||||
| from sklearn.utils.multiclass import type_of_target |  | ||||||
|  |  | ||||||
| from documents.models import Document, MatchingModel | from documents.models import Document, MatchingModel | ||||||
|  |  | ||||||
| @@ -109,6 +105,10 @@ class DocumentClassifier(object): | |||||||
|             pickle.dump(self.document_type_classifier, f) |             pickle.dump(self.document_type_classifier, f) | ||||||
|  |  | ||||||
|     def train(self): |     def train(self): | ||||||
|  |         from sklearn.feature_extraction.text import CountVectorizer | ||||||
|  |         from sklearn.neural_network import MLPClassifier | ||||||
|  |         from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | ||||||
|  |  | ||||||
|         data = list() |         data = list() | ||||||
|         labels_tags = list() |         labels_tags = list() | ||||||
|         labels_correspondent = list() |         labels_correspondent = list() | ||||||
| @@ -265,6 +265,8 @@ class DocumentClassifier(object): | |||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def predict_tags(self, content): |     def predict_tags(self, content): | ||||||
|  |         from sklearn.utils.multiclass import type_of_target | ||||||
|  |  | ||||||
|         if self.tags_classifier: |         if self.tags_classifier: | ||||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) |             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||||
|             y = self.tags_classifier.predict(X) |             y = self.tags_classifier.predict(X) | ||||||
|   | |||||||
							
								
								
									
										
											BIN
										
									
								
								src/documents/tests/data/model.pickle
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/documents/tests/data/model.pickle
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -130,6 +130,15 @@ class TestClassifier(DirectoriesMixin, TestCase): | |||||||
|         new_classifier.reload() |         new_classifier.reload() | ||||||
|         self.assertFalse(new_classifier.train()) |         self.assertFalse(new_classifier.train()) | ||||||
|  |  | ||||||
|  |     @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle")) | ||||||
|  |     def test_load_and_classify(self): | ||||||
|  |         self.generate_test_data() | ||||||
|  |  | ||||||
|  |         new_classifier = DocumentClassifier() | ||||||
|  |         new_classifier.reload() | ||||||
|  |  | ||||||
|  |         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) | ||||||
|  |  | ||||||
|     def test_one_correspondent_predict(self): |     def test_one_correspondent_predict(self): | ||||||
|         c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) |         c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) | ||||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") |         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") | ||||||
|   | |||||||
| @@ -4,7 +4,6 @@ import multiprocessing | |||||||
| import os | import os | ||||||
| import re | import re | ||||||
|  |  | ||||||
| import dateparser |  | ||||||
| from dotenv import load_dotenv | from dotenv import load_dotenv | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| @@ -491,6 +490,10 @@ if PAPERLESS_TIKA_ENABLED: | |||||||
|  |  | ||||||
| # List dates that should be ignored when trying to parse date from document text | # List dates that should be ignored when trying to parse date from document text | ||||||
| IGNORE_DATES = set() | IGNORE_DATES = set() | ||||||
|  |  | ||||||
|  | if os.getenv("PAPERLESS_IGNORE_DATES", ""): | ||||||
|  |     import dateparser | ||||||
|  |  | ||||||
|     for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","): |     for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","): | ||||||
|         d = dateparser.parse(s) |         d = dateparser.parse(s) | ||||||
|         if d: |         if d: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler