centralized classifier loading, better error handling, no error messages when auto matching is not used

This commit is contained in:
jonaswinkler 2021-01-30 14:22:23 +01:00
parent a40e4fe3bc
commit 87a18eae2d
8 changed files with 131 additions and 43 deletions

View File

@ -26,6 +26,34 @@ def preprocess_content(content):
return content return content
def load_classifier():
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
f"Document classification model does not exist (yet), not "
f"performing automatic matching."
)
return None
try:
classifier = DocumentClassifier()
classifier.reload()
except (EOFError, IncompatibleClassifierVersionError) as e:
# there's something wrong with the model file.
logger.error(
f"Unrecoverable error while loading document classification model: "
f"{str(e)}, deleting model file."
)
os.unlink(settings.MODEL_FILE)
classifier = None
except OSError as e:
logger.error(
f"Error while loading document classification model: {str(e)}"
)
classifier = None
return classifier
class DocumentClassifier(object): class DocumentClassifier(object):
FORMAT_VERSION = 6 FORMAT_VERSION = 6

View File

@ -11,7 +11,7 @@ from django.utils import timezone
from filelock import FileLock from filelock import FileLock
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError from .classifier import load_classifier
from .file_handling import create_source_path_directory, \ from .file_handling import create_source_path_directory, \
generate_unique_filename generate_unique_filename
from .loggers import LoggingMixin from .loggers import LoggingMixin
@ -201,14 +201,7 @@ class Consumer(LoggingMixin):
# reloading the classifier multiple times, since there are multiple # reloading the classifier multiple times, since there are multiple
# post-consume hooks that all require the classifier. # post-consume hooks that all require the classifier.
try: classifier = load_classifier()
classifier = DocumentClassifier()
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
self.log(
"warning",
f"Cannot classify documents: {e}.")
classifier = None
# now that everything is done, we can start to store the document # now that everything is done, we can start to store the document
# in the system. This will be a transaction and reasonably fast. # in the system. This will be a transaction and reasonably fast.

View File

@ -2,8 +2,7 @@ import logging
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier, \ from documents.classifier import load_classifier
IncompatibleClassifierVersionError
from documents.models import Document from documents.models import Document
from ...mixins import Renderable from ...mixins import Renderable
from ...signals.handlers import set_correspondent, set_document_type, set_tags from ...signals.handlers import set_correspondent, set_document_type, set_tags
@ -70,13 +69,7 @@ class Command(Renderable, BaseCommand):
queryset = Document.objects.all() queryset = Document.objects.all()
documents = queryset.distinct() documents = queryset.distinct()
classifier = DocumentClassifier() classifier = load_classifier()
try:
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning(
f"Cannot classify documents: {e}.")
classifier = None
for document in documents: for document in documents:
logging.getLogger(__name__).info( logging.getLogger(__name__).info(

View File

@ -6,10 +6,9 @@ from django.db.models.signals import post_save
from whoosh.writing import AsyncWriter from whoosh.writing import AsyncWriter
from documents import index, sanity_checker from documents import index, sanity_checker
from documents.classifier import DocumentClassifier, \ from documents.classifier import DocumentClassifier, load_classifier
IncompatibleClassifierVersionError
from documents.consumer import Consumer, ConsumerError from documents.consumer import Consumer, ConsumerError
from documents.models import Document from documents.models import Document, Tag, DocumentType, Correspondent
from documents.sanity_checker import SanityFailedError from documents.sanity_checker import SanityFailedError
@ -30,13 +29,18 @@ def index_reindex():
def train_classifier(): def train_classifier():
classifier = DocumentClassifier() if (not Tag.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists() and
not DocumentType.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists() and
not Correspondent.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists()):
try: return
# load the classifier, since we might not have to train it again.
classifier.reload() classifier = load_classifier()
except (OSError, EOFError, IncompatibleClassifierVersionError):
# This is what we're going to fix here. if not classifier:
classifier = DocumentClassifier() classifier = DocumentClassifier()
try: try:
@ -52,7 +56,7 @@ def train_classifier():
) )
except Exception as e: except Exception as e:
logging.getLogger(__name__).error( logging.getLogger(__name__).warning(
"Classifier error: " + str(e) "Classifier error: " + str(e)
) )

View File

@ -1,10 +1,13 @@
import os
import tempfile import tempfile
from pathlib import Path
from time import sleep from time import sleep
from unittest import mock from unittest import mock
from django.conf import settings
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier
from documents.models import Correspondent, Document, Tag, DocumentType from documents.models import Correspondent, Document, Tag, DocumentType
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.classifier.train() self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), []) self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
def test_load_classifier_not_exists(self):
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_incompatible_version(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_os_error(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))

View File

@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase):
self.assertIsNotNone(os.path.isfile(document.title)) self.assertIsNotNone(os.path.isfile(document.title))
self.assertTrue(os.path.isfile(document.source_path)) self.assertTrue(os.path.isfile(document.source_path))
@mock.patch("documents.consumer.DocumentClassifier") @mock.patch("documents.consumer.load_classifier")
def testClassifyDocument(self, m): def testClassifyDocument(self, m):
correspondent = Correspondent.objects.create(name="test") correspondent = Correspondent.objects.create(name="test")
dtype = DocumentType.objects.create(name="test") dtype = DocumentType.objects.create(name="test")

View File

@ -1,11 +1,12 @@
from datetime import datetime import os
from unittest import mock from unittest import mock
from django.conf import settings
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
from documents import tasks from documents import tasks
from documents.models import Document from documents.models import Document, Tag, Correspondent, DocumentType
from documents.sanity_checker import SanityError, SanityFailedError from documents.sanity_checker import SanityError, SanityFailedError
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase):
tasks.index_optimize() tasks.index_optimize()
def test_train_classifier(self): @mock.patch("documents.tasks.load_classifier")
def test_train_classifier_no_auto_matching(self, load_classifier):
tasks.train_classifier() tasks.train_classifier()
load_classifier.assert_not_called()
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_tag(self, load_classifier):
load_classifier.return_value = None
Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_type(self, load_classifier):
load_classifier.return_value = None
DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_correspondent(self, load_classifier):
load_classifier.return_value = None
Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
def test_train_classifier(self):
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
doc = Document.objects.create(correspondent=c, content="test", title="test")
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime = os.stat(settings.MODEL_FILE).st_mtime
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime2 = os.stat(settings.MODEL_FILE).st_mtime
self.assertEqual(mtime, mtime2)
doc.content = "test2"
doc.save()
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime3 = os.stat(settings.MODEL_FILE).st_mtime
self.assertNotEqual(mtime2, mtime3)
@mock.patch("documents.tasks.sanity_checker.check_sanity") @mock.patch("documents.tasks.sanity_checker.check_sanity")
def test_sanity_check(self, m): def test_sanity_check(self, m):

View File

@ -34,7 +34,7 @@ from rest_framework.viewsets import (
import documents.index as index import documents.index as index
from paperless.db import GnuPG from paperless.db import GnuPG
from paperless.views import StandardPagination from paperless.views import StandardPagination
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError from .classifier import load_classifier
from .filters import ( from .filters import (
CorrespondentFilterSet, CorrespondentFilterSet,
DocumentFilterSet, DocumentFilterSet,
@ -259,15 +259,7 @@ class DocumentViewSet(RetrieveModelMixin,
except Document.DoesNotExist: except Document.DoesNotExist:
raise Http404() raise Http404()
try: classifier = load_classifier()
classifier = DocumentClassifier()
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning(
"Cannot load classifier: Not providing auto matching "
"suggestions"
)
classifier = None
return Response({ return Response({
"correspondents": [ "correspondents": [