mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
centralized classifier loading, better error handling, no error messages when auto matching is not used
This commit is contained in:
parent
a40e4fe3bc
commit
87a18eae2d
@ -26,6 +26,34 @@ def preprocess_content(content):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def load_classifier():
|
||||||
|
if not os.path.isfile(settings.MODEL_FILE):
|
||||||
|
logger.debug(
|
||||||
|
f"Document classification model does not exist (yet), not "
|
||||||
|
f"performing automatic matching."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
classifier = DocumentClassifier()
|
||||||
|
classifier.reload()
|
||||||
|
except (EOFError, IncompatibleClassifierVersionError) as e:
|
||||||
|
# there's something wrong with the model file.
|
||||||
|
logger.error(
|
||||||
|
f"Unrecoverable error while loading document classification model: "
|
||||||
|
f"{str(e)}, deleting model file."
|
||||||
|
)
|
||||||
|
os.unlink(settings.MODEL_FILE)
|
||||||
|
classifier = None
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error while loading document classification model: {str(e)}"
|
||||||
|
)
|
||||||
|
classifier = None
|
||||||
|
|
||||||
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
class DocumentClassifier(object):
|
class DocumentClassifier(object):
|
||||||
|
|
||||||
FORMAT_VERSION = 6
|
FORMAT_VERSION = 6
|
||||||
|
@ -11,7 +11,7 @@ from django.utils import timezone
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from rest_framework.reverse import reverse
|
from rest_framework.reverse import reverse
|
||||||
|
|
||||||
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
from .classifier import load_classifier
|
||||||
from .file_handling import create_source_path_directory, \
|
from .file_handling import create_source_path_directory, \
|
||||||
generate_unique_filename
|
generate_unique_filename
|
||||||
from .loggers import LoggingMixin
|
from .loggers import LoggingMixin
|
||||||
@ -201,14 +201,7 @@ class Consumer(LoggingMixin):
|
|||||||
# reloading the classifier multiple times, since there are multiple
|
# reloading the classifier multiple times, since there are multiple
|
||||||
# post-consume hooks that all require the classifier.
|
# post-consume hooks that all require the classifier.
|
||||||
|
|
||||||
try:
|
classifier = load_classifier()
|
||||||
classifier = DocumentClassifier()
|
|
||||||
classifier.reload()
|
|
||||||
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
|
|
||||||
self.log(
|
|
||||||
"warning",
|
|
||||||
f"Cannot classify documents: {e}.")
|
|
||||||
classifier = None
|
|
||||||
|
|
||||||
# now that everything is done, we can start to store the document
|
# now that everything is done, we can start to store the document
|
||||||
# in the system. This will be a transaction and reasonably fast.
|
# in the system. This will be a transaction and reasonably fast.
|
||||||
|
@ -2,8 +2,7 @@ import logging
|
|||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
from documents.classifier import DocumentClassifier, \
|
from documents.classifier import load_classifier
|
||||||
IncompatibleClassifierVersionError
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from ...mixins import Renderable
|
from ...mixins import Renderable
|
||||||
from ...signals.handlers import set_correspondent, set_document_type, set_tags
|
from ...signals.handlers import set_correspondent, set_document_type, set_tags
|
||||||
@ -70,13 +69,7 @@ class Command(Renderable, BaseCommand):
|
|||||||
queryset = Document.objects.all()
|
queryset = Document.objects.all()
|
||||||
documents = queryset.distinct()
|
documents = queryset.distinct()
|
||||||
|
|
||||||
classifier = DocumentClassifier()
|
classifier = load_classifier()
|
||||||
try:
|
|
||||||
classifier.reload()
|
|
||||||
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
f"Cannot classify documents: {e}.")
|
|
||||||
classifier = None
|
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
logging.getLogger(__name__).info(
|
logging.getLogger(__name__).info(
|
||||||
|
@ -6,10 +6,9 @@ from django.db.models.signals import post_save
|
|||||||
from whoosh.writing import AsyncWriter
|
from whoosh.writing import AsyncWriter
|
||||||
|
|
||||||
from documents import index, sanity_checker
|
from documents import index, sanity_checker
|
||||||
from documents.classifier import DocumentClassifier, \
|
from documents.classifier import DocumentClassifier, load_classifier
|
||||||
IncompatibleClassifierVersionError
|
|
||||||
from documents.consumer import Consumer, ConsumerError
|
from documents.consumer import Consumer, ConsumerError
|
||||||
from documents.models import Document
|
from documents.models import Document, Tag, DocumentType, Correspondent
|
||||||
from documents.sanity_checker import SanityFailedError
|
from documents.sanity_checker import SanityFailedError
|
||||||
|
|
||||||
|
|
||||||
@ -30,13 +29,18 @@ def index_reindex():
|
|||||||
|
|
||||||
|
|
||||||
def train_classifier():
|
def train_classifier():
|
||||||
classifier = DocumentClassifier()
|
if (not Tag.objects.filter(
|
||||||
|
matching_algorithm=Tag.MATCH_AUTO).exists() and
|
||||||
|
not DocumentType.objects.filter(
|
||||||
|
matching_algorithm=Tag.MATCH_AUTO).exists() and
|
||||||
|
not Correspondent.objects.filter(
|
||||||
|
matching_algorithm=Tag.MATCH_AUTO).exists()):
|
||||||
|
|
||||||
try:
|
return
|
||||||
# load the classifier, since we might not have to train it again.
|
|
||||||
classifier.reload()
|
classifier = load_classifier()
|
||||||
except (OSError, EOFError, IncompatibleClassifierVersionError):
|
|
||||||
# This is what we're going to fix here.
|
if not classifier:
|
||||||
classifier = DocumentClassifier()
|
classifier = DocumentClassifier()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -52,7 +56,7 @@ def train_classifier():
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).error(
|
logging.getLogger(__name__).warning(
|
||||||
"Classifier error: " + str(e)
|
"Classifier error: " + str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
|
|
||||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError, load_classifier
|
||||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
@ -235,3 +238,30 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
self.classifier.train()
|
self.classifier.train()
|
||||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
||||||
|
|
||||||
|
def test_load_classifier_not_exists(self):
|
||||||
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||||
|
self.assertIsNone(load_classifier())
|
||||||
|
|
||||||
|
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||||
|
def test_load_classifier(self, reload):
|
||||||
|
Path(settings.MODEL_FILE).touch()
|
||||||
|
self.assertIsNotNone(load_classifier())
|
||||||
|
|
||||||
|
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||||
|
def test_load_classifier_incompatible_version(self, reload):
|
||||||
|
Path(settings.MODEL_FILE).touch()
|
||||||
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
reload.side_effect = IncompatibleClassifierVersionError()
|
||||||
|
self.assertIsNone(load_classifier())
|
||||||
|
self.assertFalse(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
@mock.patch("documents.classifier.DocumentClassifier.reload")
|
||||||
|
def test_load_classifier_os_error(self, reload):
|
||||||
|
Path(settings.MODEL_FILE).touch()
|
||||||
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
reload.side_effect = OSError()
|
||||||
|
self.assertIsNone(load_classifier())
|
||||||
|
self.assertTrue(os.path.exists(settings.MODEL_FILE))
|
||||||
|
@ -420,7 +420,7 @@ class TestConsumer(DirectoriesMixin, TestCase):
|
|||||||
self.assertIsNotNone(os.path.isfile(document.title))
|
self.assertIsNotNone(os.path.isfile(document.title))
|
||||||
self.assertTrue(os.path.isfile(document.source_path))
|
self.assertTrue(os.path.isfile(document.source_path))
|
||||||
|
|
||||||
@mock.patch("documents.consumer.DocumentClassifier")
|
@mock.patch("documents.consumer.load_classifier")
|
||||||
def testClassifyDocument(self, m):
|
def testClassifyDocument(self, m):
|
||||||
correspondent = Correspondent.objects.create(name="test")
|
correspondent = Correspondent.objects.create(name="test")
|
||||||
dtype = DocumentType.objects.create(name="test")
|
dtype = DocumentType.objects.create(name="test")
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from datetime import datetime
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from documents import tasks
|
from documents import tasks
|
||||||
from documents.models import Document
|
from documents.models import Document, Tag, Correspondent, DocumentType
|
||||||
from documents.sanity_checker import SanityError, SanityFailedError
|
from documents.sanity_checker import SanityError, SanityFailedError
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
@ -22,8 +23,55 @@ class TestTasks(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
tasks.index_optimize()
|
tasks.index_optimize()
|
||||||
|
|
||||||
def test_train_classifier(self):
|
@mock.patch("documents.tasks.load_classifier")
|
||||||
|
def test_train_classifier_no_auto_matching(self, load_classifier):
|
||||||
tasks.train_classifier()
|
tasks.train_classifier()
|
||||||
|
load_classifier.assert_not_called()
|
||||||
|
|
||||||
|
@mock.patch("documents.tasks.load_classifier")
|
||||||
|
def test_train_classifier_with_auto_tag(self, load_classifier):
|
||||||
|
load_classifier.return_value = None
|
||||||
|
Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||||
|
tasks.train_classifier()
|
||||||
|
load_classifier.assert_called_once()
|
||||||
|
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
@mock.patch("documents.tasks.load_classifier")
|
||||||
|
def test_train_classifier_with_auto_type(self, load_classifier):
|
||||||
|
load_classifier.return_value = None
|
||||||
|
DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||||
|
tasks.train_classifier()
|
||||||
|
load_classifier.assert_called_once()
|
||||||
|
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
@mock.patch("documents.tasks.load_classifier")
|
||||||
|
def test_train_classifier_with_auto_correspondent(self, load_classifier):
|
||||||
|
load_classifier.return_value = None
|
||||||
|
Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||||
|
tasks.train_classifier()
|
||||||
|
load_classifier.assert_called_once()
|
||||||
|
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
def test_train_classifier(self):
|
||||||
|
c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||||
|
doc = Document.objects.create(correspondent=c, content="test", title="test")
|
||||||
|
self.assertFalse(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
|
||||||
|
tasks.train_classifier()
|
||||||
|
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
mtime = os.stat(settings.MODEL_FILE).st_mtime
|
||||||
|
|
||||||
|
tasks.train_classifier()
|
||||||
|
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
mtime2 = os.stat(settings.MODEL_FILE).st_mtime
|
||||||
|
self.assertEqual(mtime, mtime2)
|
||||||
|
|
||||||
|
doc.content = "test2"
|
||||||
|
doc.save()
|
||||||
|
tasks.train_classifier()
|
||||||
|
self.assertTrue(os.path.isfile(settings.MODEL_FILE))
|
||||||
|
mtime3 = os.stat(settings.MODEL_FILE).st_mtime
|
||||||
|
self.assertNotEqual(mtime2, mtime3)
|
||||||
|
|
||||||
@mock.patch("documents.tasks.sanity_checker.check_sanity")
|
@mock.patch("documents.tasks.sanity_checker.check_sanity")
|
||||||
def test_sanity_check(self, m):
|
def test_sanity_check(self, m):
|
||||||
|
@ -34,7 +34,7 @@ from rest_framework.viewsets import (
|
|||||||
import documents.index as index
|
import documents.index as index
|
||||||
from paperless.db import GnuPG
|
from paperless.db import GnuPG
|
||||||
from paperless.views import StandardPagination
|
from paperless.views import StandardPagination
|
||||||
from .classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
from .classifier import load_classifier
|
||||||
from .filters import (
|
from .filters import (
|
||||||
CorrespondentFilterSet,
|
CorrespondentFilterSet,
|
||||||
DocumentFilterSet,
|
DocumentFilterSet,
|
||||||
@ -259,15 +259,7 @@ class DocumentViewSet(RetrieveModelMixin,
|
|||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404()
|
raise Http404()
|
||||||
|
|
||||||
try:
|
classifier = load_classifier()
|
||||||
classifier = DocumentClassifier()
|
|
||||||
classifier.reload()
|
|
||||||
except (OSError, EOFError, IncompatibleClassifierVersionError) as e:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"Cannot load classifier: Not providing auto matching "
|
|
||||||
"suggestions"
|
|
||||||
)
|
|
||||||
classifier = None
|
|
||||||
|
|
||||||
return Response({
|
return Response({
|
||||||
"correspondents": [
|
"correspondents": [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user