Added task id to pre/post consume script as env

This commit is contained in:
André Heuer 2023-08-20 20:55:18 +02:00 committed by shamoon
parent 407a119b9a
commit 8f8a99a645
3 changed files with 33 additions and 15 deletions

View File

@ -186,7 +186,7 @@ class Consumer(LoggingMixin):
f"Not consuming {self.filename}: Given ASN already exists!", f"Not consuming {self.filename}: Given ASN already exists!",
) )
def run_pre_consume_script(self): def run_pre_consume_script(self, task_id):
""" """
If one is configured and exists, run the pre-consume script and If one is configured and exists, run the pre-consume script and
handle its output and/or errors handle its output and/or errors
@ -209,6 +209,7 @@ class Consumer(LoggingMixin):
script_env = os.environ.copy() script_env = os.environ.copy()
script_env["DOCUMENT_SOURCE_PATH"] = original_file_path script_env["DOCUMENT_SOURCE_PATH"] = original_file_path
script_env["DOCUMENT_WORKING_PATH"] = working_file_path script_env["DOCUMENT_WORKING_PATH"] = working_file_path
script_env["TASK_ID"] = task_id
try: try:
completed_proc = run( completed_proc = run(
@ -233,7 +234,7 @@ class Consumer(LoggingMixin):
exception=e, exception=e,
) )
def run_post_consume_script(self, document: Document): def run_post_consume_script(self, document: Document, task_id):
""" """
If one is configured and exists, run the pre-consume script and If one is configured and exists, run the pre-consume script and
handle its output and/or errors handle its output and/or errors
@ -279,6 +280,7 @@ class Consumer(LoggingMixin):
",".join(document.tags.all().values_list("name", flat=True)), ",".join(document.tags.all().values_list("name", flat=True)),
) )
script_env["DOCUMENT_ORIGINAL_FILENAME"] = str(document.original_filename) script_env["DOCUMENT_ORIGINAL_FILENAME"] = str(document.original_filename)
script_env["TASK_ID"] = task_id
try: try:
completed_proc = run( completed_proc = run(
@ -388,7 +390,7 @@ class Consumer(LoggingMixin):
logging_group=self.logging_group, logging_group=self.logging_group,
) )
self.run_pre_consume_script() self.run_pre_consume_script(task_id=self.task_id)
def progress_callback(current_progress, max_progress): # pragma: no cover def progress_callback(current_progress, max_progress): # pragma: no cover
# recalculate progress to be within 20 and 80 # recalculate progress to be within 20 and 80
@ -553,7 +555,7 @@ class Consumer(LoggingMixin):
document_parser.cleanup() document_parser.cleanup()
tempdir.cleanup() tempdir.cleanup()
self.run_post_consume_script(document) self.run_post_consume_script(document, task_id=self.task_id)
self.log.info(f"Document {document} consumption finished") self.log.info(f"Document {document} consumption finished")

View File

@ -91,8 +91,9 @@ def train_classifier():
logger.warning("Classifier error: " + str(e)) logger.warning("Classifier error: " + str(e))
@shared_task @shared_task(bind=True)
def consume_file( def consume_file(
self,
input_doc: ConsumableDocument, input_doc: ConsumableDocument,
overrides: Optional[DocumentMetadataOverrides] = None, overrides: Optional[DocumentMetadataOverrides] = None,
): ):
@ -163,6 +164,7 @@ def consume_file(
override_created=overrides.created, override_created=overrides.created,
override_asn=overrides.asn, override_asn=overrides.asn,
override_owner_id=overrides.owner_id, override_owner_id=overrides.owner_id,
task_id=self.request.id,
) )
if document: if document:

View File

