mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Updates how task_args and task_kwargs are parsed, adds testing to cover everything I can think of
This commit is contained in:
		| @@ -1,8 +1,14 @@ | ||||
| import datetime | ||||
| import json | ||||
| import math | ||||
| import os | ||||
| import re | ||||
| from ast import literal_eval | ||||
| from asyncio.log import logger | ||||
| from pathlib import Path | ||||
| from typing import Dict | ||||
| from typing import Optional | ||||
| from typing import Tuple | ||||
|  | ||||
| from celery import states | ||||
|  | ||||
| try: | ||||
|     import zoneinfo | ||||
| @@ -646,16 +652,21 @@ class TasksViewSerializer(serializers.ModelSerializer): | ||||
|  | ||||
|     def get_result(self, obj): | ||||
|         result = "" | ||||
|         if hasattr(obj, "attempted_task") and obj.attempted_task: | ||||
|         if ( | ||||
|             hasattr(obj, "attempted_task") | ||||
|             and obj.attempted_task | ||||
|             and obj.attempted_task.result | ||||
|         ): | ||||
|             try: | ||||
|                 result_json = json.loads(obj.attempted_task.result) | ||||
|             except Exception: | ||||
|                 pass | ||||
|  | ||||
|             if result_json and "exc_message" in result_json: | ||||
|                 result = result_json["exc_message"] | ||||
|             else: | ||||
|                 result = obj.attempted_task.result.strip('"') | ||||
|                 result: str = obj.attempted_task.result | ||||
|                 if "exc_message" in result: | ||||
|                     # This is a dict in this case | ||||
|                     result: Dict = literal_eval(result) | ||||
|                     # This is a list, grab the first item (most recent) | ||||
|                     result = result["exc_message"][0] | ||||
|             except Exception as e:  # pragma: no cover | ||||
|                 # Extra security if something is malformed | ||||
|                 logger.warn(f"Error getting task result: {e}", exc_info=True) | ||||
|         return result | ||||
|  | ||||
|     status = serializers.SerializerMethodField() | ||||
| @@ -704,26 +715,25 @@ class TasksViewSerializer(serializers.ModelSerializer): | ||||
|         result = "" | ||||
|         if hasattr(obj, "attempted_task") and obj.attempted_task: | ||||
|             try: | ||||
|                 # We have to make this a valid JSON object string | ||||
|                 kwargs_json = json.loads( | ||||
|                     obj.attempted_task.task_kwargs.strip('"') | ||||
|                     .replace("'", '"') | ||||
|                     .replace("None", '""'), | ||||
|                 ) | ||||
|             except Exception: | ||||
|                 pass | ||||
|                 task_kwargs: Optional[str] = obj.attempted_task.task_kwargs | ||||
|                 # Try the override filename first (this is a webui created task?) | ||||
|                 if task_kwargs is not None: | ||||
|                     # It's a string, string of a dict.  Who knows why... | ||||
|                     kwargs = literal_eval(literal_eval(task_kwargs)) | ||||
|                     if "override_filename" in kwargs: | ||||
|                         result = kwargs["override_filename"] | ||||
|  | ||||
|             if kwargs_json and "override_filename" in kwargs_json: | ||||
|                 result = kwargs_json["override_filename"] | ||||
|             else: | ||||
|                 filepath = ( | ||||
|                     obj.attempted_task.task_args.replace('"', "") | ||||
|                     .replace("'", "") | ||||
|                     .replace("(", "") | ||||
|                     .replace(")", "") | ||||
|                     .replace(",", "") | ||||
|                 ) | ||||
|                 result = os.path.split(filepath)[1] | ||||
|                 # Nothing was found, report the task first argument | ||||
|                 if not len(result): | ||||
|                     # There are always some arguments to the consume | ||||
|                     task_args: Tuple = literal_eval( | ||||
|                         literal_eval(obj.attempted_task.task_args), | ||||
|                     ) | ||||
|                     filepath = Path(task_args[0]) | ||||
|                     result = filepath.name | ||||
|             except Exception as e:  # pragma: no cover | ||||
|                 # Extra security if something is malformed | ||||
|                 logger.warn(f"Error getting task result: {e}", exc_info=True) | ||||
|  | ||||
|         return result | ||||
|  | ||||
| @@ -735,7 +745,8 @@ class TasksViewSerializer(serializers.ModelSerializer): | ||||
|         if ( | ||||
|             hasattr(obj, "attempted_task") | ||||
|             and obj.attempted_task | ||||
|             and obj.attempted_task.status == "SUCCESS" | ||||
|             and obj.attempted_task.result | ||||
|             and obj.attempted_task.status == states.SUCCESS | ||||
|         ): | ||||
|             try: | ||||
|                 result = re.search(regexp, obj.attempted_task.result).group(1) | ||||
|   | ||||
| @@ -2831,6 +2831,14 @@ class TestTasks(APITestCase): | ||||
|         self.assertEqual(returned_task2["task_name"], result2.task_name) | ||||
|  | ||||
|     def test_acknowledge_tasks(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery tasks | ||||
|         WHEN: | ||||
|             - API call is made to get mark task as acknowledged | ||||
|         THEN: | ||||
|             - Task is marked as acknowledged | ||||
|         """ | ||||
|         result1 = TaskResult.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_name="documents.tasks.some_task", | ||||
| @@ -2849,3 +2857,119 @@ class TestTasks(APITestCase): | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|         self.assertEqual(len(response.data), 0) | ||||
|  | ||||
|     def test_task_result_no_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A celery task completed without error | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - The returned data includes the task result | ||||
|         """ | ||||
|         result1 = TaskResult.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             result="Success. New document id 1 created", | ||||
|         ) | ||||
|         _ = PaperlessTask.objects.create(attempted_task=result1) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["result"], "Success. New document id 1 created") | ||||
|         self.assertEqual(returned_data["related_document"], "1") | ||||
|  | ||||
|     def test_task_result_with_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A celery task completed with an exception | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - The returned result is the exception info | ||||
|         """ | ||||
|         result1 = TaskResult.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             result={ | ||||
|                 "exc_type": "ConsumerError", | ||||
|                 "exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."], | ||||
|                 "exc_module": "documents.consumer", | ||||
|             }, | ||||
|         ) | ||||
|         _ = PaperlessTask.objects.create(attempted_task=result1) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual( | ||||
|             returned_data["result"], | ||||
|             "test.pdf: Not consuming test.pdf: It is a duplicate.", | ||||
|         ) | ||||
|  | ||||
|     def test_task_name_webui(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery task | ||||
|             - Task was created through the webui | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - Returned data include the filename | ||||
|         """ | ||||
|         result1 = TaskResult.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"", | ||||
|             task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"", | ||||
|         ) | ||||
|         _ = PaperlessTask.objects.create(attempted_task=result1) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["name"], "test.pdf") | ||||
|  | ||||
|     def test_task_name_consume_folder(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery task | ||||
|             - Task was created through the consume folder | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - Returned data include the filename | ||||
|         """ | ||||
|         result1 = TaskResult.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             task_args="\"('/consume/anothertest.pdf',)\"", | ||||
|             task_kwargs="\"{'override_tag_ids': None}\"", | ||||
|         ) | ||||
|         _ = PaperlessTask.objects.create(attempted_task=result1) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["name"], "anothertest.pdf") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Trenton H
					Trenton H