Updates how task_args and task_kwargs are parsed, adds testing to cover everything I can think of

This commit is contained in:
Trenton H 2022-09-28 09:09:51 -07:00
parent 6f6f006704
commit 5b66ef0a74
2 changed files with 166 additions and 31 deletions

View File

@ -1,8 +1,14 @@
import datetime import datetime
import json
import math import math
import os
import re import re
from ast import literal_eval
from asyncio.log import logger
from pathlib import Path
from typing import Dict
from typing import Optional
from typing import Tuple
from celery import states
try: try:
import zoneinfo import zoneinfo
@ -646,16 +652,21 @@ class TasksViewSerializer(serializers.ModelSerializer):
def get_result(self, obj): def get_result(self, obj):
result = "" result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task: if (
hasattr(obj, "attempted_task")
and obj.attempted_task
and obj.attempted_task.result
):
try: try:
result_json = json.loads(obj.attempted_task.result) result: str = obj.attempted_task.result
except Exception: if "exc_message" in result:
pass # This is a dict in this case
result: Dict = literal_eval(result)
if result_json and "exc_message" in result_json: # This is a list, grab the first item (most recent)
result = result_json["exc_message"] result = result["exc_message"][0]
else: except Exception as e: # pragma: no cover
result = obj.attempted_task.result.strip('"') # Extra security if something is malformed
logger.warn(f"Error getting task result: {e}", exc_info=True)
return result return result
status = serializers.SerializerMethodField() status = serializers.SerializerMethodField()
@ -704,26 +715,25 @@ class TasksViewSerializer(serializers.ModelSerializer):
result = "" result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task: if hasattr(obj, "attempted_task") and obj.attempted_task:
try: try:
# We have to make this a valid JSON object string task_kwargs: Optional[str] = obj.attempted_task.task_kwargs
kwargs_json = json.loads( # Try the override filename first (this is a webui created task?)
obj.attempted_task.task_kwargs.strip('"') if task_kwargs is not None:
.replace("'", '"') # It's a string, string of a dict. Who knows why...
.replace("None", '""'), kwargs = literal_eval(literal_eval(task_kwargs))
) if "override_filename" in kwargs:
except Exception: result = kwargs["override_filename"]
pass
if kwargs_json and "override_filename" in kwargs_json: # Nothing was found, report the task first argument
result = kwargs_json["override_filename"] if not len(result):
else: # There are always some arguments to the consume
filepath = ( task_args: Tuple = literal_eval(
obj.attempted_task.task_args.replace('"', "") literal_eval(obj.attempted_task.task_args),
.replace("'", "") )
.replace("(", "") filepath = Path(task_args[0])
.replace(")", "") result = filepath.name
.replace(",", "") except Exception as e: # pragma: no cover
) # Extra security if something is malformed
result = os.path.split(filepath)[1] logger.warn(f"Error getting task result: {e}", exc_info=True)
return result return result
@ -735,7 +745,8 @@ class TasksViewSerializer(serializers.ModelSerializer):
if ( if (
hasattr(obj, "attempted_task") hasattr(obj, "attempted_task")
and obj.attempted_task and obj.attempted_task
and obj.attempted_task.status == "SUCCESS" and obj.attempted_task.result
and obj.attempted_task.status == states.SUCCESS
): ):
try: try:
result = re.search(regexp, obj.attempted_task.result).group(1) result = re.search(regexp, obj.attempted_task.result).group(1)

View File

@ -2831,6 +2831,14 @@ class TestTasks(APITestCase):
self.assertEqual(returned_task2["task_name"], result2.task_name) self.assertEqual(returned_task2["task_name"], result2.task_name)
def test_acknowledge_tasks(self): def test_acknowledge_tasks(self):
"""
GIVEN:
- Attempted celery tasks
WHEN:
- API call is made to get mark task as acknowledged
THEN:
- Task is marked as acknowledged
"""
result1 = TaskResult.objects.create( result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task", task_name="documents.tasks.some_task",
@ -2849,3 +2857,119 @@ class TestTasks(APITestCase):
response = self.client.get(self.ENDPOINT) response = self.client.get(self.ENDPOINT)
self.assertEqual(len(response.data), 0) self.assertEqual(len(response.data), 0)
def test_task_result_no_error(self):
"""
GIVEN:
- A celery task completed without error
WHEN:
- API call is made to get tasks
THEN:
- The returned data includes the task result
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result="Success. New document id 1 created",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["result"], "Success. New document id 1 created")
self.assertEqual(returned_data["related_document"], "1")
def test_task_result_with_error(self):
"""
GIVEN:
- A celery task completed with an exception
WHEN:
- API call is made to get tasks
THEN:
- The returned result is the exception info
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result={
"exc_type": "ConsumerError",
"exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."],
"exc_module": "documents.consumer",
},
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(
returned_data["result"],
"test.pdf: Not consuming test.pdf: It is a duplicate.",
)
def test_task_name_webui(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the webui
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"",
task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "test.pdf")
def test_task_name_consume_folder(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the consume folder
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/consume/anothertest.pdf',)\"",
task_kwargs="\"{'override_tag_ids': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "anothertest.pdf")