diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 271fb4597..a7f75c489 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -4,6 +4,7 @@ import shutil from celery import states from celery.signals import before_task_publish +from celery.signals import task_failure from celery.signals import task_postrun from celery.signals import task_prerun from django.conf import settings @@ -591,3 +592,29 @@ def task_postrun_handler( # Don't let an exception in the signal handlers prevent # a document from being consumed. logger.exception("Updating PaperlessTask failed") + + +@task_failure.connect +def task_failure_handler( + sender=None, + task_id=None, + exception=None, + args=None, + traceback=None, + **kwargs, +): + """ + Updates the result of a failed PaperlessTask. + + https://docs.celeryq.dev/en/stable/userguide/signals.html#task-failure + """ + try: + task_instance = PaperlessTask.objects.filter(task_id=task_id).first() + + if task_instance is not None and task_instance.result is None: + task_instance.status = states.FAILURE + task_instance.result = traceback + task_instance.date_done = timezone.now() + task_instance.save() + except Exception: # pragma: no cover + logger.exception("Updating PaperlessTask failed") diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index a6befc25e..d63df0b3c 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -9,6 +9,7 @@ from documents.models import PaperlessTask from documents.signals.handlers import before_task_publish_handler from documents.signals.handlers import task_postrun_handler from documents.signals.handlers import task_prerun_handler +from documents.signals.handlers import task_failure_handler from documents.tests.test_consumer import fake_magic_from_file from documents.tests.utils import DirectoriesMixin @@ -146,3 +147,44 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): task = PaperlessTask.objects.get() self.assertEqual(celery.states.SUCCESS, task.status) + + def test_task_failure_handler(self): + """ + GIVEN: + - A celery task is started via the consume folder + WHEN: + - Task failed execution + THEN: + - The task is marked as failed + """ + headers = { + "id": str(uuid.uuid4()), + "task": "documents.tasks.consume_file", + } + body = ( + # args + ( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file="/consume/hello-9.pdf", + ), + None, + ), + # kwargs + {}, + # celery stuff + {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, + ) + self.util_call_before_task_publish_handler( + headers_to_use=headers, + body_to_use=body, + ) + + task_failure_handler( + task_id=headers["id"], + exception="Example failure", + ) + + task = PaperlessTask.objects.get() + + self.assertEqual(celery.states.FAILURE, task.status)