mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Merge pull request #4037 from andreheuer/dev
Enhancement: add task id to pre/post consume script as env
This commit is contained in:
@@ -209,6 +209,7 @@ class Consumer(LoggingMixin):
|
||||
script_env = os.environ.copy()
|
||||
script_env["DOCUMENT_SOURCE_PATH"] = original_file_path
|
||||
script_env["DOCUMENT_WORKING_PATH"] = working_file_path
|
||||
script_env["TASK_ID"] = self.task_id or ""
|
||||
|
||||
try:
|
||||
completed_proc = run(
|
||||
@@ -279,6 +280,7 @@ class Consumer(LoggingMixin):
|
||||
",".join(document.tags.all().values_list("name", flat=True)),
|
||||
)
|
||||
script_env["DOCUMENT_ORIGINAL_FILENAME"] = str(document.original_filename)
|
||||
script_env["TASK_ID"] = self.task_id or ""
|
||||
|
||||
try:
|
||||
completed_proc = run(
|
||||
|
@@ -7,6 +7,7 @@ from typing import Type
|
||||
|
||||
import tqdm
|
||||
from asgiref.sync import async_to_sync
|
||||
from celery import Task
|
||||
from celery import shared_task
|
||||
from channels.layers import get_channel_layer
|
||||
from django.conf import settings
|
||||
@@ -91,8 +92,9 @@ def train_classifier():
|
||||
logger.warning("Classifier error: " + str(e))
|
||||
|
||||
|
||||
@shared_task
|
||||
@shared_task(bind=True)
|
||||
def consume_file(
|
||||
self: Task,
|
||||
input_doc: ConsumableDocument,
|
||||
overrides: Optional[DocumentMetadataOverrides] = None,
|
||||
):
|
||||
@@ -163,6 +165,7 @@ def consume_file(
|
||||
override_created=overrides.created,
|
||||
override_asn=overrides.asn,
|
||||
override_owner_id=overrides.owner_id,
|
||||
task_id=self.request.id,
|
||||
)
|
||||
|
||||
if document:
|
||||
|
@@ -4,6 +4,7 @@ import re
|
||||
import shutil
|
||||
import stat
|
||||
import tempfile
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -862,6 +863,7 @@ class PreConsumeTestCase(TestCase):
|
||||
c = Consumer()
|
||||
c.original_path = "path-to-file"
|
||||
c.path = "/tmp/somewhere/path-to-file"
|
||||
c.task_id = str(uuid.uuid4())
|
||||
c.run_pre_consume_script()
|
||||
|
||||
m.assert_called_once()
|
||||
@@ -877,6 +879,7 @@ class PreConsumeTestCase(TestCase):
|
||||
subset = {
|
||||
"DOCUMENT_SOURCE_PATH": c.original_path,
|
||||
"DOCUMENT_WORKING_PATH": c.path,
|
||||
"TASK_ID": c.task_id,
|
||||
}
|
||||
self.assertDictEqual(environment, {**environment, **subset})
|
||||
|
||||
@@ -937,7 +940,10 @@ class PreConsumeTestCase(TestCase):
|
||||
with override_settings(PRE_CONSUME_SCRIPT=script.name):
|
||||
c = Consumer()
|
||||
c.path = "path-to-file"
|
||||
self.assertRaises(ConsumerError, c.run_pre_consume_script)
|
||||
self.assertRaises(
|
||||
ConsumerError,
|
||||
c.run_pre_consume_script,
|
||||
)
|
||||
|
||||
|
||||
class PostConsumeTestCase(TestCase):
|
||||
@@ -968,7 +974,11 @@ class PostConsumeTestCase(TestCase):
|
||||
doc = Document.objects.create(title="Test", mime_type="application/pdf")
|
||||
c = Consumer()
|
||||
c.filename = "somefile.pdf"
|
||||
self.assertRaises(ConsumerError, c.run_post_consume_script, doc)
|
||||
self.assertRaises(
|
||||
ConsumerError,
|
||||
c.run_post_consume_script,
|
||||
doc,
|
||||
)
|
||||
|
||||
@mock.patch("documents.consumer.run")
|
||||
def test_post_consume_script_simple(self, m):
|
||||
@@ -995,7 +1005,9 @@ class PostConsumeTestCase(TestCase):
|
||||
doc.tags.add(tag1)
|
||||
doc.tags.add(tag2)
|
||||
|
||||
Consumer().run_post_consume_script(doc)
|
||||
consumer = Consumer()
|
||||
consumer.task_id = str(uuid.uuid4())
|
||||
consumer.run_post_consume_script(doc)
|
||||
|
||||
m.assert_called_once()
|
||||
|
||||
@@ -1017,6 +1029,7 @@ class PostConsumeTestCase(TestCase):
|
||||
"DOCUMENT_THUMBNAIL_URL": f"/api/documents/{doc.pk}/thumb/",
|
||||
"DOCUMENT_CORRESPONDENT": "my_bank",
|
||||
"DOCUMENT_TAGS": "a,b",
|
||||
"TASK_ID": consumer.task_id,
|
||||
}
|
||||
|
||||
self.assertDictEqual(environment, {**environment, **subset})
|
||||
|
Reference in New Issue
Block a user