mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Merge branch 'main' into dev
This commit is contained in:
@@ -6,6 +6,7 @@ from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless.matching")
|
||||
@@ -19,40 +20,64 @@ def log_reason(matching_model, document, reason):
|
||||
)
|
||||
|
||||
|
||||
def match_correspondents(document, classifier):
|
||||
def match_correspondents(document, classifier, user=None):
|
||||
pred_id = classifier.predict_correspondent(document.content) if classifier else None
|
||||
|
||||
correspondents = Correspondent.objects.all()
|
||||
if user is not None:
|
||||
correspondents = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"documents.view_correspondent",
|
||||
Correspondent,
|
||||
)
|
||||
else:
|
||||
correspondents = Correspondent.objects.all()
|
||||
|
||||
return list(
|
||||
filter(lambda o: matches(o, document) or o.pk == pred_id, correspondents),
|
||||
)
|
||||
|
||||
|
||||
def match_document_types(document, classifier):
|
||||
def match_document_types(document, classifier, user=None):
|
||||
pred_id = classifier.predict_document_type(document.content) if classifier else None
|
||||
|
||||
document_types = DocumentType.objects.all()
|
||||
if user is not None:
|
||||
document_types = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"documents.view_documenttype",
|
||||
DocumentType,
|
||||
)
|
||||
else:
|
||||
document_types = DocumentType.objects.all()
|
||||
|
||||
return list(
|
||||
filter(lambda o: matches(o, document) or o.pk == pred_id, document_types),
|
||||
)
|
||||
|
||||
|
||||
def match_tags(document, classifier):
|
||||
def match_tags(document, classifier, user=None):
|
||||
predicted_tag_ids = classifier.predict_tags(document.content) if classifier else []
|
||||
|
||||
tags = Tag.objects.all()
|
||||
if user is not None:
|
||||
tags = get_objects_for_user_owner_aware(user, "documents.view_tag", Tag)
|
||||
else:
|
||||
tags = Tag.objects.all()
|
||||
|
||||
return list(
|
||||
filter(lambda o: matches(o, document) or o.pk in predicted_tag_ids, tags),
|
||||
)
|
||||
|
||||
|
||||
def match_storage_paths(document, classifier):
|
||||
def match_storage_paths(document, classifier, user=None):
|
||||
pred_id = classifier.predict_storage_path(document.content) if classifier else None
|
||||
|
||||
storage_paths = StoragePath.objects.all()
|
||||
if user is not None:
|
||||
storage_paths = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"documents.view_storagepath",
|
||||
StoragePath,
|
||||
)
|
||||
else:
|
||||
storage_paths = StoragePath.objects.all()
|
||||
|
||||
return list(
|
||||
filter(
|
||||
|
@@ -4,6 +4,7 @@ from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from guardian.models import GroupObjectPermission
|
||||
from guardian.shortcuts import assign_perm
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from guardian.shortcuts import get_users_with_perms
|
||||
from guardian.shortcuts import remove_perm
|
||||
from rest_framework.permissions import BasePermission
|
||||
@@ -101,3 +102,15 @@ def set_permissions_for_object(permissions, object):
|
||||
group,
|
||||
object,
|
||||
)
|
||||
|
||||
|
||||
def get_objects_for_user_owner_aware(user, perms, Model):
|
||||
objects_owned = Model.objects.filter(owner=user)
|
||||
objects_unowned = Model.objects.filter(owner__isnull=True)
|
||||
objects_with_perms = get_objects_for_user(
|
||||
user=user,
|
||||
perms=perms,
|
||||
klass=Model,
|
||||
accept_global_perms=False,
|
||||
)
|
||||
return objects_owned | objects_unowned | objects_with_perms
|
||||
|
@@ -4,6 +4,7 @@ import shutil
|
||||
|
||||
from celery import states
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_failure
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from django.conf import settings
|
||||
@@ -591,3 +592,29 @@ def task_postrun_handler(
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
|
||||
|
||||
@task_failure.connect
|
||||
def task_failure_handler(
|
||||
sender=None,
|
||||
task_id=None,
|
||||
exception=None,
|
||||
args=None,
|
||||
traceback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Updates the result of a failed PaperlessTask.
|
||||
|
||||
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-failure
|
||||
"""
|
||||
try:
|
||||
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
|
||||
|
||||
if task_instance is not None and task_instance.result is None:
|
||||
task_instance.status = states.FAILURE
|
||||
task_instance.result = traceback
|
||||
task_instance.date_done = timezone.now()
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
|
@@ -9,6 +9,7 @@ from documents.models import PaperlessTask
|
||||
from documents.signals.handlers import before_task_publish_handler
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
from documents.signals.handlers import task_prerun_handler
|
||||
from documents.signals.handlers import task_failure_handler
|
||||
from documents.tests.test_consumer import fake_magic_from_file
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
@@ -146,3 +147,44 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
|
||||
task = PaperlessTask.objects.get()
|
||||
|
||||
self.assertEqual(celery.states.SUCCESS, task.status)
|
||||
|
||||
def test_task_failure_handler(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- A celery task is started via the consume folder
|
||||
WHEN:
|
||||
- Task failed execution
|
||||
THEN:
|
||||
- The task is marked as failed
|
||||
"""
|
||||
headers = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"task": "documents.tasks.consume_file",
|
||||
}
|
||||
body = (
|
||||
# args
|
||||
(
|
||||
ConsumableDocument(
|
||||
source=DocumentSource.ConsumeFolder,
|
||||
original_file="/consume/hello-9.pdf",
|
||||
),
|
||||
None,
|
||||
),
|
||||
# kwargs
|
||||
{},
|
||||
# celery stuff
|
||||
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
|
||||
)
|
||||
self.util_call_before_task_publish_handler(
|
||||
headers_to_use=headers,
|
||||
body_to_use=body,
|
||||
)
|
||||
|
||||
task_failure_handler(
|
||||
task_id=headers["id"],
|
||||
exception="Example failure",
|
||||
)
|
||||
|
||||
task = PaperlessTask.objects.get()
|
||||
|
||||
self.assertEqual(celery.states.FAILURE, task.status)
|
||||
|
@@ -401,12 +401,16 @@ class DocumentViewSet(
|
||||
|
||||
return Response(
|
||||
{
|
||||
"correspondents": [c.id for c in match_correspondents(doc, classifier)],
|
||||
"tags": [t.id for t in match_tags(doc, classifier)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier)
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [dt.id for dt in match_storage_paths(doc, classifier)],
|
||||
"dates": [
|
||||
date.strftime("%Y-%m-%d") for date in dates if date is not None
|
||||
],
|
||||
|
Reference in New Issue
Block a user