Merge pull request #1648 from paperless-ngx/feature-use-celery

Feature: Transition to celery for background tasks
This commit is contained in:
shamoon
2022-10-10 00:07:55 -07:00
committed by GitHub
34 changed files with 959 additions and 297 deletions

View File

@@ -1,11 +1,12 @@
import itertools
from django.db.models import Q
from django_q.tasks import async_task
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.tasks import bulk_update_documents
from documents.tasks import update_document_archive_file
def set_correspondent(doc_ids, correspondent):
@@ -16,7 +17,7 @@ def set_correspondent(doc_ids, correspondent):
affected_docs = [doc.id for doc in qs]
qs.update(correspondent=correspondent)
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"
@@ -31,8 +32,7 @@ def set_storage_path(doc_ids, storage_path):
affected_docs = [doc.id for doc in qs]
qs.update(storage_path=storage_path)
async_task(
"documents.tasks.bulk_update_documents",
bulk_update_documents.delay(
document_ids=affected_docs,
)
@@ -47,7 +47,7 @@ def set_document_type(doc_ids, document_type):
affected_docs = [doc.id for doc in qs]
qs.update(document_type=document_type)
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"
@@ -63,7 +63,7 @@ def add_tag(doc_ids, tag):
[DocumentTagRelationship(document_id=doc, tag_id=tag) for doc in affected_docs],
)
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"
@@ -79,7 +79,7 @@ def remove_tag(doc_ids, tag):
Q(document_id__in=affected_docs) & Q(tag_id=tag),
).delete()
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"
@@ -103,7 +103,7 @@ def modify_tags(doc_ids, add_tags, remove_tags):
ignore_conflicts=True,
)
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"
@@ -123,8 +123,7 @@ def delete(doc_ids):
def redo_ocr(doc_ids):
for document_id in doc_ids:
async_task(
"documents.tasks.update_document_archive_file",
update_document_archive_file.delay(
document_id=document_id,
)

View File

@@ -11,9 +11,9 @@ from typing import Final
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.management.base import CommandError
from django_q.tasks import async_task
from documents.models import Tag
from documents.parsers import is_file_ext_supported
from documents.tasks import consume_file
from watchdog.events import FileSystemEventHandler
from watchdog.observers.polling import PollingObserver
@@ -92,11 +92,9 @@ def _consume(filepath):
try:
logger.info(f"Adding {filepath} to the task queue.")
async_task(
"documents.tasks.consume_file",
consume_file.delay(
filepath,
override_tag_ids=tag_ids if tag_ids else None,
task_name=os.path.basename(filepath)[:100],
)
except Exception:
# Catch all so that the consumer won't crash.

View File

@@ -1,34 +1,14 @@
# Generated by Django 3.1.3 on 2020-11-09 16:36
from django.db import migrations
from django.db.migrations import RunPython
from django_q.models import Schedule
from django_q.tasks import schedule
def add_schedules(apps, schema_editor):
schedule(
"documents.tasks.train_classifier",
name="Train the classifier",
schedule_type=Schedule.HOURLY,
)
schedule(
"documents.tasks.index_optimize",
name="Optimize the index",
schedule_type=Schedule.DAILY,
)
def remove_schedules(apps, schema_editor):
Schedule.objects.filter(func="documents.tasks.train_classifier").delete()
Schedule.objects.filter(func="documents.tasks.index_optimize").delete()
class Migration(migrations.Migration):
dependencies = [
("documents", "1000_update_paperless_all"),
("django_q", "0013_task_attempt_count"),
]
operations = [RunPython(add_schedules, remove_schedules)]
operations = [
migrations.RunPython(migrations.RunPython.noop, migrations.RunPython.noop)
]

View File

@@ -2,27 +2,12 @@
from django.db import migrations
from django.db.migrations import RunPython
from django_q.models import Schedule
from django_q.tasks import schedule
def add_schedules(apps, schema_editor):
schedule(
"documents.tasks.sanity_check",
name="Perform sanity check",
schedule_type=Schedule.WEEKLY,
)
def remove_schedules(apps, schema_editor):
Schedule.objects.filter(func="documents.tasks.sanity_check").delete()
class Migration(migrations.Migration):
dependencies = [
("documents", "1003_mime_types"),
("django_q", "0013_task_attempt_count"),
]
operations = [RunPython(add_schedules, remove_schedules)]
operations = [RunPython(migrations.RunPython.noop, migrations.RunPython.noop)]

View File

@@ -4,28 +4,9 @@ from django.db import migrations, models
import django.db.models.deletion
def init_paperless_tasks(apps, schema_editor):
PaperlessTask = apps.get_model("documents", "PaperlessTask")
Task = apps.get_model("django_q", "Task")
for task in Task.objects.filter(func="documents.tasks.consume_file"):
if not hasattr(task, "paperlesstask"):
paperlesstask = PaperlessTask.objects.create(
attempted_task=task,
task_id=task.id,
name=task.name,
created=task.started,
started=task.started,
acknowledged=True,
)
task.paperlesstask = paperlesstask
task.save()
class Migration(migrations.Migration):
dependencies = [
("django_q", "0014_schedule_cluster"),
("documents", "1021_webp_thumbnail_conversion"),
]
@@ -60,10 +41,12 @@ class Migration(migrations.Migration):
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="attempted_task",
to="django_q.task",
# This is a dummy field, 1026 will fix up the column
# This manual change is required, as django doesn't django doesn't really support
# removing an app which has migration deps like this
to="documents.document",
),
),
],
),
migrations.RunPython(init_paperless_tasks, migrations.RunPython.noop),
)
]

