diff --git a/src/documents/bulk_download.py b/src/documents/bulk_download.py index 6beefa23b..87d97afcc 100644 --- a/src/documents/bulk_download.py +++ b/src/documents/bulk_download.py @@ -1,18 +1,29 @@ +import os from zipfile import ZipFile from documents.models import Document class BulkArchiveStrategy: - def __init__(self, zipf: ZipFile): + def __init__(self, zipf: ZipFile, follow_formatting: bool = False): self.zipf = zipf + if follow_formatting: + self.make_unique_filename = self._formatted_filepath + else: + self.make_unique_filename = self._filename_only - def make_unique_filename( + def _filename_only( self, doc: Document, archive: bool = False, folder: str = "", ): + """ + Constructs a unique name for the given document to be used inside the + zip file. + + The filename might not be unique enough, so a counter is appended if needed + """ counter = 0 while True: filename = folder + doc.get_public_filename(archive, counter) @@ -21,6 +32,25 @@ class BulkArchiveStrategy: else: return filename + def _formatted_filepath( + self, + doc: Document, + archive: bool = False, + folder: str = "", + ): + """ + Constructs a full file path for the given document to be used inside + the zipfile. + + The path is already unique, as handled when a document is consumed or updated + """ + if archive and doc.has_archive_version: + in_archive_path = os.path.join(folder, doc.archive_filename) + else: + in_archive_path = os.path.join(folder, doc.filename) + + return in_archive_path + def add_document(self, doc: Document): raise NotImplementedError() # pragma: no cover @@ -31,9 +61,6 @@ class OriginalsOnlyStrategy(BulkArchiveStrategy): class ArchiveOnlyStrategy(BulkArchiveStrategy): - def __init__(self, zipf): - super().__init__(zipf) - def add_document(self, doc: Document): if doc.has_archive_version: self.zipf.write( diff --git a/src/documents/models.py b/src/documents/models.py index c1b9c88bc..1ee6dfedb 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -287,6 +287,9 @@ class Document(models.Model): return open(self.archive_path, "rb") def get_public_filename(self, archive=False, counter=0, suffix=None) -> str: + """ + Returns a sanitized filename for the document, not including any paths. + """ result = str(self) if counter: diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index db282cacd..3e6ec4390 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -551,6 +551,10 @@ class BulkDownloadSerializer(DocumentListSerializer): default="none", ) + follow_formatting = serializers.BooleanField( + default=False, + ) + def validate_compression(self, compression): import zipfile diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index d876984bd..a7e2be53f 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -2329,6 +2329,9 @@ class TestBulkEdit(DirectoriesMixin, APITestCase): class TestBulkDownload(DirectoriesMixin, APITestCase): + + ENDPOINT = "/api/documents/bulk_download/" + def setUp(self): super().setUp() @@ -2379,7 +2382,7 @@ class TestBulkDownload(DirectoriesMixin, APITestCase): def test_download_originals(self): response = self.client.post( - "/api/documents/bulk_download/", + self.ENDPOINT, json.dumps( {"documents": [self.doc2.id, self.doc3.id], "content": "originals"}, ), @@ -2402,7 +2405,7 @@ class TestBulkDownload(DirectoriesMixin, APITestCase): def test_download_default(self): response = self.client.post( - "/api/documents/bulk_download/", + self.ENDPOINT, json.dumps({"documents": [self.doc2.id, self.doc3.id]}), content_type="application/json", ) @@ -2423,7 +2426,7 @@ class TestBulkDownload(DirectoriesMixin, APITestCase): def test_download_both(self): response = self.client.post( - "/api/documents/bulk_download/", + self.ENDPOINT, json.dumps({"documents": [self.doc2.id, self.doc3.id], "content": "both"}), content_type="application/json", ) @@ -2457,7 +2460,7 @@ class TestBulkDownload(DirectoriesMixin, APITestCase): def test_filename_clashes(self): response = self.client.post( - "/api/documents/bulk_download/", + self.ENDPOINT, json.dumps({"documents": [self.doc2.id, self.doc2b.id]}), content_type="application/json", ) @@ -2479,13 +2482,145 @@ class TestBulkDownload(DirectoriesMixin, APITestCase): def test_compression(self): response = self.client.post( - "/api/documents/bulk_download/", + self.ENDPOINT, json.dumps( {"documents": [self.doc2.id, self.doc2b.id], "compression": "lzma"}, ), content_type="application/json", ) + @override_settings(FILENAME_FORMAT="{correspondent}/{title}") + def test_formatted_download_originals(self): + + c = Correspondent.objects.create(name="test") + c2 = Correspondent.objects.create(name="a space name") + + self.doc2.correspondent = c + self.doc2.title = "This is Doc 2" + self.doc2.save() + + self.doc3.correspondent = c2 + self.doc3.title = "Title 2 - Doc 3" + self.doc3.save() + + response = self.client.post( + self.ENDPOINT, + json.dumps( + { + "documents": [self.doc2.id, self.doc3.id], + "content": "originals", + "follow_formatting": True, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Content-Type"], "application/zip") + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 2) + self.assertIn("a space name/Title 2 - Doc 3.jpg", zipf.namelist()) + self.assertIn("test/This is Doc 2.pdf", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("test/This is Doc 2.pdf")) + + with self.doc3.source_file as f: + self.assertEqual( + f.read(), + zipf.read("a space name/Title 2 - Doc 3.jpg"), + ) + + @override_settings(FILENAME_FORMAT="somewhere/{title}") + def test_formatted_download_archive(self): + + self.doc2.title = "This is Doc 2" + self.doc2.save() + + self.doc3.title = "Title 2 - Doc 3" + self.doc3.save() + print(self.doc3.archive_path) + print(self.doc3.archive_filename) + + response = self.client.post( + self.ENDPOINT, + json.dumps( + { + "documents": [self.doc2.id, self.doc3.id], + "follow_formatting": True, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Content-Type"], "application/zip") + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 2) + self.assertIn("somewhere/This is Doc 2.pdf", zipf.namelist()) + self.assertIn("somewhere/Title 2 - Doc 3.pdf", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("somewhere/This is Doc 2.pdf")) + + with self.doc3.archive_file as f: + self.assertEqual(f.read(), zipf.read("somewhere/Title 2 - Doc 3.pdf")) + + @override_settings(FILENAME_FORMAT="{document_type}/{title}") + def test_formatted_download_both(self): + + dc1 = DocumentType.objects.create(name="bill") + dc2 = DocumentType.objects.create(name="statement") + + self.doc2.document_type = dc1 + self.doc2.title = "This is Doc 2" + self.doc2.save() + + self.doc3.document_type = dc2 + self.doc3.title = "Title 2 - Doc 3" + self.doc3.save() + + response = self.client.post( + self.ENDPOINT, + json.dumps( + { + "documents": [self.doc2.id, self.doc3.id], + "content": "both", + "follow_formatting": True, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Content-Type"], "application/zip") + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 3) + self.assertIn("originals/bill/This is Doc 2.pdf", zipf.namelist()) + self.assertIn("archive/statement/Title 2 - Doc 3.pdf", zipf.namelist()) + self.assertIn("originals/statement/Title 2 - Doc 3.jpg", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual( + f.read(), + zipf.read("originals/bill/This is Doc 2.pdf"), + ) + + with self.doc3.archive_file as f: + self.assertEqual( + f.read(), + zipf.read("archive/statement/Title 2 - Doc 3.pdf"), + ) + + with self.doc3.source_file as f: + self.assertEqual( + f.read(), + zipf.read("originals/statement/Title 2 - Doc 3.jpg"), + ) + class TestApiAuth(DirectoriesMixin, APITestCase): def test_auth_required(self): diff --git a/src/documents/views.py b/src/documents/views.py index 10225be6f..ce82cbfaa 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -745,6 +745,7 @@ class BulkDownloadView(GenericAPIView): ids = serializer.validated_data.get("documents") compression = serializer.validated_data.get("compression") content = serializer.validated_data.get("content") + follow_filename_format = serializer.validated_data.get("follow_formatting") os.makedirs(settings.SCRATCH_DIR, exist_ok=True) temp = tempfile.NamedTemporaryFile( @@ -761,7 +762,7 @@ class BulkDownloadView(GenericAPIView): strategy_class = ArchiveOnlyStrategy with zipfile.ZipFile(temp.name, "w", compression) as zipf: - strategy = strategy_class(zipf) + strategy = strategy_class(zipf, follow_filename_format) for id in ids: doc = Document.objects.get(id=id) strategy.add_document(doc)