@ -4,6 +4,7 @@ import re
import shutil import shutil
import stat import stat
import tempfile import tempfile
import uuid
from unittest import mock from unittest import mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -802,7 +803,7 @@ class PreConsumeTestCase(TestCase):
def test_no_pre_consume_script(self, m): def test_no_pre_consume_script(self, m):
c = Consumer() c = Consumer()
c.path = "path-to-file" c.path = "path-to-file"
c.run_pre_consume_script() c.run_pre_consume_script(str(uuid.uuid4()))
m.assert_not_called() m.assert_not_called()
@mock.patch("documents.consumer.run") @mock.patch("documents.consumer.run")
@ -812,7 +813,7 @@ class PreConsumeTestCase(TestCase):
c = Consumer() c = Consumer()
c.filename = "somefile.pdf" c.filename = "somefile.pdf"
c.path = "path-to-file" c.path = "path-to-file"
self.assertRaises(ConsumerError, c.run_pre_consume_script) self.assertRaises(ConsumerError, c.run_pre_consume_script, str(uuid.uuid4()))
@mock.patch("documents.consumer.run") @mock.patch("documents.consumer.run")
def test_pre_consume_script(self, m): def test_pre_consume_script(self, m):
@ -821,7 +822,8 @@ class PreConsumeTestCase(TestCase):
c = Consumer() c = Consumer()
c.original_path = "path-to-file" c.original_path = "path-to-file"
c.path = "/tmp/somewhere/path-to-file" c.path = "/tmp/somewhere/path-to-file"
c.run_pre_consume_script() task_id = str(uuid.uuid4())
c.run_pre_consume_script(task_id)
m.assert_called_once() m.assert_called_once()
@ -836,6 +838,7 @@ class PreConsumeTestCase(TestCase):
subset = { subset = {
"DOCUMENT_SOURCE_PATH": c.original_path, "DOCUMENT_SOURCE_PATH": c.original_path,
"DOCUMENT_WORKING_PATH": c.path, "DOCUMENT_WORKING_PATH": c.path,
"TASK_ID": task_id,
} }
self.assertDictEqual(environment, {**environment, **subset}) self.assertDictEqual(environment, {**environment, **subset})
@ -864,7 +867,7 @@ class PreConsumeTestCase(TestCase):
c = Consumer() c = Consumer()
c.path = "path-to-file" c.path = "path-to-file"
c.run_pre_consume_script() c.run_pre_consume_script(str(uuid.uuid4()))
self.assertIn( self.assertIn(
"INFO:paperless.consumer:This message goes to stdout", "INFO:paperless.consumer:This message goes to stdout",
cm.output, cm.output,
@ -896,7 +899,11 @@ class PreConsumeTestCase(TestCase):
with override_settings(PRE_CONSUME_SCRIPT=script.name): with override_settings(PRE_CONSUME_SCRIPT=script.name):
c = Consumer() c = Consumer()
c.path = "path-to-file" c.path = "path-to-file"
self.assertRaises(ConsumerError, c.run_pre_consume_script) self.assertRaises(
ConsumerError,
c.run_pre_consume_script,
str(uuid.uuid4()),
)
class PostConsumeTestCase(TestCase): class PostConsumeTestCase(TestCase):
@ -917,7 +924,7 @@ class PostConsumeTestCase(TestCase):
doc.tags.add(tag1) doc.tags.add(tag1)
doc.tags.add(tag2) doc.tags.add(tag2)
Consumer().run_post_consume_script(doc) Consumer().run_post_consume_script(doc, str(uuid.uuid4()))
m.assert_not_called() m.assert_not_called()
@ -927,7 +934,12 @@ class PostConsumeTestCase(TestCase):
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
c = Consumer() c = Consumer()
c.filename = "somefile.pdf" c.filename = "somefile.pdf"
self.assertRaises(ConsumerError, c.run_post_consume_script, doc) self.assertRaises(
ConsumerError,
c.run_post_consume_script,
doc,
str(uuid.uuid4()),
)
@mock.patch("documents.consumer.run") @mock.patch("documents.consumer.run")
def test_post_consume_script_simple(self, m): def test_post_consume_script_simple(self, m):
@ -935,7 +947,7 @@ class PostConsumeTestCase(TestCase):
with override_settings(POST_CONSUME_SCRIPT=script.name): with override_settings(POST_CONSUME_SCRIPT=script.name):
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
Consumer().run_post_consume_script(doc) Consumer().run_post_consume_script(doc, str(uuid.uuid4()))
m.assert_called_once() m.assert_called_once()
@ -953,8 +965,9 @@ class PostConsumeTestCase(TestCase):
tag2 = Tag.objects.create(name="b") tag2 = Tag.objects.create(name="b")
doc.tags.add(tag1) doc.tags.add(tag1)
doc.tags.add(tag2) doc.tags.add(tag2)
task_id = str(uuid.uuid4())
Consumer().run_post_consume_script(doc) Consumer().run_post_consume_script(doc, task_id)
m.assert_called_once() m.assert_called_once()
@ -976,6 +989,7 @@ class PostConsumeTestCase(TestCase):
"DOCUMENT_THUMBNAIL_URL": f"/api/documents/{doc.pk}/thumb/", "DOCUMENT_THUMBNAIL_URL": f"/api/documents/{doc.pk}/thumb/",
"DOCUMENT_CORRESPONDENT": "my_bank", "DOCUMENT_CORRESPONDENT": "my_bank",
"DOCUMENT_TAGS": "a,b", "DOCUMENT_TAGS": "a,b",
"TASK_ID": task_id,
} }
self.assertDictEqual(environment, {**environment, **subset}) self.assertDictEqual(environment, {**environment, **subset})
@ -1004,4 +1018,4 @@ class PostConsumeTestCase(TestCase):
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
c.path = "path-to-file" c.path = "path-to-file"
with self.assertRaises(ConsumerError): with self.assertRaises(ConsumerError):
c.run_post_consume_script(doc) c.run_post_consume_script(doc, str(uuid.uuid4()))