mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Added code that trains models based on data from the databasae
This commit is contained in:
		
							
								
								
									
										100
									
								
								src/documents/management/commands/document_create_classifier.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										100
									
								
								src/documents/management/commands/document_create_classifier.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| import logging | ||||
| import os.path | ||||
| import pickle | ||||
|  | ||||
| from django.core.management.base import BaseCommand | ||||
| from sklearn.feature_extraction.text import CountVectorizer | ||||
| from sklearn.multiclass import OneVsRestClassifier | ||||
| from sklearn.naive_bayes import MultinomialNB | ||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder | ||||
|  | ||||
| from documents.models import Document | ||||
| from ...mixins import Renderable | ||||
|  | ||||
|  | ||||
| def preprocess_content(content): | ||||
|     content = content.lower() | ||||
|     content = content.strip() | ||||
|     content = content.replace("\n", " ") | ||||
|     content = content.replace("\r", " ") | ||||
|     while content.find("  ") > -1: | ||||
|         content = content.replace("  ", " ") | ||||
|     return content | ||||
|  | ||||
|  | ||||
| class Command(Renderable, BaseCommand): | ||||
|  | ||||
|     help = """ | ||||
|         There is no help. | ||||
|     """.replace("    ", "") | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         BaseCommand.__init__(self, *args, **kwargs) | ||||
|  | ||||
|     def handle(self, *args, **options): | ||||
|         data = list() | ||||
|         labels_tags = list() | ||||
|         labels_correspondent = list() | ||||
|         labels_type = list() | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logging.getLogger(__name__).info("Gathering data from database...") | ||||
|         for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||
|             data.append(preprocess_content(doc.content)) | ||||
|             labels_type.append(doc.document_type.name if doc.document_type is not None else "-") | ||||
|             labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") | ||||
|             tags = [tag.name for tag in doc.tags.all()] | ||||
|             labels_tags.append(tags) | ||||
|  | ||||
|         # Step 2: vectorize data | ||||
|         logging.getLogger(__name__).info("Vectorizing data...") | ||||
|         data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05) | ||||
|         data_vectorized = data_vectorizer.fit_transform(data) | ||||
|  | ||||
|         tags_binarizer = MultiLabelBinarizer() | ||||
|         labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags) | ||||
|  | ||||
|         correspondent_binarizer = LabelEncoder() | ||||
|         labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent) | ||||
|  | ||||
|         type_binarizer = LabelEncoder() | ||||
|         labels_type_vectorized = type_binarizer.fit_transform(labels_type) | ||||
|  | ||||
|         # Step 3: train the classifiers | ||||
|         if len(tags_binarizer.classes_) > 0: | ||||
|             logging.getLogger(__name__).info("Training tags classifier") | ||||
|             tags_classifier = OneVsRestClassifier(MultinomialNB()) | ||||
|             tags_classifier.fit(data_vectorized, labels_tags_vectorized) | ||||
|         else: | ||||
|             tags_classifier = None | ||||
|             logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") | ||||
|  | ||||
|         if len(correspondent_binarizer.classes_) > 0: | ||||
|             logging.getLogger(__name__).info("Training correspondent classifier") | ||||
|             correspondent_classifier = MultinomialNB() | ||||
|             correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) | ||||
|         else: | ||||
|             correspondent_classifier = None | ||||
|             logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") | ||||
|  | ||||
|         if len(type_binarizer.classes_) > 0: | ||||
|             logging.getLogger(__name__).info("Training document type classifier") | ||||
|             type_classifier = MultinomialNB() | ||||
|             type_classifier.fit(data_vectorized, labels_type_vectorized) | ||||
|         else: | ||||
|             type_classifier = None | ||||
|             logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") | ||||
|  | ||||
|         models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle")) | ||||
|         logging.getLogger(__name__).info("Saving models to " + models_root + "...") | ||||
|  | ||||
|         with open(models_root, "wb") as f: | ||||
|             pickle.dump(data_vectorizer, f) | ||||
|  | ||||
|             pickle.dump(tags_binarizer, f) | ||||
|             pickle.dump(correspondent_binarizer, f) | ||||
|             pickle.dump(type_binarizer, f) | ||||
|  | ||||
|             pickle.dump(tags_classifier, f) | ||||
|             pickle.dump(correspondent_classifier, f) | ||||
|             pickle.dump(type_classifier, f) | ||||
| @@ -1,11 +1,19 @@ | ||||
| from collections import Counter | ||||
|  | ||||
| from django.core.management.base import BaseCommand | ||||
|  | ||||
| from documents.models import Document, Tag | ||||
| from documents.models import Document | ||||
| from ...mixins import Renderable | ||||
|  | ||||
|  | ||||
| def preprocess_content(content): | ||||
|     content = content.lower() | ||||
|     content = content.strip() | ||||
|     content = content.replace("\n", " ") | ||||
|     content = content.replace("\r", " ") | ||||
|     while content.find("  ") > -1: | ||||
|         content = content.replace("  ", " ") | ||||
|     return content | ||||
|  | ||||
|  | ||||
| class Command(Renderable, BaseCommand): | ||||
|  | ||||
|     help = """ | ||||
| @@ -15,15 +23,6 @@ class Command(Renderable, BaseCommand): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         BaseCommand.__init__(self, *args, **kwargs) | ||||
|  | ||||
|     def preprocess_content(self, content): | ||||
|         content = content.lower() | ||||
|         content = content.strip() | ||||
|         content = content.replace("\n", " ") | ||||
|         content = content.replace("\r", " ") | ||||
|         while content.find("  ") > -1: | ||||
|             content = content.replace("  ", " ") | ||||
|         return content | ||||
|  | ||||
|     def handle(self, *args, **options): | ||||
|         with open("dataset_tags.txt", "w") as f: | ||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||
| @@ -32,19 +31,19 @@ class Command(Renderable, BaseCommand): | ||||
|                     labels.append(tag.name) | ||||
|                 f.write(",".join(labels)) | ||||
|                 f.write(";") | ||||
|                 f.write(self.preprocess_content(doc.content)) | ||||
|                 f.write(preprocess_content(doc.content)) | ||||
|                 f.write("\n") | ||||
|  | ||||
|         with open("dataset_types.txt", "w") as f: | ||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||
|                 f.write(doc.document_type.name if doc.document_type is not None else "None") | ||||
|                 f.write(";") | ||||
|                 f.write(self.preprocess_content(doc.content)) | ||||
|                 f.write(preprocess_content(doc.content)) | ||||
|                 f.write("\n") | ||||
|  | ||||
|         with open("dataset_correspondents.txt", "w") as f: | ||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||
|                 f.write(doc.correspondent.name if doc.correspondent is not None else "None") | ||||
|                 f.write(";") | ||||
|                 f.write(self.preprocess_content(doc.content)) | ||||
|                 f.write(preprocess_content(doc.content)) | ||||
|                 f.write("\n") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jonas Winkler
					Jonas Winkler