From 5b66ef0a748fd5570361a2a1ed6147e0462568d2 Mon Sep 17 00:00:00 2001 From: Trenton H Date: Wed, 28 Sep 2022 09:09:51 -0700 Subject: [PATCH] Updates how task_args and task_kwargs are parsed, adds testing to cover everything I can think of --- src/documents/serialisers.py | 73 +++++++++++-------- src/documents/tests/test_api.py | 124 ++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 31 deletions(-) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 172992de4..cd99c43cd 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1,8 +1,14 @@ import datetime -import json import math -import os import re +from ast import literal_eval +from asyncio.log import logger +from pathlib import Path +from typing import Dict +from typing import Optional +from typing import Tuple + +from celery import states try: import zoneinfo @@ -646,16 +652,21 @@ class TasksViewSerializer(serializers.ModelSerializer): def get_result(self, obj): result = "" - if hasattr(obj, "attempted_task") and obj.attempted_task: + if ( + hasattr(obj, "attempted_task") + and obj.attempted_task + and obj.attempted_task.result + ): try: - result_json = json.loads(obj.attempted_task.result) - except Exception: - pass - - if result_json and "exc_message" in result_json: - result = result_json["exc_message"] - else: - result = obj.attempted_task.result.strip('"') + result: str = obj.attempted_task.result + if "exc_message" in result: + # This is a dict in this case + result: Dict = literal_eval(result) + # This is a list, grab the first item (most recent) + result = result["exc_message"][0] + except Exception as e: # pragma: no cover + # Extra security if something is malformed + logger.warn(f"Error getting task result: {e}", exc_info=True) return result status = serializers.SerializerMethodField() @@ -704,26 +715,25 @@ class TasksViewSerializer(serializers.ModelSerializer): result = "" if hasattr(obj, "attempted_task") and obj.attempted_task: try: - # We have to make this a valid JSON object string - kwargs_json = json.loads( - obj.attempted_task.task_kwargs.strip('"') - .replace("'", '"') - .replace("None", '""'), - ) - except Exception: - pass + task_kwargs: Optional[str] = obj.attempted_task.task_kwargs + # Try the override filename first (this is a webui created task?) + if task_kwargs is not None: + # It's a string, string of a dict. Who knows why... + kwargs = literal_eval(literal_eval(task_kwargs)) + if "override_filename" in kwargs: + result = kwargs["override_filename"] - if kwargs_json and "override_filename" in kwargs_json: - result = kwargs_json["override_filename"] - else: - filepath = ( - obj.attempted_task.task_args.replace('"', "") - .replace("'", "") - .replace("(", "") - .replace(")", "") - .replace(",", "") - ) - result = os.path.split(filepath)[1] + # Nothing was found, report the task first argument + if not len(result): + # There are always some arguments to the consume + task_args: Tuple = literal_eval( + literal_eval(obj.attempted_task.task_args), + ) + filepath = Path(task_args[0]) + result = filepath.name + except Exception as e: # pragma: no cover + # Extra security if something is malformed + logger.warn(f"Error getting task result: {e}", exc_info=True) return result @@ -735,7 +745,8 @@ class TasksViewSerializer(serializers.ModelSerializer): if ( hasattr(obj, "attempted_task") and obj.attempted_task - and obj.attempted_task.status == "SUCCESS" + and obj.attempted_task.result + and obj.attempted_task.status == states.SUCCESS ): try: result = re.search(regexp, obj.attempted_task.result).group(1) diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 89e340501..ec89a19e8 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -2831,6 +2831,14 @@ class TestTasks(APITestCase): self.assertEqual(returned_task2["task_name"], result2.task_name) def test_acknowledge_tasks(self): + """ + GIVEN: + - Attempted celery tasks + WHEN: + - API call is made to get mark task as acknowledged + THEN: + - Task is marked as acknowledged + """ result1 = TaskResult.objects.create( task_id=str(uuid.uuid4()), task_name="documents.tasks.some_task", @@ -2849,3 +2857,119 @@ class TestTasks(APITestCase): response = self.client.get(self.ENDPOINT) self.assertEqual(len(response.data), 0) + + def test_task_result_no_error(self): + """ + GIVEN: + - A celery task completed without error + WHEN: + - API call is made to get tasks + THEN: + - The returned data includes the task result + """ + result1 = TaskResult.objects.create( + task_id=str(uuid.uuid4()), + task_name="documents.tasks.some_task", + status=celery.states.SUCCESS, + result="Success. New document id 1 created", + ) + _ = PaperlessTask.objects.create(attempted_task=result1) + + response = self.client.get(self.ENDPOINT) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + + returned_data = response.data[0] + + self.assertEqual(returned_data["result"], "Success. New document id 1 created") + self.assertEqual(returned_data["related_document"], "1") + + def test_task_result_with_error(self): + """ + GIVEN: + - A celery task completed with an exception + WHEN: + - API call is made to get tasks + THEN: + - The returned result is the exception info + """ + result1 = TaskResult.objects.create( + task_id=str(uuid.uuid4()), + task_name="documents.tasks.some_task", + status=celery.states.SUCCESS, + result={ + "exc_type": "ConsumerError", + "exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."], + "exc_module": "documents.consumer", + }, + ) + _ = PaperlessTask.objects.create(attempted_task=result1) + + response = self.client.get(self.ENDPOINT) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + + returned_data = response.data[0] + + self.assertEqual( + returned_data["result"], + "test.pdf: Not consuming test.pdf: It is a duplicate.", + ) + + def test_task_name_webui(self): + """ + GIVEN: + - Attempted celery task + - Task was created through the webui + WHEN: + - API call is made to get tasks + THEN: + - Returned data include the filename + """ + result1 = TaskResult.objects.create( + task_id=str(uuid.uuid4()), + task_name="documents.tasks.some_task", + status=celery.states.SUCCESS, + task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"", + task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"", + ) + _ = PaperlessTask.objects.create(attempted_task=result1) + + response = self.client.get(self.ENDPOINT) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + + returned_data = response.data[0] + + self.assertEqual(returned_data["name"], "test.pdf") + + def test_task_name_consume_folder(self): + """ + GIVEN: + - Attempted celery task + - Task was created through the consume folder + WHEN: + - API call is made to get tasks + THEN: + - Returned data include the filename + """ + result1 = TaskResult.objects.create( + task_id=str(uuid.uuid4()), + task_name="documents.tasks.some_task", + status=celery.states.SUCCESS, + task_args="\"('/consume/anothertest.pdf',)\"", + task_kwargs="\"{'override_tag_ids': None}\"", + ) + _ = PaperlessTask.objects.create(attempted_task=result1) + + response = self.client.get(self.ENDPOINT) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.data), 1) + + returned_data = response.data[0] + + self.assertEqual(returned_data["name"], "anothertest.pdf")