diff --git a/src/documents/barcodes.py b/src/documents/barcodes.py index 1f520c546..82b81fecd 100644 --- a/src/documents/barcodes.py +++ b/src/documents/barcodes.py @@ -11,7 +11,6 @@ from typing import List from typing import Optional import img2pdf -import magic from django.conf import settings from pdf2image import convert_from_path from pdf2image.exceptions import PDFPageCountError @@ -63,7 +62,7 @@ class DocumentBarcodeInfo: @lru_cache(maxsize=8) -def supported_file_type(mime_type) -> bool: +def supported_file_type(mime_type: str) -> bool: """ Determines if the file is valid for barcode processing, based on MIME type and settings @@ -115,33 +114,16 @@ def barcode_reader(image: Image) -> List[str]: return barcodes -def get_file_mime_type(path: Path) -> str: - """ - Determines the file type, based on MIME type. - - Returns the MIME type. - """ - mime_type = magic.from_file(path, mime=True) - logger.debug(f"Detected mime type: {mime_type}") - return mime_type - - def convert_from_tiff_to_pdf(filepath: Path) -> Path: """ converts a given TIFF image file to pdf into a temporary directory. Returns the new pdf file. """ - mime_type = get_file_mime_type(filepath) tempdir = tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR) # use old file name with pdf extension - if mime_type == "image/tiff": - newpath = Path(tempdir) / Path(filepath.name).with_suffix(".pdf") - else: - logger.warning( - f"Cannot convert mime type {mime_type} from {filepath} to pdf.", - ) - return None + newpath = Path(tempdir) / Path(filepath.name).with_suffix(".pdf") + with Image.open(filepath) as im: has_alpha_layer = im.mode in ("RGBA", "LA") if has_alpha_layer: @@ -162,6 +144,7 @@ def convert_from_tiff_to_pdf(filepath: Path) -> Path: def scan_file_for_barcodes( filepath: Path, + mime_type: str, ) -> DocumentBarcodeInfo: """ Scan the provided pdf file for any barcodes @@ -186,7 +169,6 @@ def scan_file_for_barcodes( return detected_barcodes pdf_filepath = None - mime_type = get_file_mime_type(filepath) barcodes = [] if supported_file_type(mime_type): diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 175c80876..797345ba6 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -284,7 +284,7 @@ class Consumer(LoggingMixin): def try_consume_file( self, - path, + path: Path, override_filename=None, override_title=None, override_correspondent_id=None, diff --git a/src/documents/data_models.py b/src/documents/data_models.py new file mode 100644 index 000000000..f904743d4 --- /dev/null +++ b/src/documents/data_models.py @@ -0,0 +1,62 @@ +import dataclasses +import datetime +import enum +from pathlib import Path +from typing import List +from typing import Optional + +import magic + + +@dataclasses.dataclass +class DocumentMetadataOverrides: + """ + Manages overrides for document fields which normally would + be set from content or matching. All fields default to None, + meaning no override is happening + """ + + filename: Optional[str] = None + title: Optional[str] = None + correspondent_id: Optional[int] = None + document_type_id: Optional[int] = None + tag_ids: Optional[List[int]] = None + created: Optional[datetime.datetime] = None + asn: Optional[int] = None + owner_id: Optional[int] = None + + +class DocumentSource(enum.IntEnum): + """ + The source of an incoming document. May have other uses in the future + """ + + ConsumeFolder = enum.auto() + ApiUpload = enum.auto() + MailFetch = enum.auto() + + +@dataclasses.dataclass +class ConsumableDocument: + """ + Encapsulates an incoming document, either from consume folder, API upload + or mail fetching and certain useful operations on it. + """ + + source: DocumentSource + original_file: Path + mime_type: str = dataclasses.field(init=False, default=None) + + def __post_init__(self): + """ + After a dataclass is initialized, this is called to finalize some data + 1. Make sure the original path is an absolute, fully qualified path + 2. Get the mime type of the file + """ + # Always fully qualify the path first thing + # Just in case, convert to a path if it's a str + self.original_file = Path(self.original_file).resolve() + + # Get the file type once at init + # Note this function isn't called when the object is unpickled + self.mime_type = magic.from_file(self.original_file, mime=True) diff --git a/src/documents/management/commands/document_consumer.py b/src/documents/management/commands/document_consumer.py index d4ace3f1b..27749ea7c 100644 --- a/src/documents/management/commands/document_consumer.py +++ b/src/documents/management/commands/document_consumer.py @@ -13,6 +13,9 @@ from typing import Set from django.conf import settings from django.core.management.base import BaseCommand from django.core.management.base import CommandError +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.models import Tag from documents.parsers import is_file_ext_supported from documents.tasks import consume_file @@ -122,8 +125,11 @@ def _consume(filepath: str) -> None: try: logger.info(f"Adding {filepath} to the task queue.") consume_file.delay( - filepath, - override_tag_ids=list(tag_ids) if tag_ids else None, + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + ), + DocumentMetadataOverrides(tag_ids=tag_ids), ) except Exception: # Catch all so that the consumer won't crash. diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 670ceae64..92f8e6159 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -1,7 +1,6 @@ import logging import os import shutil -from pathlib import Path from celery import states from celery.signals import before_task_publish @@ -533,17 +532,9 @@ def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs): try: task_args = body[0] - task_kwargs = body[1] + input_doc, _ = task_args - task_file_name = "" - if "override_filename" in task_kwargs: - task_file_name = task_kwargs["override_filename"] - - # Nothing was found, report the task first argument - if not len(task_file_name): - # There are always some arguments to the consume, first is always filename - filepath = Path(task_args[0]) - task_file_name = filepath.name + task_file_name = input_doc.original_file.name PaperlessTask.objects.create( task_id=headers["id"], diff --git a/src/documents/tasks.py b/src/documents/tasks.py index fbc754e52..5c300bca2 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -1,13 +1,10 @@ import hashlib import logging -import os import shutil import uuid -from pathlib import Path from typing import Optional from typing import Type -import dateutil.parser import tqdm from asgiref.sync import async_to_sync from celery import shared_task @@ -22,6 +19,9 @@ from documents.classifier import DocumentClassifier from documents.classifier import load_classifier from documents.consumer import Consumer from documents.consumer import ConsumerError +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.file_handling import create_source_path_directory from documents.file_handling import generate_unique_filename from documents.models import Correspondent @@ -88,34 +88,20 @@ def train_classifier(): @shared_task def consume_file( - path, - override_filename=None, - override_title=None, - override_correspondent_id=None, - override_document_type_id=None, - override_tag_ids=None, - task_id=None, - override_created=None, - override_owner_id=None, - override_archive_serial_num: Optional[int] = None, + input_doc: ConsumableDocument, + overrides: Optional[DocumentMetadataOverrides] = None, ): - path = Path(path).resolve() - asn = None - - # Celery converts this to a string, but everything expects a datetime - # Long term solution is to not use JSON for the serializer but pickle instead - # TODO: This will be resolved in kombu 5.3, expected with celery 5.3 - # More types will be retained through JSON encode/decode - if override_created is not None and isinstance(override_created, str): - try: - override_created = dateutil.parser.isoparse(override_created) - except Exception: - pass + # Default no overrides + if overrides is None: + overrides = DocumentMetadataOverrides() # read all barcodes in the current document if settings.CONSUMER_ENABLE_BARCODES or settings.CONSUMER_ENABLE_ASN_BARCODE: - doc_barcode_info = barcodes.scan_file_for_barcodes(path) + doc_barcode_info = barcodes.scan_file_for_barcodes( + input_doc.original_file, + input_doc.mime_type, + ) # split document by separator pages, if enabled if settings.CONSUMER_ENABLE_BARCODES: @@ -123,7 +109,7 @@ def consume_file( if len(separators) > 0: logger.debug( - f"Pages with separators found in: {str(path)}", + f"Pages with separators found in: {input_doc.original_file}", ) document_list = barcodes.separate_pages( doc_barcode_info.pdf_path, @@ -136,18 +122,20 @@ def consume_file( # Move it to consume directory to be picked up # Otherwise, use the current parent to keep possible tags # from subdirectories - try: - # is_relative_to would be nicer, but new in 3.9 - _ = path.relative_to(settings.SCRATCH_DIR) + if input_doc.source != DocumentSource.ConsumeFolder: save_to_dir = settings.CONSUMPTION_DIR - except ValueError: - save_to_dir = path.parent + else: + # Note this uses the original file, because it's in the + # consume folder already and may include additional path + # components for tagging + # the .path is somewhere in scratch in this case + save_to_dir = input_doc.original_file.parent for n, document in enumerate(document_list): # save to consumption dir # rename it to the original filename with number prefix - if override_filename: - newname = f"{str(n)}_" + override_filename + if overrides.filename is not None: + newname = f"{str(n)}_{overrides.filename}" else: newname = None @@ -158,24 +146,27 @@ def consume_file( ) # Split file has been copied safely, remove it - os.remove(document) + document.unlink() # And clean up the directory as well, now it's empty - shutil.rmtree(os.path.dirname(document_list[0])) + shutil.rmtree(document_list[0].parent) - # Delete the PDF file which was split - os.remove(doc_barcode_info.pdf_path) + # This file has been split into multiple files without issue + # remove the original and working copy + input_doc.original_file.unlink() - # If the original was a TIFF, remove the original file as well - if str(doc_barcode_info.pdf_path) != str(path): - logger.debug(f"Deleting file {path}") - os.unlink(path) + # If the original file was a TIFF, remove the PDF generated from it + if input_doc.mime_type == "image/tiff": + logger.debug( + f"Deleting file {doc_barcode_info.pdf_path}", + ) + doc_barcode_info.pdf_path.unlink() # notify the sender, otherwise the progress bar # in the UI stays stuck payload = { - "filename": override_filename or path.name, - "task_id": task_id, + "filename": overrides.filename or input_doc.original_file.name, + "task_id": None, "current_progress": 100, "max_progress": 100, "status": "SUCCESS", @@ -194,22 +185,21 @@ def consume_file( # try reading the ASN from barcode if settings.CONSUMER_ENABLE_ASN_BARCODE: - asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) - if asn: - logger.info(f"Found ASN in barcode: {asn}") + overrides.asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) + if overrides.asn: + logger.info(f"Found ASN in barcode: {overrides.asn}") # continue with consumption if no barcode was found document = Consumer().try_consume_file( - path, - override_filename=override_filename, - override_title=override_title, - override_correspondent_id=override_correspondent_id, - override_document_type_id=override_document_type_id, - override_tag_ids=override_tag_ids, - task_id=task_id, - override_created=override_created, - override_asn=override_archive_serial_num or asn, - override_owner_id=override_owner_id, + input_doc.original_file, + override_filename=overrides.filename, + override_title=overrides.title, + override_correspondent_id=overrides.correspondent_id, + override_document_type_id=overrides.document_type_id, + override_tag_ids=overrides.tag_ids, + override_created=overrides.created, + override_asn=overrides.asn, + override_owner_id=overrides.owner_id, ) if document: diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 958f0b3fe..da60ab3c4 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -32,6 +32,7 @@ from documents import bulk_edit from documents import index from documents.models import Correspondent from documents.models import Document +from documents.tests.utils import DocumentConsumeDelayMixin from documents.models import DocumentType from documents.models import MatchingModel from documents.models import PaperlessTask @@ -45,7 +46,7 @@ from rest_framework.test import APITestCase from whoosh.writing import AsyncWriter -class TestDocumentApi(DirectoriesMixin, APITestCase): +class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): def setUp(self): super().setUp() @@ -1085,10 +1086,11 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): self.assertEqual(response.data["documents_inbox"], None) self.assertEqual(response.data["inbox_tag"], None) - @mock.patch("documents.views.consume_file.delay") - def test_upload(self, m): + def test_upload(self): - m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1101,21 +1103,22 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - m.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = m.call_args - file_path = Path(args[0]) - self.assertEqual(file_path.name, "simple.pdf") - self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents) - self.assertIsNone(kwargs["override_title"]) - self.assertIsNone(kwargs["override_correspondent_id"]) - self.assertIsNone(kwargs["override_document_type_id"]) - self.assertIsNone(kwargs["override_tag_ids"]) + input_doc, overrides = self.get_last_consume_delay_call_args() - @mock.patch("documents.views.consume_file.delay") - def test_upload_empty_metadata(self, m): + self.assertEqual(input_doc.original_file.name, "simple.pdf") + self.assertIn(Path(settings.SCRATCH_DIR), input_doc.original_file.parents) + self.assertIsNone(overrides.title) + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.tag_ids) - m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + def test_upload_empty_metadata(self): + + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1128,21 +1131,22 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - m.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = m.call_args - file_path = Path(args[0]) - self.assertEqual(file_path.name, "simple.pdf") - self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents) - self.assertIsNone(kwargs["override_title"]) - self.assertIsNone(kwargs["override_correspondent_id"]) - self.assertIsNone(kwargs["override_document_type_id"]) - self.assertIsNone(kwargs["override_tag_ids"]) + input_doc, overrides = self.get_last_consume_delay_call_args() - @mock.patch("documents.views.consume_file.delay") - def test_upload_invalid_form(self, m): + self.assertEqual(input_doc.original_file.name, "simple.pdf") + self.assertIn(Path(settings.SCRATCH_DIR), input_doc.original_file.parents) + self.assertIsNone(overrides.title) + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.tag_ids) - m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + def test_upload_invalid_form(self): + + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1153,12 +1157,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): {"documenst": f}, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - m.assert_not_called() + self.consume_file_mock.assert_not_called() - @mock.patch("documents.views.consume_file.delay") - def test_upload_invalid_file(self, m): + def test_upload_invalid_file(self): - m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.zip"), @@ -1169,12 +1174,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): {"document": f}, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - m.assert_not_called() + self.consume_file_mock.assert_not_called() - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_title(self, async_task): + def test_upload_with_title(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1186,16 +1192,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - async_task.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = async_task.call_args + _, overrides = self.get_last_consume_delay_call_args() - self.assertEqual(kwargs["override_title"], "my custom title") + self.assertEqual(overrides.title, "my custom title") + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.tag_ids) - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_correspondent(self, async_task): + def test_upload_with_correspondent(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) c = Correspondent.objects.create(name="test-corres") with open( @@ -1208,16 +1218,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - async_task.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = async_task.call_args + _, overrides = self.get_last_consume_delay_call_args() - self.assertEqual(kwargs["override_correspondent_id"], c.id) + self.assertEqual(overrides.correspondent_id, c.id) + self.assertIsNone(overrides.title) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.tag_ids) - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_invalid_correspondent(self, async_task): + def test_upload_with_invalid_correspondent(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1229,12 +1243,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - async_task.assert_not_called() + self.consume_file_mock.assert_not_called() - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_document_type(self, async_task): + def test_upload_with_document_type(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) dt = DocumentType.objects.create(name="invoice") with open( @@ -1247,16 +1262,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - async_task.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = async_task.call_args + _, overrides = self.get_last_consume_delay_call_args() - self.assertEqual(kwargs["override_document_type_id"], dt.id) + self.assertEqual(overrides.document_type_id, dt.id) + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.title) + self.assertIsNone(overrides.tag_ids) - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_invalid_document_type(self, async_task): + def test_upload_with_invalid_document_type(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1268,12 +1287,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - async_task.assert_not_called() + self.consume_file_mock.assert_not_called() - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_tags(self, async_task): + def test_upload_with_tags(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) t1 = Tag.objects.create(name="tag1") t2 = Tag.objects.create(name="tag2") @@ -1287,16 +1307,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - async_task.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = async_task.call_args + _, overrides = self.get_last_consume_delay_call_args() - self.assertCountEqual(kwargs["override_tag_ids"], [t1.id, t2.id]) + self.assertCountEqual(overrides.tag_ids, [t1.id, t2.id]) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.title) - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_invalid_tags(self, async_task): + def test_upload_with_invalid_tags(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) t1 = Tag.objects.create(name="tag1") t2 = Tag.objects.create(name="tag2") @@ -1310,12 +1334,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - async_task.assert_not_called() + self.consume_file_mock.assert_not_called() - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_created(self, async_task): + def test_upload_with_created(self): - async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) created = datetime.datetime( 2022, @@ -1337,16 +1362,17 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - async_task.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = async_task.call_args + _, overrides = self.get_last_consume_delay_call_args() - self.assertEqual(kwargs["override_created"], created) + self.assertEqual(overrides.created, created) - @mock.patch("documents.views.consume_file.delay") - def test_upload_with_asn(self, m): + def test_upload_with_asn(self): - m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4())) + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) with open( os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), @@ -1359,17 +1385,16 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - m.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = m.call_args - file_path = Path(args[0]) - self.assertEqual(file_path.name, "simple.pdf") - self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents) - self.assertIsNone(kwargs["override_title"]) - self.assertIsNone(kwargs["override_correspondent_id"]) - self.assertIsNone(kwargs["override_document_type_id"]) - self.assertIsNone(kwargs["override_tag_ids"]) - self.assertEqual(500, kwargs["override_archive_serial_num"]) + input_doc, overrides = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file.name, "simple.pdf") + self.assertEqual(overrides.filename, "simple.pdf") + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.tag_ids) + self.assertEqual(500, overrides.asn) def test_get_metadata(self): doc = Document.objects.create( diff --git a/src/documents/tests/test_barcodes.py b/src/documents/tests/test_barcodes.py index a1e08c5cf..975a3cc1b 100644 --- a/src/documents/tests/test_barcodes.py +++ b/src/documents/tests/test_barcodes.py @@ -10,6 +10,9 @@ from django.test import TestCase from documents import barcodes from documents import tasks from documents.consumer import ConsumerError +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.tests.utils import DirectoriesMixin from documents.tests.utils import FileSystemAssertsMixin from PIL import Image @@ -183,46 +186,14 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): img = Image.open(test_file) self.assertEqual(barcodes.barcode_reader(img), ["CUSTOM BARCODE"]) - def test_get_mime_type(self): - """ - GIVEN: - - - WHEN: - - - THEN: - - - """ - tiff_file = self.SAMPLE_DIR / "simple.tiff" - - pdf_file = self.SAMPLE_DIR / "simple.pdf" - - png_file = self.BARCODE_SAMPLE_DIR / "barcode-128-custom.png" - - tiff_file_no_extension = settings.SCRATCH_DIR / "testfile1" - pdf_file_no_extension = settings.SCRATCH_DIR / "testfile2" - shutil.copy(tiff_file, tiff_file_no_extension) - shutil.copy(pdf_file, pdf_file_no_extension) - - self.assertEqual(barcodes.get_file_mime_type(tiff_file), "image/tiff") - self.assertEqual(barcodes.get_file_mime_type(pdf_file), "application/pdf") - self.assertEqual( - barcodes.get_file_mime_type(tiff_file_no_extension), - "image/tiff", - ) - self.assertEqual( - barcodes.get_file_mime_type(pdf_file_no_extension), - "application/pdf", - ) - self.assertEqual(barcodes.get_file_mime_type(png_file), "image/png") - def test_convert_from_tiff_to_pdf(self): """ GIVEN: - - + - Multi-page TIFF image WHEN: - - + - Conversion to PDF THEN: - - + - The file converts without error """ test_file = self.SAMPLE_DIR / "simple.tiff" @@ -233,34 +204,20 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertIsFile(target_file) self.assertEqual(target_file.suffix, ".pdf") - def test_convert_error_from_pdf_to_pdf(self): - """ - GIVEN: - - - WHEN: - - - THEN: - - - """ - test_file = self.SAMPLE_DIR / "simple.pdf" - - dst = settings.SCRATCH_DIR / "simple.pdf" - shutil.copy(test_file, dst) - self.assertIsNone(barcodes.convert_from_tiff_to_pdf(dst)) - def test_scan_file_for_separating_barcodes(self): """ GIVEN: - - + - PDF containing barcodes WHEN: - - + - File is scanned for barcodes THEN: - - + - Correct page index located """ test_file = self.BARCODE_SAMPLE_DIR / "patch-code-t.pdf" doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -272,15 +229,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def test_scan_file_for_separating_barcodes_none_present(self): """ GIVEN: - - + - File with no barcodes WHEN: - - + - File is scanned THEN: - - + - No barcodes detected + - No pages to split on """ test_file = self.SAMPLE_DIR / "simple.pdf" doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -302,6 +261,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -323,6 +283,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -345,6 +306,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -366,6 +328,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -388,6 +351,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -411,6 +375,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -435,6 +400,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -459,6 +425,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -482,6 +449,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -504,6 +472,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -636,6 +605,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -673,7 +643,16 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): shutil.copy(test_file, dst) with mock.patch("documents.tasks.async_to_sync"): - self.assertEqual(tasks.consume_file(dst), "File successfully split") + self.assertEqual( + tasks.consume_file( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ), + None, + ), + "File successfully split", + ) @override_settings( CONSUMER_ENABLE_BARCODES=True, @@ -694,7 +673,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): shutil.copy(test_file, dst) with mock.patch("documents.tasks.async_to_sync"): - self.assertEqual(tasks.consume_file(dst), "File successfully split") + self.assertEqual( + tasks.consume_file( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ), + None, + ), + "File successfully split", + ) + self.assertFalse(dst.exists()) @override_settings( CONSUMER_ENABLE_BARCODES=True, @@ -717,7 +706,16 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): shutil.copy(test_file, dst) with self.assertLogs("paperless.barcodes", level="WARNING") as cm: - self.assertIn("Success", tasks.consume_file(dst)) + self.assertIn( + "Success", + tasks.consume_file( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ), + None, + ), + ) self.assertListEqual( cm.output, @@ -754,7 +752,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): shutil.copy(test_file, dst) with mock.patch("documents.tasks.async_to_sync"): - self.assertEqual(tasks.consume_file(dst), "File successfully split") + self.assertEqual( + tasks.consume_file( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ), + None, + ), + "File successfully split", + ) + self.assertFalse(dst.exists()) def test_scan_file_for_separating_barcodes_password(self): """ @@ -769,6 +777,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): with self.assertLogs("paperless.barcodes", level="WARNING") as cm: doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) warning = cm.output[0] expected_str = "WARNING:paperless.barcodes:File is likely password protected, not checking for barcodes" @@ -798,6 +807,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -835,6 +845,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) separator_page_numbers = barcodes.get_separating_barcodes( doc_barcode_info.barcodes, @@ -855,7 +866,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertEqual(len(document_list), 5) -class TestAsnBarcodes(DirectoriesMixin, TestCase): +class TestAsnBarcode(DirectoriesMixin, TestCase): SAMPLE_DIR = Path(__file__).parent / "samples" @@ -923,6 +934,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) @@ -944,6 +956,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) @@ -970,7 +983,13 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): shutil.copy(test_file, dst) with mock.patch("documents.consumer.Consumer.try_consume_file") as mocked_call: - tasks.consume_file(dst) + tasks.consume_file( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ), + None, + ) args, kwargs = mocked_call.call_args @@ -991,6 +1010,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) @@ -1010,6 +1030,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): doc_barcode_info = barcodes.scan_file_for_barcodes( test_file, + "application/pdf", ) asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes) @@ -1032,12 +1053,17 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase): dst = self.dirs.scratch_dir / "barcode-128-asn-too-large.pdf" shutil.copy(src, dst) + input_doc = ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=dst, + ) + with mock.patch("documents.consumer.Consumer._send_progress"): self.assertRaisesMessage( ConsumerError, "Given ASN 4294967296 is out of range [0, 4,294,967,295]", tasks.consume_file, - dst, + input_doc, ) @@ -1055,5 +1081,5 @@ class TestBarcodeZxing(TestBarcode): reason="No zxingcpp", ) @override_settings(CONSUMER_BARCODE_SCANNER="ZXING") -class TestAsnBarcodesZxing(TestAsnBarcodes): +class TestAsnBarcodesZxing(TestAsnBarcode): pass diff --git a/src/documents/tests/test_management_consumer.py b/src/documents/tests/test_management_consumer.py index 3db8de034..637a8cb20 100644 --- a/src/documents/tests/test_management_consumer.py +++ b/src/documents/tests/test_management_consumer.py @@ -1,6 +1,7 @@ import filecmp import os import shutil +from pathlib import Path from threading import Thread from time import sleep from unittest import mock @@ -11,9 +12,12 @@ from django.core.management import CommandError from django.test import override_settings from django.test import TransactionTestCase from documents.consumer import ConsumerError +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides from documents.management.commands import document_consumer from documents.models import Tag from documents.tests.utils import DirectoriesMixin +from documents.tests.utils import DocumentConsumeDelayMixin class ConsumerThread(Thread): @@ -35,18 +39,19 @@ def chunked(size, source): yield source[i : i + size] -class ConsumerMixin: +class ConsumerThreadMixin(DocumentConsumeDelayMixin): + """ + Provides a thread which runs the consumer management command at setUp + and stops it at tearDown + """ - sample_file = os.path.join(os.path.dirname(__file__), "samples", "simple.pdf") + sample_file: Path = ( + Path(__file__).parent / Path("samples") / Path("simple.pdf") + ).resolve() def setUp(self) -> None: super().setUp() self.t = None - patcher = mock.patch( - "documents.tasks.consume_file.delay", - ) - self.task_mock = patcher.start() - self.addCleanup(patcher.stop) def t_start(self): self.t = ConsumerThread() @@ -67,7 +72,7 @@ class ConsumerMixin: def wait_for_task_mock_call(self, expected_call_count=1): n = 0 while n < 50: - if self.task_mock.call_count >= expected_call_count: + if self.consume_file_mock.call_count >= expected_call_count: # give task_mock some time to finish and raise errors sleep(1) return @@ -76,8 +81,12 @@ class ConsumerMixin: # A bogus async_task that will simply check the file for # completeness and raise an exception otherwise. - def bogus_task(self, filename, **kwargs): - eq = filecmp.cmp(filename, self.sample_file, shallow=False) + def bogus_task( + self, + input_doc: ConsumableDocument, + overrides=None, + ): + eq = filecmp.cmp(input_doc.original_file, self.sample_file, shallow=False) if not eq: print("Consumed an INVALID file.") raise ConsumerError("Incomplete File READ FAILED") @@ -103,19 +112,20 @@ class ConsumerMixin: @override_settings( CONSUMER_INOTIFY_DELAY=0.01, ) -class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): +class TestConsumer(DirectoriesMixin, ConsumerThreadMixin, TransactionTestCase): def test_consume_file(self): self.t_start() - f = os.path.join(self.dirs.consumption_dir, "my_file.pdf") + f = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf")) shutil.copy(self.sample_file, f) self.wait_for_task_mock_call() - self.task_mock.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], f) + input_doc, _ = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, f) def test_consume_file_invalid_ext(self): self.t_start() @@ -125,26 +135,27 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): self.wait_for_task_mock_call() - self.task_mock.assert_not_called() + self.consume_file_mock.assert_not_called() def test_consume_existing_file(self): - f = os.path.join(self.dirs.consumption_dir, "my_file.pdf") + f = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf")) shutil.copy(self.sample_file, f) self.t_start() - self.task_mock.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], f) + input_doc, _ = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, f) @mock.patch("documents.management.commands.document_consumer.logger.error") def test_slow_write_pdf(self, error_logger): - self.task_mock.side_effect = self.bogus_task + self.consume_file_mock.side_effect = self.bogus_task self.t_start() - fname = os.path.join(self.dirs.consumption_dir, "my_file.pdf") + fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf")) self.slow_write_file(fname) @@ -152,48 +163,52 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): error_logger.assert_not_called() - self.task_mock.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], fname) + input_doc, _ = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, fname) @mock.patch("documents.management.commands.document_consumer.logger.error") def test_slow_write_and_move(self, error_logger): - self.task_mock.side_effect = self.bogus_task + self.consume_file_mock.side_effect = self.bogus_task self.t_start() - fname = os.path.join(self.dirs.consumption_dir, "my_file.~df") - fname2 = os.path.join(self.dirs.consumption_dir, "my_file.pdf") + fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.~df")) + fname2 = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf")) self.slow_write_file(fname) shutil.move(fname, fname2) self.wait_for_task_mock_call() - self.task_mock.assert_called_once() + self.consume_file_mock.assert_called_once() - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], fname2) + input_doc, _ = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, fname2) error_logger.assert_not_called() @mock.patch("documents.management.commands.document_consumer.logger.error") def test_slow_write_incomplete(self, error_logger): - self.task_mock.side_effect = self.bogus_task + self.consume_file_mock.side_effect = self.bogus_task self.t_start() - fname = os.path.join(self.dirs.consumption_dir, "my_file.pdf") + fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf")) self.slow_write_file(fname, incomplete=True) self.wait_for_task_mock_call() - self.task_mock.assert_called_once() - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], fname) + self.consume_file_mock.assert_called_once() + + input_doc, _ = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, fname) # assert that we have an error logged with this invalid file. error_logger.assert_called_once() @@ -209,7 +224,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): self.assertRaises(CommandError, call_command, "document_consumer", "--oneshot") def test_mac_write(self): - self.task_mock.side_effect = self.bogus_task + self.consume_file_mock.side_effect = self.bogus_task self.t_start() @@ -238,12 +253,13 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): self.wait_for_task_mock_call(expected_call_count=2) - self.assertEqual(2, self.task_mock.call_count) + self.assertEqual(2, self.consume_file_mock.call_count) - fnames = [ - os.path.basename(args[0]) for args, _ in self.task_mock.call_args_list - ] - self.assertCountEqual(fnames, ["my_file.pdf", "my_second_file.pdf"]) + consumed_files = [] + for input_doc, _ in self.get_all_consume_delay_call_args(): + consumed_files.append(input_doc.original_file.name) + + self.assertCountEqual(consumed_files, ["my_file.pdf", "my_second_file.pdf"]) def test_is_ignored(self): test_paths = [ @@ -341,7 +357,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): self.wait_for_task_mock_call() - self.task_mock.assert_not_called() + self.consume_file_mock.assert_not_called() @override_settings( @@ -373,7 +389,7 @@ class TestConsumerRecursivePolling(TestConsumer): pass -class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase): +class TestConsumerTags(DirectoriesMixin, ConsumerThreadMixin, TransactionTestCase): @override_settings(CONSUMER_RECURSIVE=True, CONSUMER_SUBDIRS_AS_TAGS=True) def test_consume_file_with_path_tags(self): @@ -387,7 +403,7 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase): path = os.path.join(self.dirs.consumption_dir, *tag_names) os.makedirs(path, exist_ok=True) - f = os.path.join(path, "my_file.pdf") + f = Path(os.path.join(path, "my_file.pdf")) # Wait at least inotify read_delay for recursive watchers # to be created for the new directories sleep(1) @@ -395,18 +411,19 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase): self.wait_for_task_mock_call() - self.task_mock.assert_called_once() + self.consume_file_mock.assert_called_once() # Add the pk of the Tag created by _consume() tag_ids.append(Tag.objects.get(name=tag_names[1]).pk) - args, kwargs = self.task_mock.call_args - self.assertEqual(args[0], f) + input_doc, overrides = self.get_last_consume_delay_call_args() + + self.assertEqual(input_doc.original_file, f) # assertCountEqual has a bad name, but test that the first # sequence contains the same elements as second, regardless of # their order. - self.assertCountEqual(kwargs["override_tag_ids"], tag_ids) + self.assertCountEqual(overrides.tag_ids, tag_ids) @override_settings( CONSUMER_POLLING=1, diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index e21879802..18600d709 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -1,76 +1,21 @@ +import uuid +from unittest import mock + import celery from django.test import TestCase +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.models import PaperlessTask from documents.signals.handlers import before_task_publish_handler from documents.signals.handlers import task_postrun_handler from documents.signals.handlers import task_prerun_handler +from documents.tests.test_consumer import fake_magic_from_file from documents.tests.utils import DirectoriesMixin +@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file) class TestTaskSignalHandler(DirectoriesMixin, TestCase): - - HEADERS_CONSUME = { - "lang": "py", - "task": "documents.tasks.consume_file", - "id": "52d31e24-9dcc-4c32-9e16-76007e9add5e", - "shadow": None, - "eta": None, - "expires": None, - "group": None, - "group_index": None, - "retries": 0, - "timelimit": [None, None], - "root_id": "52d31e24-9dcc-4c32-9e16-76007e9add5e", - "parent_id": None, - "argsrepr": "('/consume/hello-999.pdf',)", - "kwargsrepr": "{'override_tag_ids': None}", - "origin": "gen260@paperless-ngx-dev-webserver", - "ignore_result": False, - } - - BODY_CONSUME = ( - # args - ("/consume/hello-999.pdf",), - # kwargs - {"override_tag_ids": None}, - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, - ) - - HEADERS_WEB_UI = { - "lang": "py", - "task": "documents.tasks.consume_file", - "id": "6e88a41c-e5f8-4631-9972-68c314512498", - "shadow": None, - "eta": None, - "expires": None, - "group": None, - "group_index": None, - "retries": 0, - "timelimit": [None, None], - "root_id": "6e88a41c-e5f8-4631-9972-68c314512498", - "parent_id": None, - "argsrepr": "('/tmp/paperless/paperless-upload-st9lmbvx',)", - "kwargsrepr": "{'override_filename': 'statement.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': 'f5622ca9-3707-4ed0-b418-9680b912572f', 'override_created': None}", - "origin": "gen342@paperless-ngx-dev-webserver", - "ignore_result": False, - } - - BODY_WEB_UI = ( - # args - ("/tmp/paperless/paperless-upload-st9lmbvx",), - # kwargs - { - "override_filename": "statement.pdf", - "override_title": None, - "override_correspondent_id": None, - "override_document_type_id": None, - "override_tag_ids": None, - "task_id": "f5622ca9-3707-4ed0-b418-9680b912572f", - "override_created": None, - }, - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, - ) - def util_call_before_task_publish_handler(self, headers_to_use, body_to_use): """ Simple utility to call the pre-run handle and ensure it created a single task @@ -91,41 +36,36 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): THEN: - The task is created and marked as pending """ + headers = { + "id": str(uuid.uuid4()), + "task": "documents.tasks.consume_file", + } + body = ( + # args + ( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file="/consume/hello-999.pdf", + ), + None, + ), + # kwargs + {}, + # celery stuff + {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, + ) self.util_call_before_task_publish_handler( - headers_to_use=self.HEADERS_CONSUME, - body_to_use=self.BODY_CONSUME, + headers_to_use=headers, + body_to_use=body, ) task = PaperlessTask.objects.get() self.assertIsNotNone(task) - self.assertEqual(self.HEADERS_CONSUME["id"], task.task_id) + self.assertEqual(headers["id"], task.task_id) self.assertEqual("hello-999.pdf", task.task_file_name) self.assertEqual("documents.tasks.consume_file", task.task_name) self.assertEqual(celery.states.PENDING, task.status) - def test_before_task_publish_handler_webui(self): - """ - GIVEN: - - A celery task is started via the web ui - WHEN: - - Task before publish handler is called - THEN: - - The task is created and marked as pending - """ - self.util_call_before_task_publish_handler( - headers_to_use=self.HEADERS_WEB_UI, - body_to_use=self.BODY_WEB_UI, - ) - - task = PaperlessTask.objects.get() - - self.assertIsNotNone(task) - - self.assertEqual(self.HEADERS_WEB_UI["id"], task.task_id) - self.assertEqual("statement.pdf", task.task_file_name) - self.assertEqual("documents.tasks.consume_file", task.task_name) - self.assertEqual(celery.states.PENDING, task.status) - def test_task_prerun_handler(self): """ GIVEN: @@ -135,12 +75,32 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): THEN: - The task is marked as started """ - self.util_call_before_task_publish_handler( - headers_to_use=self.HEADERS_CONSUME, - body_to_use=self.BODY_CONSUME, + + headers = { + "id": str(uuid.uuid4()), + "task": "documents.tasks.consume_file", + } + body = ( + # args + ( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file="/consume/hello-99.pdf", + ), + None, + ), + # kwargs + {}, + # celery stuff + {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, ) - task_prerun_handler(task_id=self.HEADERS_CONSUME["id"]) + self.util_call_before_task_publish_handler( + headers_to_use=headers, + body_to_use=body, + ) + + task_prerun_handler(task_id=headers["id"]) task = PaperlessTask.objects.get() @@ -155,13 +115,31 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): THEN: - The task is marked as started """ + headers = { + "id": str(uuid.uuid4()), + "task": "documents.tasks.consume_file", + } + body = ( + # args + ( + ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file="/consume/hello-9.pdf", + ), + None, + ), + # kwargs + {}, + # celery stuff + {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, + ) self.util_call_before_task_publish_handler( - headers_to_use=self.HEADERS_CONSUME, - body_to_use=self.BODY_CONSUME, + headers_to_use=headers, + body_to_use=body, ) task_postrun_handler( - task_id=self.HEADERS_CONSUME["id"], + task_id=headers["id"], retval="Success. New document id 1 created", state=celery.states.SUCCESS, ) diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index 0a8da9ef9..26760f780 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -4,6 +4,8 @@ from collections import namedtuple from contextlib import contextmanager from os import PathLike from pathlib import Path +from typing import Iterator +from typing import Tuple from typing import Union from unittest import mock @@ -12,6 +14,8 @@ from django.db import connection from django.db.migrations.executor import MigrationExecutor from django.test import override_settings from django.test import TransactionTestCase +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides def setup_directories(): @@ -116,6 +120,11 @@ class ConsumerProgressMixin: class DocumentConsumeDelayMixin: + """ + Provides mocking of the consume_file asynchronous task and useful utilities + for decoding its arguments + """ + def setUp(self) -> None: self.consume_file_patcher = mock.patch("documents.tasks.consume_file.delay") self.consume_file_mock = self.consume_file_patcher.start() @@ -125,6 +134,47 @@ class DocumentConsumeDelayMixin: super().tearDown() self.consume_file_patcher.stop() + def get_last_consume_delay_call_args( + self, + ) -> Tuple[ConsumableDocument, DocumentMetadataOverrides]: + """ + Returns the most recent arguments to the async task + """ + # Must be at least 1 call + self.consume_file_mock.assert_called() + + args, _ = self.consume_file_mock.call_args + input_doc, overrides = args + + return (input_doc, overrides) + + def get_all_consume_delay_call_args( + self, + ) -> Iterator[Tuple[ConsumableDocument, DocumentMetadataOverrides]]: + """ + Iterates over all calls to the async task and returns the arguments + """ + + for args, _ in self.consume_file_mock.call_args_list: + input_doc, overrides = args + + yield (input_doc, overrides) + + def get_specific_consume_delay_call_args( + self, + index: int, + ) -> Iterator[Tuple[ConsumableDocument, DocumentMetadataOverrides]]: + """ + Returns the arguments of a specific call to the async task + """ + # Must be at least 1 call + self.consume_file_mock.assert_called() + + args, _ = self.consume_file_mock.call_args_list[index] + input_doc, overrides = args + + return (input_doc, overrides) + class TestMigrations(TransactionTestCase): @property @@ -140,7 +190,7 @@ class TestMigrations(TransactionTestCase): assert ( self.migrate_from and self.migrate_to - ), "TestCase '{}' must define migrate_from and migrate_to properties".format( + ), "TestCase '{}' must define migrate_from and migrate_to properties".format( type(self).__name__, ) self.migrate_from = [(self.app, self.migrate_from)] diff --git a/src/documents/views.py b/src/documents/views.py index 1b30ec770..a50d9f7f4 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -5,7 +5,6 @@ import os import re import tempfile import urllib -import uuid import zipfile from datetime import datetime from pathlib import Path @@ -65,6 +64,9 @@ from .bulk_download import ArchiveOnlyStrategy from .bulk_download import OriginalAndArchiveStrategy from .bulk_download import OriginalsOnlyStrategy from .classifier import load_classifier +from .data_models import ConsumableDocument +from .data_models import DocumentMetadataOverrides +from .data_models import DocumentSource from .filters import CorrespondentFilterSet from .filters import DocumentFilterSet from .filters import DocumentTypeFilterSet @@ -692,19 +694,24 @@ class PostDocumentView(GenericAPIView): os.utime(temp_file_path, times=(t, t)) - task_id = str(uuid.uuid4()) + input_doc = ConsumableDocument( + source=DocumentSource.ApiUpload, + original_file=temp_file_path, + ) + input_doc_overrides = DocumentMetadataOverrides( + filename=doc_name, + title=title, + correspondent_id=correspondent_id, + document_type_id=document_type_id, + tag_ids=tag_ids, + created=created, + asn=archive_serial_number, + owner_id=request.user.id, + ) async_task = consume_file.delay( - # Paths are not JSON friendly - str(temp_file_path), - override_title=title, - override_correspondent_id=correspondent_id, - override_document_type_id=document_type_id, - override_tag_ids=tag_ids, - task_id=task_id, - override_created=created, - override_owner_id=request.user.id, - override_archive_serial_num=archive_serial_number, + input_doc, + input_doc_overrides, ) return Response(async_task.id) diff --git a/src/paperless_mail/mail.py b/src/paperless_mail/mail.py index 50a578563..06dd3ac6c 100644 --- a/src/paperless_mail/mail.py +++ b/src/paperless_mail/mail.py @@ -21,6 +21,9 @@ from django.conf import settings from django.db import DatabaseError from django.utils.timezone import is_naive from django.utils.timezone import make_aware +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.loggers import LoggingMixin from documents.models import Correspondent from documents.parsers import is_mime_type_supported @@ -694,18 +697,22 @@ class MailAccountHandler(LoggingMixin): f"{message.subject} from {message.from_}", ) + input_doc = ConsumableDocument( + source=DocumentSource.MailFetch, + original_file=temp_filename, + ) + doc_overrides = DocumentMetadataOverrides( + title=title, + filename=pathvalidate.sanitize_filename(att.filename), + correspondent_id=correspondent.id if correspondent else None, + document_type_id=doc_type.id if doc_type else None, + tag_ids=tag_ids, + owner_id=rule.owner.id if rule.owner else None, + ) + consume_task = consume_file.s( - path=temp_filename, - override_filename=pathvalidate.sanitize_filename( - att.filename, - ), - override_title=title, - override_correspondent_id=correspondent.id - if correspondent - else None, - override_document_type_id=doc_type.id if doc_type else None, - override_tag_ids=tag_ids, - override_owner_id=rule.owner.id if rule.owner else None, + input_doc, + doc_overrides, ) consume_tasks.append(consume_task) @@ -770,16 +777,22 @@ class MailAccountHandler(LoggingMixin): f"{message.subject} from {message.from_}", ) + input_doc = ConsumableDocument( + source=DocumentSource.MailFetch, + original_file=temp_filename, + ) + doc_overrides = DocumentMetadataOverrides( + title=message.subject, + filename=pathvalidate.sanitize_filename(f"{message.subject}.eml"), + correspondent_id=correspondent.id if correspondent else None, + document_type_id=doc_type.id if doc_type else None, + tag_ids=tag_ids, + owner_id=rule.owner.id if rule.owner else None, + ) + consume_task = consume_file.s( - path=temp_filename, - override_filename=pathvalidate.sanitize_filename( - message.subject + ".eml", - ), - override_title=message.subject, - override_correspondent_id=correspondent.id if correspondent else None, - override_document_type_id=doc_type.id if doc_type else None, - override_tag_ids=tag_ids, - override_owner_id=rule.owner.id if rule.owner else None, + input_doc, + doc_overrides, ) queue_consumption_tasks( diff --git a/src/paperless_mail/tests/test_mail.py b/src/paperless_mail/tests/test_mail.py index c0bfccba5..e08f0ad18 100644 --- a/src/paperless_mail/tests/test_mail.py +++ b/src/paperless_mail/tests/test_mail.py @@ -12,8 +12,11 @@ from unittest import mock from django.core.management import call_command from django.db import DatabaseError from django.test import TestCase +from documents.data_models import ConsumableDocument +from documents.data_models import DocumentMetadataOverrides from documents.models import Correspondent from documents.tests.utils import DirectoriesMixin +from documents.tests.utils import DocumentConsumeDelayMixin from documents.tests.utils import FileSystemAssertsMixin from imap_tools import EmailAddress from imap_tools import FolderInfo @@ -194,7 +197,11 @@ def fake_magic_from_buffer(buffer, mime=False): @mock.patch("paperless_mail.mail.magic.from_buffer", fake_magic_from_buffer) -class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): +class TestMail( + DirectoriesMixin, + FileSystemAssertsMixin, + TestCase, +): def setUp(self): self._used_uids = set() @@ -409,6 +416,8 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertEqual(result, 2) + self._queue_consumption_tasks_mock.assert_called() + self.assert_queue_consumption_tasks_call_args( [ [ @@ -426,7 +435,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): result = self.mail_account_handler._handle_message(message, rule) - self.assertFalse(self._queue_consumption_tasks_mock.called) + self._queue_consumption_tasks_mock.assert_not_called() self.assertEqual(result, 0) def test_handle_unknown_mime_type(self): @@ -541,7 +550,6 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): for (pattern, matches) in tests: with self.subTest(msg=pattern): - print(f"PATTERN {pattern}") self._queue_consumption_tasks_mock.reset_mock() account = MailAccount(name=str(uuid.uuid4())) account.save() @@ -855,7 +863,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.mail_account_handler.handle_mail_account(account) self.bogus_mailbox.folder.list.assert_called_once() - self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0) + self._queue_consumption_tasks_mock.assert_not_called() def test_error_folder_set_error_listing(self): """ @@ -888,7 +896,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.mail_account_handler.handle_mail_account(account) self.bogus_mailbox.folder.list.assert_called_once() - self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0) + self._queue_consumption_tasks_mock.assert_not_called() @mock.patch("paperless_mail.mail.MailAccountHandler._get_correspondent") def test_error_skip_mail(self, m): @@ -1002,7 +1010,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.reset_bogus_mailbox() self._queue_consumption_tasks_mock.reset_mock() - self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0) + self._queue_consumption_tasks_mock.assert_not_called() self.assertEqual(len(self.bogus_mailbox.messages), 3) self.mail_account_handler.handle_mail_account(account) @@ -1041,7 +1049,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): ) self.assertEqual(len(self.bogus_mailbox.messages), 3) - self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0) + self._queue_consumption_tasks_mock.assert_not_called() self.assertEqual(len(self.bogus_mailbox.fetch("UNSEEN", False)), 2) self.mail_account_handler.handle_mail_account(account) @@ -1148,13 +1156,21 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase): consume_tasks, expected_signatures, ): + input_doc, overrides = consume_task.args + # assert the file exists - self.assertIsFile(consume_task.kwargs["path"]) + self.assertIsFile(input_doc.original_file) # assert all expected arguments are present in the signature for key, value in expected_signature.items(): - self.assertIn(key, consume_task.kwargs) - self.assertEqual(consume_task.kwargs[key], value) + if key == "override_correspondent_id": + self.assertEqual(overrides.correspondent_id, value) + elif key == "override_filename": + self.assertEqual(overrides.filename, value) + elif key == "override_title": + self.assertEqual(overrides.title, value) + else: + self.fail("No match for expected arg") def apply_mail_actions(self): """