mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	centralized classifier loading, better error handling, no error messages when auto matching is not used
This commit is contained in:
		| @@ -26,6 +26,34 @@ def preprocess_content(content): | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def load_classifier(): | ||||
|     if not os.path.isfile(settings.MODEL_FILE): | ||||
|         logger.debug( | ||||
|             f"Document classification model does not exist (yet), not " | ||||
|             f"performing automatic matching." | ||||
|         ) | ||||
|         return None | ||||
|  | ||||
|     try: | ||||
|         classifier = DocumentClassifier() | ||||
|         classifier.reload() | ||||
|     except (EOFError, IncompatibleClassifierVersionError) as e: | ||||
|         # there's something wrong with the model file. | ||||
|         logger.error( | ||||
|             f"Unrecoverable error while loading document classification model: " | ||||
|             f"{str(e)}, deleting model file." | ||||
|         ) | ||||
|         os.unlink(settings.MODEL_FILE) | ||||
|         classifier = None | ||||
|     except OSError as e: | ||||
|         logger.error( | ||||
|             f"Error while loading document classification model: {str(e)}" | ||||
|         ) | ||||
|         classifier = None | ||||
|  | ||||
|     return classifier | ||||
|  | ||||
|  | ||||
| class DocumentClassifier(object): | ||||
|  | ||||
|     FORMAT_VERSION = 6 | ||||
|   | ||||
| @@ -11,7 +11,7 @@ from django.utils import timezone | ||||
| from filelock import FileLock | ||||
| from rest_framework.reverse import reverse | ||||
|  | ||||
| from .classifier import DocumentClassifier, IncompatibleClassifierVersionError | ||||
| from .classifier import load_classifier | ||||
| from .file_handling import create_source_path_directory, \ | ||||
|     generate_unique_filename | ||||
| from .loggers import LoggingMixin | ||||
| @@ -201,14 +201,7 @@ class Consumer(LoggingMixin): | ||||
|         #   reloading the classifier multiple times, since there are multiple | ||||
|         #   post-consume hooks that all require the classifier. | ||||
|  | ||||
|         try: | ||||
|             classifier = DocumentClassifier() | ||||
|             classifier.reload() | ||||
|         except (OSError, EOFError, IncompatibleClassifierVersionError) as e: | ||||
|             self.log( | ||||
|                 "warning", | ||||
|                 f"Cannot classify documents: {e}.") | ||||
|             classifier = None | ||||
|         classifier = load_classifier() | ||||
|  | ||||
|         # now that everything is done, we can start to store the document | ||||
|         # in the system. This will be a transaction and reasonably fast. | ||||
|   | ||||
| @@ -2,8 +2,7 @@ import logging | ||||
|  | ||||
| from django.core.management.base import BaseCommand | ||||
|  | ||||
| from documents.classifier import DocumentClassifier, \ | ||||
|     IncompatibleClassifierVersionError | ||||
| from documents.classifier import load_classifier | ||||
| from documents.models import Document | ||||
| from ...mixins import Renderable | ||||
| from ...signals.handlers import set_correspondent, set_document_type, set_tags | ||||
| @@ -70,13 +69,7 @@ class Command(Renderable, BaseCommand): | ||||
|             queryset = Document.objects.all() | ||||
|         documents = queryset.distinct() | ||||
|  | ||||
|         classifier = DocumentClassifier() | ||||
|         try: | ||||
|             classifier.reload() | ||||
|         except (OSError, EOFError, IncompatibleClassifierVersionError) as e: | ||||
|             logging.getLogger(__name__).warning( | ||||
|                 f"Cannot classify documents: {e}.") | ||||
|             classifier = None | ||||
|         classifier = load_classifier() | ||||
|  | ||||
|         for document in documents: | ||||
|             logging.getLogger(__name__).info( | ||||
|   | ||||
| @@ -6,10 +6,9 @@ from django.db.models.signals import post_save | ||||
| from whoosh.writing import AsyncWriter | ||||
|  | ||||
| from documents import index, sanity_checker | ||||
| from documents.classifier import DocumentClassifier, \ | ||||
|     IncompatibleClassifierVersionError | ||||
| from documents.classifier import DocumentClassifier, load_classifier | ||||
| from documents.consumer import Consumer, ConsumerError | ||||
| from documents.models import Document | ||||
| from documents.models import Document, Tag, DocumentType, Correspondent | ||||
| from documents.sanity_checker import SanityFailedError | ||||
|  | ||||
|  | ||||
| @@ -30,13 +29,18 @@ def index_reindex(): | ||||
|  | ||||
|  | ||||
| def train_classifier(): | ||||
|     classifier = DocumentClassifier() | ||||
|     if (not Tag.objects.filter( | ||||
|                 matching_algorithm=Tag.MATCH_AUTO).exists() and | ||||
|         not DocumentType.objects.filter( | ||||
|             matching_algorithm=Tag.MATCH_AUTO).exists() and | ||||
|         not Correspondent.objects.filter( | ||||
|             matching_algorithm=Tag.MATCH_AUTO).exists()): | ||||
|  | ||||
|     try: | ||||
|         # load the classifier, since we might not have to train it again. | ||||
|         classifier.reload() | ||||
|     except (OSError, EOFError, IncompatibleClassifierVersionError): | ||||
|         # This is what we're going to fix here. | ||||
|         return | ||||
|  | ||||
|     classifier = load_classifier() | ||||
|  | ||||
|     if not classifier: | ||||
|         classifier = DocumentClassifier() | ||||
|  | ||||
|     try: | ||||
| @@ -52,7 +56,7 @@ def train_classifier(): | ||||
|             ) | ||||
|  | ||||
|     except Exception as e: | ||||
|         logging.getLogger(__name__).error( | ||||
|         logging.getLogger(__name__).warning( | ||||
|             "Classifier error: " + str(e) | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -1,10 +1,13 @@ | ||||
| import os | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
| from time import sleep | ||||
| from unittest import mock | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.test import TestCase, override_settings | ||||
|  | ||||
| from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError | ||||
| from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier | ||||
| from documents.models import Correspondent, Document, Tag, DocumentType | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
| @@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc2.content), []) | ||||
|  | ||||
|     def test_load_classifier_not_exists(self): | ||||
|         self.assertFalse(os.path.exists(settings.MODEL_FILE)) | ||||
|         self.assertIsNone(load_classifier()) | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier(self, reload): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertIsNotNone(load_classifier()) | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier_incompatible_version(self, reload): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|         reload.side_effect = IncompatibleClassifierVersionError() | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertFalse(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|     @mock.patch("documents.classifier.DocumentClassifier.reload") | ||||
|     def test_load_classifier_os_error(self, reload): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|         reload.side_effect = OSError() | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|   | ||||
| @@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase): | ||||
|         self.assertIsNotNone(os.path.isfile(document.title)) | ||||
|         self.assertTrue(os.path.isfile(document.source_path)) | ||||
|  | ||||
|     @mock.patch("documents.consumer.DocumentClassifier") | ||||
|     @mock.patch("documents.consumer.load_classifier") | ||||
|     def testClassifyDocument(self, m): | ||||
|         correspondent = Correspondent.objects.create(name="test") | ||||
|         dtype = DocumentType.objects.create(name="test") | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| from datetime import datetime | ||||
| import os | ||||
| from unittest import mock | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.test import TestCase | ||||
| from django.utils import timezone | ||||
|  | ||||
| from documents import tasks | ||||
| from documents.models import Document | ||||
| from documents.models import Document, Tag, Correspondent, DocumentType | ||||
| from documents.sanity_checker import SanityError, SanityFailedError | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
| @@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         tasks.index_optimize() | ||||
|  | ||||
|     def test_train_classifier(self): | ||||
|     @mock.patch("documents.tasks.load_classifier") | ||||
|     def test_train_classifier_no_auto_matching(self, load_classifier): | ||||
|         tasks.train_classifier() | ||||
|         load_classifier.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.tasks.load_classifier") | ||||
|     def test_train_classifier_with_auto_tag(self, load_classifier): | ||||
|         load_classifier.return_value = None | ||||
|         Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") | ||||
|         tasks.train_classifier() | ||||
|         load_classifier.assert_called_once() | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|     @mock.patch("documents.tasks.load_classifier") | ||||
|     def test_train_classifier_with_auto_type(self, load_classifier): | ||||
|         load_classifier.return_value = None | ||||
|         DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") | ||||
|         tasks.train_classifier() | ||||
|         load_classifier.assert_called_once() | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|     @mock.patch("documents.tasks.load_classifier") | ||||
|     def test_train_classifier_with_auto_correspondent(self, load_classifier): | ||||
|         load_classifier.return_value = None | ||||
|         Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") | ||||
|         tasks.train_classifier() | ||||
|         load_classifier.assert_called_once() | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|     def test_train_classifier(self): | ||||
|         c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") | ||||
|         doc = Document.objects.create(correspondent=c, content="test", title="test") | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime = os.stat(settings.MODEL_FILE).st_mtime | ||||
|  | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime2 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|         self.assertEqual(mtime, mtime2) | ||||
|  | ||||
|         doc.content = "test2" | ||||
|         doc.save() | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime3 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|         self.assertNotEqual(mtime2, mtime3) | ||||
|  | ||||
|     @mock.patch("documents.tasks.sanity_checker.check_sanity") | ||||
|     def test_sanity_check(self, m): | ||||
|   | ||||
| @@ -34,7 +34,7 @@ from rest_framework.viewsets import ( | ||||
| import documents.index as index | ||||
| from paperless.db import GnuPG | ||||
| from paperless.views import StandardPagination | ||||
| from .classifier import DocumentClassifier, IncompatibleClassifierVersionError | ||||
| from .classifier import load_classifier | ||||
| from .filters import ( | ||||
|     CorrespondentFilterSet, | ||||
|     DocumentFilterSet, | ||||
| @@ -259,15 +259,7 @@ class DocumentViewSet(RetrieveModelMixin, | ||||
|         except Document.DoesNotExist: | ||||
|             raise Http404() | ||||
|  | ||||
|         try: | ||||
|             classifier = DocumentClassifier() | ||||
|             classifier.reload() | ||||
|         except (OSError, EOFError, IncompatibleClassifierVersionError) as e: | ||||
|             logging.getLogger(__name__).warning( | ||||
|                 "Cannot load classifier: Not providing auto matching " | ||||
|                 "suggestions" | ||||
|             ) | ||||
|             classifier = None | ||||
|         classifier = load_classifier() | ||||
|  | ||||
|         return Response({ | ||||
|             "correspondents": [ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler