From 87a18eae2d37b1c05052aafb7c4142a13ad10892 Mon Sep 17 00:00:00 2001 From: jonaswinkler Date: Sat, 30 Jan 2021 14:22:23 +0100 Subject: [PATCH] centralized classifier loading, better error handling, no error messages when auto matching is not used --- src/documents/classifier.py | 28 ++++++++++ src/documents/consumer.py | 11 +--- .../management/commands/document_retagger.py | 11 +--- src/documents/tasks.py | 24 +++++---- src/documents/tests/test_classifier.py | 32 ++++++++++- src/documents/tests/test_consumer.py | 2 +- src/documents/tests/test_tasks.py | 54 +++++++++++++++++-- src/documents/views.py | 12 +---- 8 files changed, 131 insertions(+), 43 deletions(-) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 60c9abeec..41cd05412 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -26,6 +26,34 @@ def preprocess_content(content): return content +def load_classifier(): + if not os.path.isfile(settings.MODEL_FILE): + logger.debug( + f"Document classification model does not exist (yet), not " + f"performing automatic matching." + ) + return None + + try: + classifier = DocumentClassifier() + classifier.reload() + except (EOFError, IncompatibleClassifierVersionError) as e: + # there's something wrong with the model file. + logger.error( + f"Unrecoverable error while loading document classification model: " + f"{str(e)}, deleting model file." + ) + os.unlink(settings.MODEL_FILE) + classifier = None + except OSError as e: + logger.error( + f"Error while loading document classification model: {str(e)}" + ) + classifier = None + + return classifier + + class DocumentClassifier(object): FORMAT_VERSION = 6 diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 5418e3b59..5e76ad03a 100755 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -11,7 +11,7 @@ from django.utils import timezone from filelock import FileLock from rest_framework.reverse import reverse -from .classifier import DocumentClassifier, IncompatibleClassifierVersionError +from .classifier import load_classifier from .file_handling import create_source_path_directory, \ generate_unique_filename from .loggers import LoggingMixin @@ -201,14 +201,7 @@ class Consumer(LoggingMixin): # reloading the classifier multiple times, since there are multiple # post-consume hooks that all require the classifier. - try: - classifier = DocumentClassifier() - classifier.reload() - except (OSError, EOFError, IncompatibleClassifierVersionError) as e: - self.log( - "warning", - f"Cannot classify documents: {e}.") - classifier = None + classifier = load_classifier() # now that everything is done, we can start to store the document # in the system. This will be a transaction and reasonably fast. diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 0fb9782c1..b2f5d8918 100755 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -2,8 +2,7 @@ import logging from django.core.management.base import BaseCommand -from documents.classifier import DocumentClassifier, \ - IncompatibleClassifierVersionError +from documents.classifier import load_classifier from documents.models import Document from ...mixins import Renderable from ...signals.handlers import set_correspondent, set_document_type, set_tags @@ -70,13 +69,7 @@ class Command(Renderable, BaseCommand): queryset = Document.objects.all() documents = queryset.distinct() - classifier = DocumentClassifier() - try: - classifier.reload() - except (OSError, EOFError, IncompatibleClassifierVersionError) as e: - logging.getLogger(__name__).warning( - f"Cannot classify documents: {e}.") - classifier = None + classifier = load_classifier() for document in documents: logging.getLogger(__name__).info( diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 38ff532b5..4e74d7350 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -6,10 +6,9 @@ from django.db.models.signals import post_save from whoosh.writing import AsyncWriter from documents import index, sanity_checker -from documents.classifier import DocumentClassifier, \ - IncompatibleClassifierVersionError +from documents.classifier import DocumentClassifier, load_classifier from documents.consumer import Consumer, ConsumerError -from documents.models import Document +from documents.models import Document, Tag, DocumentType, Correspondent from documents.sanity_checker import SanityFailedError @@ -30,13 +29,18 @@ def index_reindex(): def train_classifier(): - classifier = DocumentClassifier() + if (not Tag.objects.filter( + matching_algorithm=Tag.MATCH_AUTO).exists() and + not DocumentType.objects.filter( + matching_algorithm=Tag.MATCH_AUTO).exists() and + not Correspondent.objects.filter( + matching_algorithm=Tag.MATCH_AUTO).exists()): - try: - # load the classifier, since we might not have to train it again. - classifier.reload() - except (OSError, EOFError, IncompatibleClassifierVersionError): - # This is what we're going to fix here. + return + + classifier = load_classifier() + + if not classifier: classifier = DocumentClassifier() try: @@ -52,7 +56,7 @@ def train_classifier(): ) except Exception as e: - logging.getLogger(__name__).error( + logging.getLogger(__name__).warning( "Classifier error: " + str(e) ) diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 9e999794d..43c38b691 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,10 +1,13 @@ +import os import tempfile +from pathlib import Path from time import sleep from unittest import mock +from django.conf import settings from django.test import TestCase, override_settings -from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError +from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier from documents.models import Correspondent, Document, Tag, DocumentType from documents.tests.utils import DirectoriesMixin @@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase): self.classifier.train() self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) self.assertListEqual(self.classifier.predict_tags(doc2.content), []) + + def test_load_classifier_not_exists(self): + self.assertFalse(os.path.exists(settings.MODEL_FILE)) + self.assertIsNone(load_classifier()) + + @mock.patch("documents.classifier.DocumentClassifier.reload") + def test_load_classifier(self, reload): + Path(settings.MODEL_FILE).touch() + self.assertIsNotNone(load_classifier()) + + @mock.patch("documents.classifier.DocumentClassifier.reload") + def test_load_classifier_incompatible_version(self, reload): + Path(settings.MODEL_FILE).touch() + self.assertTrue(os.path.exists(settings.MODEL_FILE)) + + reload.side_effect = IncompatibleClassifierVersionError() + self.assertIsNone(load_classifier()) + self.assertFalse(os.path.exists(settings.MODEL_FILE)) + + @mock.patch("documents.classifier.DocumentClassifier.reload") + def test_load_classifier_os_error(self, reload): + Path(settings.MODEL_FILE).touch() + self.assertTrue(os.path.exists(settings.MODEL_FILE)) + + reload.side_effect = OSError() + self.assertIsNone(load_classifier()) + self.assertTrue(os.path.exists(settings.MODEL_FILE)) diff --git a/src/documents/tests/test_consumer.py b/src/documents/tests/test_consumer.py index a6861a541..02d1d0004 100644 --- a/src/documents/tests/test_consumer.py +++ b/src/documents/tests/test_consumer.py @@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase): self.assertIsNotNone(os.path.isfile(document.title)) self.assertTrue(os.path.isfile(document.source_path)) - @mock.patch("documents.consumer.DocumentClassifier") + @mock.patch("documents.consumer.load_classifier") def testClassifyDocument(self, m): correspondent = Correspondent.objects.create(name="test") dtype = DocumentType.objects.create(name="test") diff --git a/src/documents/tests/test_tasks.py b/src/documents/tests/test_tasks.py index 653590707..eb310d357 100644 --- a/src/documents/tests/test_tasks.py +++ b/src/documents/tests/test_tasks.py @@ -1,11 +1,12 @@ -from datetime import datetime +import os from unittest import mock +from django.conf import settings from django.test import TestCase from django.utils import timezone from documents import tasks -from documents.models import Document +from documents.models import Document, Tag, Correspondent, DocumentType from documents.sanity_checker import SanityError, SanityFailedError from documents.tests.utils import DirectoriesMixin @@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase): tasks.index_optimize() - def test_train_classifier(self): + @mock.patch("documents.tasks.load_classifier") + def test_train_classifier_no_auto_matching(self, load_classifier): tasks.train_classifier() + load_classifier.assert_not_called() + + @mock.patch("documents.tasks.load_classifier") + def test_train_classifier_with_auto_tag(self, load_classifier): + load_classifier.return_value = None + Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") + tasks.train_classifier() + load_classifier.assert_called_once() + self.assertFalse(os.path.isfile(settings.MODEL_FILE)) + + @mock.patch("documents.tasks.load_classifier") + def test_train_classifier_with_auto_type(self, load_classifier): + load_classifier.return_value = None + DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") + tasks.train_classifier() + load_classifier.assert_called_once() + self.assertFalse(os.path.isfile(settings.MODEL_FILE)) + + @mock.patch("documents.tasks.load_classifier") + def test_train_classifier_with_auto_correspondent(self, load_classifier): + load_classifier.return_value = None + Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") + tasks.train_classifier() + load_classifier.assert_called_once() + self.assertFalse(os.path.isfile(settings.MODEL_FILE)) + + def test_train_classifier(self): + c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") + doc = Document.objects.create(correspondent=c, content="test", title="test") + self.assertFalse(os.path.isfile(settings.MODEL_FILE)) + + tasks.train_classifier() + self.assertTrue(os.path.isfile(settings.MODEL_FILE)) + mtime = os.stat(settings.MODEL_FILE).st_mtime + + tasks.train_classifier() + self.assertTrue(os.path.isfile(settings.MODEL_FILE)) + mtime2 = os.stat(settings.MODEL_FILE).st_mtime + self.assertEqual(mtime, mtime2) + + doc.content = "test2" + doc.save() + tasks.train_classifier() + self.assertTrue(os.path.isfile(settings.MODEL_FILE)) + mtime3 = os.stat(settings.MODEL_FILE).st_mtime + self.assertNotEqual(mtime2, mtime3) @mock.patch("documents.tasks.sanity_checker.check_sanity") def test_sanity_check(self, m): diff --git a/src/documents/views.py b/src/documents/views.py index 43ae2b103..6fbb42976 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -34,7 +34,7 @@ from rest_framework.viewsets import ( import documents.index as index from paperless.db import GnuPG from paperless.views import StandardPagination -from .classifier import DocumentClassifier, IncompatibleClassifierVersionError +from .classifier import load_classifier from .filters import ( CorrespondentFilterSet, DocumentFilterSet, @@ -259,15 +259,7 @@ class DocumentViewSet(RetrieveModelMixin, except Document.DoesNotExist: raise Http404() - try: - classifier = DocumentClassifier() - classifier.reload() - except (OSError, EOFError, IncompatibleClassifierVersionError) as e: - logging.getLogger(__name__).warning( - "Cannot load classifier: Not providing auto matching " - "suggestions" - ) - classifier = None + classifier = load_classifier() return Response({ "correspondents": [