View File

@@ -0,0 +1,57 @@
# Generated by Django 4.1.1 on 2022-09-27 19:31
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("django_celery_results", "0011_taskresult_periodic_task_name"),
("documents", "1025_alter_savedviewfilterrule_rule_type"),
]
operations = [
migrations.RemoveField(
model_name="paperlesstask",
name="created",
),
migrations.RemoveField(
model_name="paperlesstask",
name="name",
),
migrations.RemoveField(
model_name="paperlesstask",
name="started",
),
# Remove the field from the model
migrations.RemoveField(
model_name="paperlesstask",
name="attempted_task",
),
# Add the field back, pointing to the correct model
# This resolves a problem where the temporary change in 1022
# results in a type mismatch
migrations.AddField(
model_name="paperlesstask",
name="attempted_task",
field=models.OneToOneField(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="attempted_task",
to="django_celery_results.taskresult",
),
),
# Drop the django-q tables entirely
# Must be done last or there could be references here
migrations.RunSQL(
"DROP TABLE IF EXISTS django_q_ormq", reverse_sql=migrations.RunSQL.noop
),
migrations.RunSQL(
"DROP TABLE IF EXISTS django_q_schedule", reverse_sql=migrations.RunSQL.noop
),
migrations.RunSQL(
"DROP TABLE IF EXISTS django_q_task", reverse_sql=migrations.RunSQL.noop
),
]

View File

@@ -12,7 +12,7 @@ from django.contrib.auth.models import User
from django.db import models
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from django_q.tasks import Task
from django_celery_results.models import TaskResult
from documents.parsers import get_default_file_extension
@@ -527,19 +527,16 @@ class UiSettings(models.Model):
class PaperlessTask(models.Model):
task_id = models.CharField(max_length=128)
name = models.CharField(max_length=256)
created = models.DateTimeField(_("created"), auto_now=True)
started = models.DateTimeField(_("started"), null=True)
acknowledged = models.BooleanField(default=False)
attempted_task = models.OneToOneField(
Task,
TaskResult,
on_delete=models.CASCADE,
related_name="attempted_task",
null=True,
blank=True,
)
acknowledged = models.BooleanField(default=False)
class Comment(models.Model):

View File

