Updates how task_args and task_kwargs are parsed, adds testing to cover everything I can think of

This commit is contained in:
Trenton H 2022-09-28 09:09:51 -07:00
parent 6f6f006704
commit 5b66ef0a74
2 changed files with 166 additions and 31 deletions

View File

@ -1,8 +1,14 @@
import datetime
import json
import math
import os
import re
from ast import literal_eval
from asyncio.log import logger
from pathlib import Path
from typing import Dict
from typing import Optional
from typing import Tuple
from celery import states
try:
import zoneinfo
@ -646,16 +652,21 @@ class TasksViewSerializer(serializers.ModelSerializer):
def get_result(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
if (
hasattr(obj, "attempted_task")
and obj.attempted_task
and obj.attempted_task.result
):
try:
result_json = json.loads(obj.attempted_task.result)
except Exception:
pass
if result_json and "exc_message" in result_json:
result = result_json["exc_message"]
else:
result = obj.attempted_task.result.strip('"')
result: str = obj.attempted_task.result
if "exc_message" in result:
# This is a dict in this case
result: Dict = literal_eval(result)
# This is a list, grab the first item (most recent)
result = result["exc_message"][0]
except Exception as e: # pragma: no cover
# Extra security if something is malformed
logger.warn(f"Error getting task result: {e}", exc_info=True)
return result
status = serializers.SerializerMethodField()
@ -704,26 +715,25 @@ class TasksViewSerializer(serializers.ModelSerializer):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
try:
# We have to make this a valid JSON object string
kwargs_json = json.loads(
obj.attempted_task.task_kwargs.strip('"')
.replace("'", '"')
.replace("None", '""'),
)
except Exception:
pass
task_kwargs: Optional[str] = obj.attempted_task.task_kwargs
# Try the override filename first (this is a webui created task?)
if task_kwargs is not None:
# It's a string, string of a dict. Who knows why...
kwargs = literal_eval(literal_eval(task_kwargs))
if "override_filename" in kwargs:
result = kwargs["override_filename"]
if kwargs_json and "override_filename" in kwargs_json:
result = kwargs_json["override_filename"]
else:
filepath = (
obj.attempted_task.task_args.replace('"', "")
.replace("'", "")
.replace("(", "")
.replace(")", "")
.replace(",", "")
)
result = os.path.split(filepath)[1]
# Nothing was found, report the task first argument
if not len(result):
# There are always some arguments to the consume
task_args: Tuple = literal_eval(
literal_eval(obj.attempted_task.task_args),
)
filepath = Path(task_args[0])
result = filepath.name
except Exception as e: # pragma: no cover
# Extra security if something is malformed
logger.warn(f"Error getting task result: {e}", exc_info=True)
return result
@ -735,7 +745,8 @@ class TasksViewSerializer(serializers.ModelSerializer):
if (
hasattr(obj, "attempted_task")
and obj.attempted_task
and obj.attempted_task.status == "SUCCESS"
and obj.attempted_task.result
and obj.attempted_task.status == states.SUCCESS
):
try:
result = re.search(regexp, obj.attempted_task.result).group(1)

View File

@ -2831,6 +2831,14 @@ class TestTasks(APITestCase):
self.assertEqual(returned_task2["task_name"], result2.task_name)
def test_acknowledge_tasks(self):
"""
GIVEN:
- Attempted celery tasks
WHEN:
- API call is made to get mark task as acknowledged
THEN:
- Task is marked as acknowledged
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
@ -2849,3 +2857,119 @@ class TestTasks(APITestCase):
response = self.client.get(self.ENDPOINT)
self.assertEqual(len(response.data), 0)
def test_task_result_no_error(self):
"""
GIVEN:
- A celery task completed without error
WHEN:
- API call is made to get tasks
THEN:
- The returned data includes the task result
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result="Success. New document id 1 created",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["result"], "Success. New document id 1 created")
self.assertEqual(returned_data["related_document"], "1")
def test_task_result_with_error(self):
"""
GIVEN:
- A celery task completed with an exception
WHEN:
- API call is made to get tasks
THEN:
- The returned result is the exception info
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result={
"exc_type": "ConsumerError",
"exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."],
"exc_module": "documents.consumer",
},
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(
returned_data["result"],
"test.pdf: Not consuming test.pdf: It is a duplicate.",
)
def test_task_name_webui(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the webui
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"",
task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "test.pdf")
def test_task_name_consume_folder(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the consume folder
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/consume/anothertest.pdf',)\"",
task_kwargs="\"{'override_tag_ids': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "anothertest.pdf")