centralized classifier loading, better error handling, no error messages when auto matching is not used

This commit is contained in:
jonaswinkler 2021-01-30 14:22:23 +01:00
parent a40e4fe3bc
commit 87a18eae2d
8 changed files with 131 additions and 43 deletions

View File

@ -26,6 +26,34 @@ def preprocess_content(content):
return content
def load_classifier():
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
f"Document classification model does not exist (yet), not "
f"performing automatic matching."
)
return None
try:
classifier = DocumentClassifier()
classifier.reload()
except (EOFError, IncompatibleClassifierVersionError) as e:
# there's something wrong with the model file.
logger.error(
f"Unrecoverable error while loading document classification model: "
f"{str(e)}, deleting model file."
)
os.unlink(settings.MODEL_FILE)
classifier = None
except OSError as e:
logger.error(
f"Error while loading document classification model: {str(e)}"
)
classifier = None
return classifier
class DocumentClassifier(object):
FORMAT_VERSION = 6

View File

@ -11,7 +11,7 @@ from django.utils import timezone
from filelock import FileLock
from rest_framework.reverse import reverse
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
from .classifier import load_classifier
from .file_handling import create_source_path_directory, \
generate_unique_filename
from .loggers import LoggingMixin
@ -201,14 +201,7 @@ class Consumer(LoggingMixin):
# reloading the classifier multiple times, since there are multiple
# post-consume hooks that all require the classifier.
try:
classifier = DocumentClassifier()
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
self.log(
"warning",
f"Cannot classify documents: {e}.")
classifier = None
classifier = load_classifier()
# now that everything is done, we can start to store the document
# in the system. This will be a transaction and reasonably fast.

View File

@ -2,8 +2,7 @@ import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from documents.classifier import load_classifier
from documents.models import Document
from ...mixins import Renderable
from ...signals.handlers import set_correspondent, set_document_type, set_tags
@ -70,13 +69,7 @@ class Command(Renderable, BaseCommand):
queryset = Document.objects.all()
documents = queryset.distinct()
classifier = DocumentClassifier()
try:
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning(
f"Cannot classify documents: {e}.")
classifier = None
classifier = load_classifier()
for document in documents:
logging.getLogger(__name__).info(

View File

@ -6,10 +6,9 @@ from django.db.models.signals import post_save
from whoosh.writing import AsyncWriter
from documents import index, sanity_checker
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from documents.classifier import DocumentClassifier, load_classifier
from documents.consumer import Consumer, ConsumerError
from documents.models import Document
from documents.models import Document, Tag, DocumentType, Correspondent
from documents.sanity_checker import SanityFailedError
@ -30,13 +29,18 @@ def index_reindex():
def train_classifier():
classifier = DocumentClassifier()
if (not Tag.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists() and
not DocumentType.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists() and
not Correspondent.objects.filter(
matching_algorithm=Tag.MATCH_AUTO).exists()):
try:
# load the classifier, since we might not have to train it again.
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError):
# This is what we're going to fix here.
return
classifier = load_classifier()
if not classifier:
classifier = DocumentClassifier()
try:
@ -52,7 +56,7 @@ def train_classifier():
)
except Exception as e:
logging.getLogger(__name__).error(
logging.getLogger(__name__).warning(
"Classifier error: " + str(e)
)

View File

@ -1,10 +1,13 @@
import os
import tempfile
from pathlib import Path
from time import sleep
from unittest import mock
from django.conf import settings
from django.test import TestCase, override_settings
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier
from documents.models import Correspondent, Document, Tag, DocumentType
from documents.tests.utils import DirectoriesMixin
@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.classifier.train()
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
def test_load_classifier_not_exists(self):
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_incompatible_version(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
@mock.patch("documents.classifier.DocumentClassifier.reload")
def test_load_classifier_os_error(self, reload):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
reload.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))

View File

@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase):
self.assertIsNotNone(os.path.isfile(document.title))
self.assertTrue(os.path.isfile(document.source_path))
@mock.patch("documents.consumer.DocumentClassifier")
@mock.patch("documents.consumer.load_classifier")
def testClassifyDocument(self, m):
correspondent = Correspondent.objects.create(name="test")
dtype = DocumentType.objects.create(name="test")

View File

@ -1,11 +1,12 @@
from datetime import datetime
import os
from unittest import mock
from django.conf import settings
from django.test import TestCase
from django.utils import timezone
from documents import tasks
from documents.models import Document
from documents.models import Document, Tag, Correspondent, DocumentType
from documents.sanity_checker import SanityError, SanityFailedError
from documents.tests.utils import DirectoriesMixin
@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase):
tasks.index_optimize()
def test_train_classifier(self):
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_no_auto_matching(self, load_classifier):
tasks.train_classifier()
load_classifier.assert_not_called()
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_tag(self, load_classifier):
load_classifier.return_value = None
Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_type(self, load_classifier):
load_classifier.return_value = None
DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
@mock.patch("documents.tasks.load_classifier")
def test_train_classifier_with_auto_correspondent(self, load_classifier):
load_classifier.return_value = None
Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
tasks.train_classifier()
load_classifier.assert_called_once()
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
def test_train_classifier(self):
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
doc = Document.objects.create(correspondent=c, content="test", title="test")
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime = os.stat(settings.MODEL_FILE).st_mtime
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime2 = os.stat(settings.MODEL_FILE).st_mtime
self.assertEqual(mtime, mtime2)
doc.content = "test2"
doc.save()
tasks.train_classifier()
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
mtime3 = os.stat(settings.MODEL_FILE).st_mtime
self.assertNotEqual(mtime2, mtime3)
@mock.patch("documents.tasks.sanity_checker.check_sanity")
def test_sanity_check(self, m):

View File

@ -34,7 +34,7 @@ from rest_framework.viewsets import (
import documents.index as index
from paperless.db import GnuPG
from paperless.views import StandardPagination
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
from .classifier import load_classifier
from .filters import (
CorrespondentFilterSet,
DocumentFilterSet,
@ -259,15 +259,7 @@ class DocumentViewSet(RetrieveModelMixin,
except Document.DoesNotExist:
raise Http404()
try:
classifier = DocumentClassifier()
classifier.reload()
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning(
"Cannot load classifier: Not providing auto matching "
"suggestions"
)
classifier = None
classifier = load_classifier()
return Response({
"correspondents": [