mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Transitions the backend to celery and celery beat
This commit is contained in:

committed by
Trenton H

parent
74c1a99545
commit
09287701ae
@@ -1,11 +1,12 @@
|
||||
import itertools
|
||||
|
||||
from django.db.models import Q
|
||||
from django_q.tasks import async_task
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.tasks import bulk_update_documents
|
||||
from documents.tasks import update_document_archive_file
|
||||
|
||||
|
||||
def set_correspondent(doc_ids, correspondent):
|
||||
@@ -16,7 +17,7 @@ def set_correspondent(doc_ids, correspondent):
|
||||
affected_docs = [doc.id for doc in qs]
|
||||
qs.update(correspondent=correspondent)
|
||||
|
||||
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -31,8 +32,7 @@ def set_storage_path(doc_ids, storage_path):
|
||||
affected_docs = [doc.id for doc in qs]
|
||||
qs.update(storage_path=storage_path)
|
||||
|
||||
async_task(
|
||||
"documents.tasks.bulk_update_documents",
|
||||
bulk_update_documents.delay(
|
||||
document_ids=affected_docs,
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ def set_document_type(doc_ids, document_type):
|
||||
affected_docs = [doc.id for doc in qs]
|
||||
qs.update(document_type=document_type)
|
||||
|
||||
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -63,7 +63,7 @@ def add_tag(doc_ids, tag):
|
||||
[DocumentTagRelationship(document_id=doc, tag_id=tag) for doc in affected_docs],
|
||||
)
|
||||
|
||||
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -79,7 +79,7 @@ def remove_tag(doc_ids, tag):
|
||||
Q(document_id__in=affected_docs) & Q(tag_id=tag),
|
||||
).delete()
|
||||
|
||||
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -103,7 +103,7 @@ def modify_tags(doc_ids, add_tags, remove_tags):
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs)
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -123,8 +123,7 @@ def delete(doc_ids):
|
||||
def redo_ocr(doc_ids):
|
||||
|
||||
for document_id in doc_ids:
|
||||
async_task(
|
||||
"documents.tasks.update_document_archive_file",
|
||||
update_document_archive_file.delay(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
|
@@ -11,9 +11,9 @@ from typing import Final
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.management.base import CommandError
|
||||
from django_q.tasks import async_task
|
||||
from documents.models import Tag
|
||||
from documents.parsers import is_file_ext_supported
|
||||
from documents.tasks import consume_file
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers.polling import PollingObserver
|
||||
|
||||
@@ -92,11 +92,9 @@ def _consume(filepath):
|
||||
|
||||
try:
|
||||
logger.info(f"Adding {filepath} to the task queue.")
|
||||
async_task(
|
||||
"documents.tasks.consume_file",
|
||||
consume_file.delay(
|
||||
filepath,
|
||||
override_tag_ids=tag_ids if tag_ids else None,
|
||||
task_name=os.path.basename(filepath)[:100],
|
||||
)
|
||||
except Exception:
|
||||
# Catch all so that the consumer won't crash.
|
||||
|
@@ -1,34 +1,14 @@
|
||||
# Generated by Django 3.1.3 on 2020-11-09 16:36
|
||||
|
||||
from django.db import migrations
|
||||
from django.db.migrations import RunPython
|
||||
from django_q.models import Schedule
|
||||
from django_q.tasks import schedule
|
||||
|
||||
|
||||
def add_schedules(apps, schema_editor):
|
||||
schedule(
|
||||
"documents.tasks.train_classifier",
|
||||
name="Train the classifier",
|
||||
schedule_type=Schedule.HOURLY,
|
||||
)
|
||||
schedule(
|
||||
"documents.tasks.index_optimize",
|
||||
name="Optimize the index",
|
||||
schedule_type=Schedule.DAILY,
|
||||
)
|
||||
|
||||
|
||||
def remove_schedules(apps, schema_editor):
|
||||
Schedule.objects.filter(func="documents.tasks.train_classifier").delete()
|
||||
Schedule.objects.filter(func="documents.tasks.index_optimize").delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("documents", "1000_update_paperless_all"),
|
||||
("django_q", "0013_task_attempt_count"),
|
||||
]
|
||||
|
||||
operations = [RunPython(add_schedules, remove_schedules)]
|
||||
operations = [
|
||||
migrations.RunPython(migrations.RunPython.noop, migrations.RunPython.noop)
|
||||
]
|
||||
|
@@ -2,27 +2,12 @@
|
||||
|
||||
from django.db import migrations
|
||||
from django.db.migrations import RunPython
|
||||
from django_q.models import Schedule
|
||||
from django_q.tasks import schedule
|
||||
|
||||
|
||||
def add_schedules(apps, schema_editor):
|
||||
schedule(
|
||||
"documents.tasks.sanity_check",
|
||||
name="Perform sanity check",
|
||||
schedule_type=Schedule.WEEKLY,
|
||||
)
|
||||
|
||||
|
||||
def remove_schedules(apps, schema_editor):
|
||||
Schedule.objects.filter(func="documents.tasks.sanity_check").delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("documents", "1003_mime_types"),
|
||||
("django_q", "0013_task_attempt_count"),
|
||||
]
|
||||
|
||||
operations = [RunPython(add_schedules, remove_schedules)]
|
||||
operations = [RunPython(migrations.RunPython.noop, migrations.RunPython.noop)]
|
||||
|
@@ -4,28 +4,9 @@ from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
def init_paperless_tasks(apps, schema_editor):
|
||||
PaperlessTask = apps.get_model("documents", "PaperlessTask")
|
||||
Task = apps.get_model("django_q", "Task")
|
||||
|
||||
for task in Task.objects.filter(func="documents.tasks.consume_file"):
|
||||
if not hasattr(task, "paperlesstask"):
|
||||
paperlesstask = PaperlessTask.objects.create(
|
||||
attempted_task=task,
|
||||
task_id=task.id,
|
||||
name=task.name,
|
||||
created=task.started,
|
||||
started=task.started,
|
||||
acknowledged=True,
|
||||
)
|
||||
task.paperlesstask = paperlesstask
|
||||
task.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("django_q", "0014_schedule_cluster"),
|
||||
("documents", "1021_webp_thumbnail_conversion"),
|
||||
]
|
||||
|
||||
@@ -60,10 +41,12 @@ class Migration(migrations.Migration):
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="attempted_task",
|
||||
to="django_q.task",
|
||||
# This is a dummy field, AlterField in 1026 will set to correct value
|
||||
# This manual change is required, as django doesn't django doesn't really support
|
||||
# removing an app which has migration deps like this
|
||||
to="documents.document",
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.RunPython(init_paperless_tasks, migrations.RunPython.noop),
|
||||
)
|
||||
]
|
||||
|
@@ -0,0 +1,42 @@
|
||||
# Generated by Django 4.0.7 on 2022-09-02 22:08
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("django_celery_results", "0011_taskresult_periodic_task_name"),
|
||||
("documents", "1025_alter_savedviewfilterrule_rule_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="paperlesstask",
|
||||
name="created",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="paperlesstask",
|
||||
name="name",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="paperlesstask",
|
||||
name="started",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="paperlesstask",
|
||||
name="task_id",
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="paperlesstask",
|
||||
name="attempted_task",
|
||||
field=models.OneToOneField(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="attempted_task",
|
||||
to="django_celery_results.taskresult",
|
||||
),
|
||||
),
|
||||
]
|
24
src/documents/migrations/1027_drop_django_q.py
Normal file
24
src/documents/migrations/1027_drop_django_q.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Generated by Django 4.0.7 on 2022-09-05 21:39
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("documents", "1026_remove_paperlesstask_created_and_more"),
|
||||
]
|
||||
|
||||
# Manual SQL commands to drop the django_q related tables
|
||||
# if they exist
|
||||
operations = [
|
||||
migrations.RunSQL(
|
||||
"DROP TABLE IF EXISTS django_q_ormq", reverse_sql=migrations.RunSQL.noop
|
||||
),
|
||||
migrations.RunSQL(
|
||||
"DROP TABLE IF EXISTS django_q_schedule", reverse_sql=migrations.RunSQL.noop
|
||||
),
|
||||
migrations.RunSQL(
|
||||
"DROP TABLE IF EXISTS django_q_task", reverse_sql=migrations.RunSQL.noop
|
||||
),
|
||||
]
|
@@ -12,7 +12,7 @@ from django.contrib.auth.models import User
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_q.tasks import Task
|
||||
from django_celery_results.models import TaskResult
|
||||
from documents.parsers import get_default_file_extension
|
||||
|
||||
|
||||
@@ -527,19 +527,15 @@ class UiSettings(models.Model):
|
||||
|
||||
|
||||
class PaperlessTask(models.Model):
|
||||
acknowledged = models.BooleanField(default=False)
|
||||
|
||||
task_id = models.CharField(max_length=128)
|
||||
name = models.CharField(max_length=256)
|
||||
created = models.DateTimeField(_("created"), auto_now=True)
|
||||
started = models.DateTimeField(_("started"), null=True)
|
||||
attempted_task = models.OneToOneField(
|
||||
Task,
|
||||
TaskResult,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="attempted_task",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
acknowledged = models.BooleanField(default=False)
|
||||
|
||||
|
||||
class Comment(models.Model):
|
||||
|
@@ -18,12 +18,12 @@ from .models import Correspondent
|
||||
from .models import Document
|
||||
from .models import DocumentType
|
||||
from .models import MatchingModel
|
||||
from .models import PaperlessTask
|
||||
from .models import SavedView
|
||||
from .models import SavedViewFilterRule
|
||||
from .models import StoragePath
|
||||
from .models import Tag
|
||||
from .models import UiSettings
|
||||
from .models import PaperlessTask
|
||||
from .parsers import is_mime_type_supported
|
||||
|
||||
|
||||
@@ -620,7 +620,17 @@ class TasksViewSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
depth = 1
|
||||
fields = "__all__"
|
||||
fields = (
|
||||
"id",
|
||||
"type",
|
||||
"status",
|
||||
"result",
|
||||
"acknowledged",
|
||||
"date_created",
|
||||
"date_done",
|
||||
"task_name",
|
||||
"task_id",
|
||||
)
|
||||
|
||||
type = serializers.SerializerMethodField()
|
||||
|
||||
@@ -639,17 +649,42 @@ class TasksViewSerializer(serializers.ModelSerializer):
|
||||
status = serializers.SerializerMethodField()
|
||||
|
||||
def get_status(self, obj):
|
||||
if obj.attempted_task is None:
|
||||
if obj.started:
|
||||
return "started"
|
||||
else:
|
||||
return "queued"
|
||||
elif obj.attempted_task.success:
|
||||
return "complete"
|
||||
elif not obj.attempted_task.success:
|
||||
return "failed"
|
||||
else:
|
||||
return "unknown"
|
||||
result = "unknown"
|
||||
if hasattr(obj, "attempted_task") and obj.attempted_task:
|
||||
result = obj.attempted_task.status
|
||||
return result
|
||||
|
||||
date_created = serializers.SerializerMethodField()
|
||||
|
||||
def get_date_created(self, obj):
|
||||
result = ""
|
||||
if hasattr(obj, "attempted_task") and obj.attempted_task:
|
||||
result = obj.attempted_task.date_created
|
||||
return result
|
||||
|
||||
date_done = serializers.SerializerMethodField()
|
||||
|
||||
def get_date_done(self, obj):
|
||||
result = ""
|
||||
if hasattr(obj, "attempted_task") and obj.attempted_task:
|
||||
result = obj.attempted_task.date_done
|
||||
return result
|
||||
|
||||
task_name = serializers.SerializerMethodField()
|
||||
|
||||
def get_task_name(self, obj):
|
||||
result = ""
|
||||
if hasattr(obj, "attempted_task") and obj.attempted_task:
|
||||
result = obj.attempted_task.task_name
|
||||
return result
|
||||
|
||||
task_id = serializers.SerializerMethodField()
|
||||
|
||||
def get_task_id(self, obj):
|
||||
result = ""
|
||||
if hasattr(obj, "attempted_task") and obj.attempted_task:
|
||||
result = obj.attempted_task.task_id
|
||||
return result
|
||||
|
||||
|
||||
class AcknowledgeTasksViewSerializer(serializers.Serializer):
|
||||
|
@@ -2,7 +2,6 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import django_q
|
||||
from django.conf import settings
|
||||
from django.contrib.admin.models import ADDITION
|
||||
from django.contrib.admin.models import LogEntry
|
||||
@@ -14,6 +13,7 @@ from django.db.models import Q
|
||||
from django.dispatch import receiver
|
||||
from django.utils import termcolors
|
||||
from django.utils import timezone
|
||||
from django_celery_results.models import TaskResult
|
||||
from filelock import FileLock
|
||||
|
||||
from .. import matching
|
||||
@@ -25,7 +25,6 @@ from ..models import MatchingModel
|
||||
from ..models import PaperlessTask
|
||||
from ..models import Tag
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless.handlers")
|
||||
|
||||
|
||||
@@ -503,47 +502,16 @@ def add_to_index(sender, document, **kwargs):
|
||||
index.add_or_update_document(document)
|
||||
|
||||
|
||||
@receiver(django_q.signals.pre_enqueue)
|
||||
def init_paperless_task(sender, task, **kwargs):
|
||||
if task["func"] == "documents.tasks.consume_file":
|
||||
try:
|
||||
paperless_task, created = PaperlessTask.objects.get_or_create(
|
||||
task_id=task["id"],
|
||||
)
|
||||
paperless_task.name = task["name"]
|
||||
paperless_task.created = task["started"]
|
||||
paperless_task.save()
|
||||
except Exception as e:
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
logger.error(f"Creating PaperlessTask failed: {e}")
|
||||
|
||||
|
||||
@receiver(django_q.signals.pre_execute)
|
||||
def paperless_task_started(sender, task, **kwargs):
|
||||
@receiver(models.signals.post_save, sender=TaskResult)
|
||||
def update_paperless_task(sender, instance: TaskResult, **kwargs):
|
||||
try:
|
||||
if task["func"] == "documents.tasks.consume_file":
|
||||
paperless_task, created = PaperlessTask.objects.get_or_create(
|
||||
task_id=task["id"],
|
||||
)
|
||||
paperless_task.started = timezone.now()
|
||||
paperless_task.save()
|
||||
except PaperlessTask.DoesNotExist:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Creating PaperlessTask failed: {e}")
|
||||
|
||||
|
||||
@receiver(models.signals.post_save, sender=django_q.models.Task)
|
||||
def update_paperless_task(sender, instance, **kwargs):
|
||||
try:
|
||||
if instance.func == "documents.tasks.consume_file":
|
||||
paperless_task, created = PaperlessTask.objects.get_or_create(
|
||||
task_id=instance.id,
|
||||
if instance.task_name == "documents.tasks.consume_file":
|
||||
paperless_task, _ = PaperlessTask.objects.get_or_create(
|
||||
task_id=instance.task_id,
|
||||
)
|
||||
paperless_task.attempted_task = instance
|
||||
paperless_task.save()
|
||||
except PaperlessTask.DoesNotExist:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
logger.error(f"Creating PaperlessTask failed: {e}")
|
||||
|
@@ -8,6 +8,7 @@ from typing import Type
|
||||
|
||||
import tqdm
|
||||
from asgiref.sync import async_to_sync
|
||||
from celery import shared_task
|
||||
from channels.layers import get_channel_layer
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
@@ -36,6 +37,7 @@ from whoosh.writing import AsyncWriter
|
||||
logger = logging.getLogger("paperless.tasks")
|
||||
|
||||
|
||||
@shared_task
|
||||
def index_optimize():
|
||||
ix = index.open_index()
|
||||
writer = AsyncWriter(ix)
|
||||
@@ -52,6 +54,7 @@ def index_reindex(progress_bar_disable=False):
|
||||
index.update_document(writer, document)
|
||||
|
||||
|
||||
@shared_task
|
||||
def train_classifier():
|
||||
if (
|
||||
not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
@@ -80,6 +83,7 @@ def train_classifier():
|
||||
logger.warning("Classifier error: " + str(e))
|
||||
|
||||
|
||||
@shared_task
|
||||
def consume_file(
|
||||
path,
|
||||
override_filename=None,
|
||||
@@ -171,6 +175,7 @@ def consume_file(
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
def sanity_check():
|
||||
messages = sanity_checker.check_sanity()
|
||||
|
||||
@@ -186,6 +191,7 @@ def sanity_check():
|
||||
return "No issues detected."
|
||||
|
||||
|
||||
@shared_task
|
||||
def bulk_update_documents(document_ids):
|
||||
documents = Document.objects.filter(id__in=document_ids)
|
||||
|
||||
@@ -199,6 +205,7 @@ def bulk_update_documents(document_ids):
|
||||
index.update_document(writer, doc)
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_archive_file(document_id):
|
||||
"""
|
||||
Re-creates the archive file of a document, including new OCR content and thumbnail
|
||||
|
@@ -10,6 +10,8 @@ import zipfile
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import celery
|
||||
|
||||
try:
|
||||
import zoneinfo
|
||||
except ImportError:
|
||||
@@ -32,6 +34,7 @@ from documents.models import SavedView
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import UiSettings
|
||||
from django_celery_results.models import TaskResult
|
||||
from documents.models import Comment
|
||||
from documents.models import StoragePath
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
@@ -790,7 +793,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.data["documents_inbox"], None)
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload(self, m):
|
||||
|
||||
with open(
|
||||
@@ -813,7 +816,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertIsNone(kwargs["override_document_type_id"])
|
||||
self.assertIsNone(kwargs["override_tag_ids"])
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_empty_metadata(self, m):
|
||||
|
||||
with open(
|
||||
@@ -836,7 +839,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertIsNone(kwargs["override_document_type_id"])
|
||||
self.assertIsNone(kwargs["override_tag_ids"])
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_invalid_form(self, m):
|
||||
|
||||
with open(
|
||||
@@ -850,7 +853,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
m.assert_not_called()
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_invalid_file(self, m):
|
||||
|
||||
with open(
|
||||
@@ -864,7 +867,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
m.assert_not_called()
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_title(self, async_task):
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
|
||||
@@ -882,7 +885,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
self.assertEqual(kwargs["override_title"], "my custom title")
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_correspondent(self, async_task):
|
||||
c = Correspondent.objects.create(name="test-corres")
|
||||
with open(
|
||||
@@ -901,7 +904,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
self.assertEqual(kwargs["override_correspondent_id"], c.id)
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_invalid_correspondent(self, async_task):
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
|
||||
@@ -915,7 +918,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
async_task.assert_not_called()
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_document_type(self, async_task):
|
||||
dt = DocumentType.objects.create(name="invoice")
|
||||
with open(
|
||||
@@ -934,7 +937,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
self.assertEqual(kwargs["override_document_type_id"], dt.id)
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_invalid_document_type(self, async_task):
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
|
||||
@@ -948,7 +951,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
async_task.assert_not_called()
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_tags(self, async_task):
|
||||
t1 = Tag.objects.create(name="tag1")
|
||||
t2 = Tag.objects.create(name="tag2")
|
||||
@@ -968,7 +971,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
self.assertCountEqual(kwargs["override_tag_ids"], [t1.id, t2.id])
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_invalid_tags(self, async_task):
|
||||
t1 = Tag.objects.create(name="tag1")
|
||||
t2 = Tag.objects.create(name="tag2")
|
||||
@@ -984,7 +987,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
async_task.assert_not_called()
|
||||
|
||||
@mock.patch("documents.views.async_task")
|
||||
@mock.patch("documents.views.consume_file.delay")
|
||||
def test_upload_with_created(self, async_task):
|
||||
created = datetime.datetime(
|
||||
2022,
|
||||
@@ -1615,7 +1618,7 @@ class TestBulkEdit(DirectoriesMixin, APITestCase):
|
||||
user = User.objects.create_superuser(username="temp_admin")
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
patcher = mock.patch("documents.bulk_edit.async_task")
|
||||
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
|
||||
self.async_task = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
self.c1 = Correspondent.objects.create(name="c1")
|
||||
@@ -2783,7 +2786,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
|
||||
|
||||
class TestTasks(APITestCase):
|
||||
ENDPOINT = "/api/tasks/"
|
||||
ENDPOINT_ACKOWLEDGE = "/api/acknowledge_tasks/"
|
||||
ENDPOINT_ACKNOWLEDGE = "/api/acknowledge_tasks/"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -2792,39 +2795,60 @@ class TestTasks(APITestCase):
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
def test_get_tasks(self):
|
||||
task_id1 = str(uuid.uuid4())
|
||||
PaperlessTask.objects.create(task_id=task_id1)
|
||||
Task.objects.create(
|
||||
id=task_id1,
|
||||
started=timezone.now() - datetime.timedelta(seconds=30),
|
||||
stopped=timezone.now(),
|
||||
func="documents.tasks.consume_file",
|
||||
"""
|
||||
GIVEN:
|
||||
- Attempted celery tasks
|
||||
WHEN:
|
||||
- API call is made to get tasks
|
||||
THEN:
|
||||
- Attempting and pending tasks are serialized and provided
|
||||
"""
|
||||
result1 = TaskResult.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_name="documents.tasks.some_task",
|
||||
status=celery.states.PENDING,
|
||||
)
|
||||
task_id2 = str(uuid.uuid4())
|
||||
PaperlessTask.objects.create(task_id=task_id2)
|
||||
PaperlessTask.objects.create(attempted_task=result1)
|
||||
|
||||
result2 = TaskResult.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_name="documents.tasks.other_task",
|
||||
status=celery.states.STARTED,
|
||||
)
|
||||
PaperlessTask.objects.create(attempted_task=result2)
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
from pprint import pprint
|
||||
|
||||
for x in response.data:
|
||||
pprint(x)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
returned_task1 = response.data[1]
|
||||
returned_task2 = response.data[0]
|
||||
self.assertEqual(returned_task1["task_id"], task_id1)
|
||||
self.assertEqual(returned_task1["status"], "complete")
|
||||
self.assertIsNotNone(returned_task1["attempted_task"])
|
||||
self.assertEqual(returned_task2["task_id"], task_id2)
|
||||
self.assertEqual(returned_task2["status"], "queued")
|
||||
self.assertIsNone(returned_task2["attempted_task"])
|
||||
|
||||
self.assertEqual(returned_task1["task_id"], result1.task_id)
|
||||
self.assertEqual(returned_task1["status"], celery.states.PENDING)
|
||||
self.assertEqual(returned_task1["task_name"], result1.task_name)
|
||||
|
||||
self.assertEqual(returned_task2["task_id"], result2.task_id)
|
||||
self.assertEqual(returned_task2["status"], celery.states.STARTED)
|
||||
self.assertEqual(returned_task2["task_name"], result2.task_name)
|
||||
|
||||
def test_acknowledge_tasks(self):
|
||||
task_id = str(uuid.uuid4())
|
||||
task = PaperlessTask.objects.create(task_id=task_id)
|
||||
result1 = TaskResult.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_name="documents.tasks.some_task",
|
||||
status=celery.states.PENDING,
|
||||
)
|
||||
task = PaperlessTask.objects.create(attempted_task=result1)
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
|
||||
response = self.client.post(
|
||||
self.ENDPOINT_ACKOWLEDGE,
|
||||
self.ENDPOINT_ACKNOWLEDGE,
|
||||
{"tasks": [task.id]},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
@@ -43,7 +43,7 @@ class ConsumerMixin:
|
||||
super().setUp()
|
||||
self.t = None
|
||||
patcher = mock.patch(
|
||||
"documents.management.commands.document_consumer.async_task",
|
||||
"documents.tasks.consume_file.delay",
|
||||
)
|
||||
self.task_mock = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
@@ -76,7 +76,7 @@ class ConsumerMixin:
|
||||
|
||||
# A bogus async_task that will simply check the file for
|
||||
# completeness and raise an exception otherwise.
|
||||
def bogus_task(self, func, filename, **kwargs):
|
||||
def bogus_task(self, filename, **kwargs):
|
||||
eq = filecmp.cmp(filename, self.sample_file, shallow=False)
|
||||
if not eq:
|
||||
print("Consumed an INVALID file.")
|
||||
@@ -115,7 +115,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
self.task_mock.assert_called_once()
|
||||
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], f)
|
||||
self.assertEqual(args[0], f)
|
||||
|
||||
def test_consume_file_invalid_ext(self):
|
||||
self.t_start()
|
||||
@@ -135,7 +135,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
self.task_mock.assert_called_once()
|
||||
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], f)
|
||||
self.assertEqual(args[0], f)
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.error")
|
||||
def test_slow_write_pdf(self, error_logger):
|
||||
@@ -155,7 +155,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
self.task_mock.assert_called_once()
|
||||
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], fname)
|
||||
self.assertEqual(args[0], fname)
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.error")
|
||||
def test_slow_write_and_move(self, error_logger):
|
||||
@@ -175,7 +175,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
self.task_mock.assert_called_once()
|
||||
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], fname2)
|
||||
self.assertEqual(args[0], fname2)
|
||||
|
||||
error_logger.assert_not_called()
|
||||
|
||||
@@ -193,7 +193,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
|
||||
self.task_mock.assert_called_once()
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], fname)
|
||||
self.assertEqual(args[0], fname)
|
||||
|
||||
# assert that we have an error logged with this invalid file.
|
||||
error_logger.assert_called_once()
|
||||
@@ -241,7 +241,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
self.assertEqual(2, self.task_mock.call_count)
|
||||
|
||||
fnames = [
|
||||
os.path.basename(args[1]) for args, _ in self.task_mock.call_args_list
|
||||
os.path.basename(args[0]) for args, _ in self.task_mock.call_args_list
|
||||
]
|
||||
self.assertCountEqual(fnames, ["my_file.pdf", "my_second_file.pdf"])
|
||||
|
||||
@@ -338,7 +338,7 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
|
||||
tag_ids.append(Tag.objects.get(name=tag_names[1]).pk)
|
||||
|
||||
args, kwargs = self.task_mock.call_args
|
||||
self.assertEqual(args[1], f)
|
||||
self.assertEqual(args[0], f)
|
||||
|
||||
# assertCountEqual has a bad name, but test that the first
|
||||
# sequence contains the same elements as second, regardless of
|
||||
|
@@ -28,7 +28,7 @@ from django.utils.translation import get_language
|
||||
from django.views.decorators.cache import cache_control
|
||||
from django.views.generic import TemplateView
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from django_q.tasks import async_task
|
||||
from documents.tasks import consume_file
|
||||
from packaging import version as packaging_version
|
||||
from paperless import version
|
||||
from paperless.db import GnuPG
|
||||
@@ -612,8 +612,7 @@ class PostDocumentView(GenericAPIView):
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
async_task(
|
||||
"documents.tasks.consume_file",
|
||||
consume_file.delay(
|
||||
temp_filename,
|
||||
override_filename=doc_name,
|
||||
override_title=title,
|
||||
@@ -621,7 +620,6 @@ class PostDocumentView(GenericAPIView):
|
||||
override_document_type_id=document_type_id,
|
||||
override_tag_ids=tag_ids,
|
||||
task_id=task_id,
|
||||
task_name=os.path.basename(doc_name)[:100],
|
||||
override_created=created,
|
||||
)
|
||||
|
||||
@@ -882,7 +880,7 @@ class TasksViewSet(ReadOnlyModelViewSet):
|
||||
PaperlessTask.objects.filter(
|
||||
acknowledged=False,
|
||||
)
|
||||
.order_by("created")
|
||||
.order_by("attempted_task__date_created")
|
||||
.reverse()
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user