@@ -1,6 +1,14 @@
import datetime
import math
import re
from ast import literal_eval
from asyncio.log import logger
from pathlib import Path
from typing import Dict
from typing import Optional
from typing import Tuple
from celery import states
try:
import zoneinfo
@@ -18,12 +26,12 @@ from .models import Correspondent
from .models import Document
from .models import DocumentType
from .models import MatchingModel
from .models import PaperlessTask
from .models import SavedView
from .models import SavedViewFilterRule
from .models import StoragePath
from .models import Tag
from .models import UiSettings
from .models import PaperlessTask
from .parsers import is_mime_type_supported
@@ -629,7 +637,19 @@ class TasksViewSerializer(serializers.ModelSerializer):
class Meta:
model = PaperlessTask
depth = 1
fields = "__all__"
fields = (
"id",
"task_id",
"date_created",
"date_done",
"type",
"status",
"result",
"acknowledged",
"task_name",
"name",
"related_document",
)
type = serializers.SerializerMethodField()
@@ -641,24 +661,108 @@ class TasksViewSerializer(serializers.ModelSerializer):
def get_result(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.result
if (
hasattr(obj, "attempted_task")
and obj.attempted_task
and obj.attempted_task.result
):
try:
result: str = obj.attempted_task.result
if "exc_message" in result:
# This is a dict in this case
result: Dict = literal_eval(result)
# This is a list, grab the first item (most recent)
result = result["exc_message"][0]
except Exception as e: # pragma: no cover
# Extra security if something is malformed
logger.warn(f"Error getting task result: {e}", exc_info=True)
return result
status = serializers.SerializerMethodField()
def get_status(self, obj):
if obj.attempted_task is None:
if obj.started:
return "started"
else:
return "queued"
elif obj.attempted_task.success:
return "complete"
elif not obj.attempted_task.success:
return "failed"
else:
return "unknown"
result = "unknown"
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.status
return result
date_created = serializers.SerializerMethodField()
def get_date_created(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.date_created
return result
date_done = serializers.SerializerMethodField()
def get_date_done(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.date_done
return result
task_id = serializers.SerializerMethodField()
def get_task_id(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.task_id
return result
task_name = serializers.SerializerMethodField()
def get_task_name(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
result = obj.attempted_task.task_name
return result
name = serializers.SerializerMethodField()
def get_name(self, obj):
result = ""
if hasattr(obj, "attempted_task") and obj.attempted_task:
try:
task_kwargs: Optional[str] = obj.attempted_task.task_kwargs
# Try the override filename first (this is a webui created task?)
if task_kwargs is not None:
# It's a string, string of a dict. Who knows why...
kwargs = literal_eval(literal_eval(task_kwargs))
if "override_filename" in kwargs:
result = kwargs["override_filename"]
# Nothing was found, report the task first argument
if not len(result):
# There are always some arguments to the consume
task_args: Tuple = literal_eval(
literal_eval(obj.attempted_task.task_args),
)
filepath = Path(task_args[0])
result = filepath.name
except Exception as e: # pragma: no cover
# Extra security if something is malformed
logger.warn(f"Error getting file name from task: {e}", exc_info=True)
return result
related_document = serializers.SerializerMethodField()
def get_related_document(self, obj):
result = ""
regexp = r"New document id (\d+) created"
if (
hasattr(obj, "attempted_task")
and obj.attempted_task
and obj.attempted_task.result
and obj.attempted_task.status == states.SUCCESS
):
try:
result = re.search(regexp, obj.attempted_task.result).group(1)
except Exception:
pass
return result
class AcknowledgeTasksViewSerializer(serializers.Serializer):

View File

@@ -2,7 +2,6 @@ import logging
import os
import shutil
import django_q
from django.conf import settings
from django.contrib.admin.models import ADDITION
from django.contrib.admin.models import LogEntry
@@ -14,6 +13,7 @@ from django.db.models import Q
from django.dispatch import receiver
from django.utils import termcolors
from django.utils import timezone
from django_celery_results.models import TaskResult
from filelock import FileLock
from .. import matching
@@ -25,7 +25,6 @@ from ..models import MatchingModel
from ..models import PaperlessTask
from ..models import Tag
logger = logging.getLogger("paperless.handlers")
@@ -503,47 +502,19 @@ def add_to_index(sender, document, **kwargs):
index.add_or_update_document(document)
@receiver(django_q.signals.pre_enqueue)
def init_paperless_task(sender, task, **kwargs):
if task["func"] == "documents.tasks.consume_file":
try:
paperless_task, created = PaperlessTask.objects.get_or_create(
task_id=task["id"],
)
paperless_task.name = task["name"]
paperless_task.created = task["started"]
paperless_task.save()
except Exception as e:
# Don't let an exception in the signal handlers prevent
# a document from being consumed.
logger.error(f"Creating PaperlessTask failed: {e}")
@receiver(django_q.signals.pre_execute)
def paperless_task_started(sender, task, **kwargs):
@receiver(models.signals.post_save, sender=TaskResult)
def update_paperless_task(sender, instance: TaskResult, **kwargs):
try:
if task["func"] == "documents.tasks.consume_file":
paperless_task, created = PaperlessTask.objects.get_or_create(
task_id=task["id"],
)
paperless_task.started = timezone.now()
paperless_task.save()
except PaperlessTask.DoesNotExist:
pass
except Exception as e:
logger.error(f"Creating PaperlessTask failed: {e}")
@receiver(models.signals.post_save, sender=django_q.models.Task)
def update_paperless_task(sender, instance, **kwargs):
try:
if instance.func == "documents.tasks.consume_file":
paperless_task, created = PaperlessTask.objects.get_or_create(
task_id=instance.id,
if instance.task_name == "documents.tasks.consume_file":
paperless_task, _ = PaperlessTask.objects.get_or_create(
task_id=instance.task_id,
)
paperless_task.name = instance.task_name
paperless_task.created = instance.date_created
paperless_task.completed = instance.date_done
paperless_task.attempted_task = instance
paperless_task.save()
except PaperlessTask.DoesNotExist:
pass
except Exception as e:
# Don't let an exception in the signal handlers prevent
# a document from being consumed.
logger.error(f"Creating PaperlessTask failed: {e}")

View File

@@ -8,6 +8,7 @@ from typing import Type
import tqdm
from asgiref.sync import async_to_sync
from celery import shared_task
from channels.layers import get_channel_layer
from django.conf import settings
from django.db import transaction
@@ -36,6 +37,7 @@ from whoosh.writing import AsyncWriter
logger = logging.getLogger("paperless.tasks")
@shared_task
def index_optimize():
ix = index.open_index()
writer = AsyncWriter(ix)
@@ -52,6 +54,7 @@ def index_reindex(progress_bar_disable=False):
index.update_document(writer, document)
@shared_task
def train_classifier():
if (
not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
@@ -80,6 +83,7 @@ def train_classifier():
logger.warning("Classifier error: " + str(e))
@shared_task
def consume_file(
path,
override_filename=None,
@@ -183,6 +187,7 @@ def consume_file(
)
@shared_task
def sanity_check():
messages = sanity_checker.check_sanity()
@@ -198,6 +203,7 @@ def sanity_check():
return "No issues detected."
@shared_task
def bulk_update_documents(document_ids):
documents = Document.objects.filter(id__in=document_ids)
@@ -211,6 +217,7 @@ def bulk_update_documents(document_ids):
index.update_document(writer, doc)
@shared_task
def update_document_archive_file(document_id):
"""
Re-creates the archive file of a document, including new OCR content and thumbnail

View File

@@ -10,6 +10,8 @@ import zipfile
from unittest import mock
from unittest.mock import MagicMock
import celery
try:
import zoneinfo
except ImportError:
@@ -20,7 +22,6 @@ from django.conf import settings
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from django_q.models import Task
from documents import bulk_edit
from documents import index
from documents.models import Correspondent
@@ -31,7 +32,7 @@ from documents.models import PaperlessTask
from documents.models import SavedView
from documents.models import StoragePath
from documents.models import Tag
from documents.models import UiSettings
from django_celery_results.models import TaskResult
from documents.models import Comment
from documents.models import StoragePath
from documents.tests.utils import DirectoriesMixin
@@ -790,7 +791,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["documents_inbox"], None)
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload(self, m):
with open(
@@ -813,7 +814,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertIsNone(kwargs["override_document_type_id"])
self.assertIsNone(kwargs["override_tag_ids"])
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_empty_metadata(self, m):
with open(
@@ -836,7 +837,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertIsNone(kwargs["override_document_type_id"])
self.assertIsNone(kwargs["override_tag_ids"])
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_invalid_form(self, m):
with open(
@@ -850,7 +851,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, 400)
m.assert_not_called()
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_invalid_file(self, m):
with open(
@@ -864,7 +865,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, 400)
m.assert_not_called()
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_title(self, async_task):
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -882,7 +883,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(kwargs["override_title"], "my custom title")
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_correspondent(self, async_task):
c = Correspondent.objects.create(name="test-corres")
with open(
@@ -901,7 +902,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(kwargs["override_correspondent_id"], c.id)
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_invalid_correspondent(self, async_task):
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -915,7 +916,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
async_task.assert_not_called()
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_document_type(self, async_task):
dt = DocumentType.objects.create(name="invoice")
with open(
@@ -934,7 +935,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(kwargs["override_document_type_id"], dt.id)
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_invalid_document_type(self, async_task):
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -948,7 +949,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
async_task.assert_not_called()
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_tags(self, async_task):
t1 = Tag.objects.create(name="tag1")
t2 = Tag.objects.create(name="tag2")
@@ -968,7 +969,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertCountEqual(kwargs["override_tag_ids"], [t1.id, t2.id])
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_invalid_tags(self, async_task):
t1 = Tag.objects.create(name="tag1")
t2 = Tag.objects.create(name="tag2")
@@ -984,7 +985,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
async_task.assert_not_called()
@mock.patch("documents.views.async_task")
@mock.patch("documents.views.consume_file.delay")
def test_upload_with_created(self, async_task):
created = datetime.datetime(
2022,
@@ -1619,7 +1620,7 @@ class TestBulkEdit(DirectoriesMixin, APITestCase):
user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=user)
patcher = mock.patch("documents.bulk_edit.async_task")
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
self.async_task = patcher.start()
self.addCleanup(patcher.stop)
self.c1 = Correspondent.objects.create(name="c1")
@@ -2738,7 +2739,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
class TestTasks(APITestCase):
ENDPOINT = "/api/tasks/"
ENDPOINT_ACKOWLEDGE = "/api/acknowledge_tasks/"
ENDPOINT_ACKNOWLEDGE = "/api/acknowledge_tasks/"
def setUp(self):
super().setUp()
@@ -2747,16 +2748,27 @@ class TestTasks(APITestCase):
self.client.force_authenticate(user=self.user)
def test_get_tasks(self):
task_id1 = str(uuid.uuid4())
PaperlessTask.objects.create(task_id=task_id1)
Task.objects.create(
id=task_id1,
started=timezone.now() - datetime.timedelta(seconds=30),
stopped=timezone.now(),
func="documents.tasks.consume_file",
"""
GIVEN:
- Attempted celery tasks
WHEN:
- API call is made to get tasks
THEN:
- Attempting and pending tasks are serialized and provided
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_great_task",
status=celery.states.PENDING,
)
task_id2 = str(uuid.uuid4())
PaperlessTask.objects.create(task_id=task_id2)
PaperlessTask.objects.create(attempted_task=result1)
result2 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_awesome_task",
status=celery.states.STARTED,
)
PaperlessTask.objects.create(attempted_task=result2)
response = self.client.get(self.ENDPOINT)
@@ -2764,25 +2776,155 @@ class TestTasks(APITestCase):
self.assertEqual(len(response.data), 2)
returned_task1 = response.data[1]
returned_task2 = response.data[0]
self.assertEqual(returned_task1["task_id"], task_id1)
self.assertEqual(returned_task1["status"], "complete")
self.assertIsNotNone(returned_task1["attempted_task"])
self.assertEqual(returned_task2["task_id"], task_id2)
self.assertEqual(returned_task2["status"], "queued")
self.assertIsNone(returned_task2["attempted_task"])
self.assertEqual(returned_task1["task_id"], result1.task_id)
self.assertEqual(returned_task1["status"], celery.states.PENDING)
self.assertEqual(returned_task1["task_name"], result1.task_name)
self.assertEqual(returned_task2["task_id"], result2.task_id)
self.assertEqual(returned_task2["status"], celery.states.STARTED)
self.assertEqual(returned_task2["task_name"], result2.task_name)
def test_acknowledge_tasks(self):
task_id = str(uuid.uuid4())
task = PaperlessTask.objects.create(task_id=task_id)
"""
GIVEN:
- Attempted celery tasks
WHEN:
- API call is made to get mark task as acknowledged
THEN:
- Task is marked as acknowledged
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.PENDING,
)
task = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(len(response.data), 1)
response = self.client.post(
self.ENDPOINT_ACKOWLEDGE,
self.ENDPOINT_ACKNOWLEDGE,
{"tasks": [task.id]},
)
self.assertEqual(response.status_code, 200)
response = self.client.get(self.ENDPOINT)
self.assertEqual(len(response.data), 0)
def test_task_result_no_error(self):
"""
GIVEN:
- A celery task completed without error
WHEN:
- API call is made to get tasks
THEN:
- The returned data includes the task result
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result="Success. New document id 1 created",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["result"], "Success. New document id 1 created")
self.assertEqual(returned_data["related_document"], "1")
def test_task_result_with_error(self):
"""
GIVEN:
- A celery task completed with an exception
WHEN:
- API call is made to get tasks
THEN:
- The returned result is the exception info
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
result={
"exc_type": "ConsumerError",
"exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."],
"exc_module": "documents.consumer",
},
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(
returned_data["result"],
"test.pdf: Not consuming test.pdf: It is a duplicate.",
)
def test_task_name_webui(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the webui
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"",
task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "test.pdf")
def test_task_name_consume_folder(self):
"""
GIVEN:
- Attempted celery task
- Task was created through the consume folder
WHEN:
- API call is made to get tasks
THEN:
- Returned data include the filename
"""
result1 = TaskResult.objects.create(
task_id=str(uuid.uuid4()),
task_name="documents.tasks.some_task",
status=celery.states.SUCCESS,
task_args="\"('/consume/anothertest.pdf',)\"",
task_kwargs="\"{'override_tag_ids': None}\"",
)
_ = PaperlessTask.objects.create(attempted_task=result1)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 1)
returned_data = response.data[0]
self.assertEqual(returned_data["name"], "anothertest.pdf")

View File

@@ -43,7 +43,7 @@ class ConsumerMixin:
super().setUp()
self.t = None
patcher = mock.patch(
"documents.management.commands.document_consumer.async_task",
"documents.tasks.consume_file.delay",
)
self.task_mock = patcher.start()
self.addCleanup(patcher.stop)
@@ -76,7 +76,7 @@ class ConsumerMixin:
# A bogus async_task that will simply check the file for
# completeness and raise an exception otherwise.
def bogus_task(self, func, filename, **kwargs):
def bogus_task(self, filename, **kwargs):
eq = filecmp.cmp(filename, self.sample_file, shallow=False)
if not eq:
print("Consumed an INVALID file.")
@@ -115,7 +115,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.task_mock.assert_called_once()
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], f)
self.assertEqual(args[0], f)
def test_consume_file_invalid_ext(self):
self.t_start()
@@ -135,7 +135,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.task_mock.assert_called_once()
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], f)
self.assertEqual(args[0], f)
@mock.patch("documents.management.commands.document_consumer.logger.error")
def test_slow_write_pdf(self, error_logger):
@@ -155,7 +155,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.task_mock.assert_called_once()
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], fname)
self.assertEqual(args[0], fname)
@mock.patch("documents.management.commands.document_consumer.logger.error")
def test_slow_write_and_move(self, error_logger):
@@ -175,7 +175,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.task_mock.assert_called_once()
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], fname2)
self.assertEqual(args[0], fname2)
error_logger.assert_not_called()
@@ -193,7 +193,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.task_mock.assert_called_once()
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], fname)
self.assertEqual(args[0], fname)
# assert that we have an error logged with this invalid file.
error_logger.assert_called_once()
@@ -241,7 +241,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
self.assertEqual(2, self.task_mock.call_count)
fnames = [
os.path.basename(args[1]) for args, _ in self.task_mock.call_args_list
os.path.basename(args[0]) for args, _ in self.task_mock.call_args_list
]
self.assertCountEqual(fnames, ["my_file.pdf", "my_second_file.pdf"])
@@ -338,7 +338,7 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
tag_ids.append(Tag.objects.get(name=tag_names[1]).pk)
args, kwargs = self.task_mock.call_args
self.assertEqual(args[1], f)
self.assertEqual(args[0], f)
# assertCountEqual has a bad name, but test that the first
# sequence contains the same elements as second, regardless of

View File

@@ -28,7 +28,7 @@ from django.utils.translation import get_language
from django.views.decorators.cache import cache_control
from django.views.generic import TemplateView
from django_filters.rest_framework import DjangoFilterBackend
from django_q.tasks import async_task
from documents.tasks import consume_file
from packaging import version as packaging_version
from paperless import version
from paperless.db import GnuPG
@@ -615,8 +615,7 @@ class PostDocumentView(GenericAPIView):
task_id = str(uuid.uuid4())
async_task(
"documents.tasks.consume_file",
consume_file.delay(
temp_filename,
override_filename=doc_name,
override_title=title,
@@ -624,7 +623,6 @@ class PostDocumentView(GenericAPIView):
override_document_type_id=document_type_id,
override_tag_ids=tag_ids,
task_id=task_id,
task_name=os.path.basename(doc_name)[:100],
override_created=created,
)
@@ -888,8 +886,9 @@ class TasksViewSet(ReadOnlyModelViewSet):
queryset = (
PaperlessTask.objects.filter(
acknowledged=False,
attempted_task__isnull=False,
)
.order_by("created")
.order_by("attempted_task__date_created")
.reverse()
)