diff --git a/src/documents/bulk_download.py b/src/documents/bulk_download.py new file mode 100644 index 000000000..8c675b4b5 --- /dev/null +++ b/src/documents/bulk_download.py @@ -0,0 +1,60 @@ +from zipfile import ZipFile + +from documents.models import Document + + +class BulkArchiveStrategy: + + def __init__(self, zipf: ZipFile): + self.zipf = zipf + + def make_unique_filename(self, + doc: Document, + archive: bool = False, + folder: str = ""): + counter = 0 + while True: + filename = folder + doc.get_public_filename(archive, counter) + if filename in self.zipf.namelist(): + counter += 1 + else: + return filename + + def add_document(self, doc: Document): + raise NotImplementedError() # pragma: no cover + + +class OriginalsOnlyStrategy(BulkArchiveStrategy): + + def add_document(self, doc: Document): + self.zipf.write(doc.source_path, self.make_unique_filename(doc)) + + +class ArchiveOnlyStrategy(BulkArchiveStrategy): + + def __init__(self, zipf): + super(ArchiveOnlyStrategy, self).__init__(zipf) + + def add_document(self, doc: Document): + if doc.has_archive_version: + self.zipf.write(doc.archive_path, + self.make_unique_filename(doc, archive=True)) + else: + self.zipf.write(doc.source_path, + self.make_unique_filename(doc)) + + +class OriginalAndArchiveStrategy(BulkArchiveStrategy): + + def add_document(self, doc: Document): + if doc.has_archive_version: + self.zipf.write( + doc.archive_path, self.make_unique_filename( + doc, archive=True, folder="archive/" + ) + ) + + self.zipf.write( + doc.source_path, + self.make_unique_filename(doc, folder="originals/") + ) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index b01f82805..66736bdbf 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -192,14 +192,34 @@ class SavedViewSerializer(serializers.ModelSerializer): return saved_view -class BulkEditSerializer(serializers.Serializer): +class DocumentListSerializer(serializers.Serializer): documents = serializers.ListField( - child=serializers.IntegerField(), + required=True, label="Documents", - write_only=True + write_only=True, + child=serializers.IntegerField() ) + def _validate_document_id_list(self, documents, name="documents"): + if not type(documents) == list: + raise serializers.ValidationError(f"{name} must be a list") + if not all([type(i) == int for i in documents]): + raise serializers.ValidationError( + f"{name} must be a list of integers") + count = Document.objects.filter(id__in=documents).count() + if not count == len(documents): + raise serializers.ValidationError( + f"Some documents in {name} don't exist or were " + f"specified twice.") + + def validate_documents(self, documents): + self._validate_document_id_list(documents) + return documents + + +class BulkEditSerializer(DocumentListSerializer): + method = serializers.ChoiceField( choices=[ "set_correspondent", @@ -215,18 +235,6 @@ class BulkEditSerializer(serializers.Serializer): parameters = serializers.DictField(allow_empty=True) - def _validate_document_id_list(self, documents, name="documents"): - if not type(documents) == list: - raise serializers.ValidationError(f"{name} must be a list") - if not all([type(i) == int for i in documents]): - raise serializers.ValidationError( - f"{name} must be a list of integers") - count = Document.objects.filter(id__in=documents).count() - if not count == len(documents): - raise serializers.ValidationError( - f"Some documents in {name} don't exist or were " - f"specified twice.") - def _validate_tag_id_list(self, tags, name="tags"): if not type(tags) == list: raise serializers.ValidationError(f"{name} must be a list") @@ -238,10 +246,6 @@ class BulkEditSerializer(serializers.Serializer): raise serializers.ValidationError( f"Some tags in {name} don't exist or were specified twice.") - def validate_documents(self, documents): - self._validate_document_id_list(documents) - return documents - def validate_method(self, method): if method == "set_correspondent": return bulk_edit.set_correspondent @@ -392,9 +396,24 @@ class PostDocumentSerializer(serializers.Serializer): return None -class SelectionDataSerializer(serializers.Serializer): +class BulkDownloadSerializer(DocumentListSerializer): - documents = serializers.ListField( - required=True, - child=serializers.IntegerField() + content = serializers.ChoiceField( + choices=["archive", "originals", "both"], + default="archive" ) + + compression = serializers.ChoiceField( + choices=["none", "deflated", "bzip2", "lzma"], + default="none" + ) + + def validate_compression(self, compression): + import zipfile + + return { + "none": zipfile.ZIP_STORED, + "deflated": zipfile.ZIP_DEFLATED, + "bzip2": zipfile.ZIP_BZIP2, + "lzma": zipfile.ZIP_LZMA + }[compression] diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 01e7210a5..7486154e1 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -1,7 +1,10 @@ +import datetime +import io import json import os import shutil import tempfile +import zipfile from unittest import mock from django.conf import settings @@ -1123,6 +1126,113 @@ class TestBulkEdit(DirectoriesMixin, APITestCase): self.assertCountEqual(response.data['selected_document_types'], [{"id": self.c1.id, "document_count": 1}, {"id": self.c2.id, "document_count": 0}]) +class TestBulkDownload(DirectoriesMixin, APITestCase): + + def setUp(self): + super(TestBulkDownload, self).setUp() + + user = User.objects.create_superuser(username="temp_admin") + self.client.force_login(user=user) + + self.doc1 = Document.objects.create(title="unrelated", checksum="A") + self.doc2 = Document.objects.create(title="document A", filename="docA.pdf", mime_type="application/pdf", checksum="B", created=datetime.datetime(2021, 1, 1)) + self.doc2b = Document.objects.create(title="document A", filename="docA2.pdf", mime_type="application/pdf", checksum="D", created=datetime.datetime(2021, 1, 1)) + self.doc3 = Document.objects.create(title="document B", filename="docB.jpg", mime_type="image/jpeg", checksum="C", created=datetime.datetime(2020, 3, 21), archive_filename="docB.pdf", archive_checksum="D") + + shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), self.doc2.source_path) + shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.png"), self.doc2b.source_path) + shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.jpg"), self.doc3.source_path) + shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "test_with_bom.pdf"), self.doc3.archive_path) + + def test_download_originals(self): + response = self.client.post("/api/documents/bulk_download/", json.dumps({ + "documents": [self.doc2.id, self.doc3.id], + "content": "originals" + }), content_type='application/json') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response['Content-Type'], 'application/zip') + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 2) + self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) + self.assertIn("2020-03-21 document B.jpg", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf")) + + with self.doc3.source_file as f: + self.assertEqual(f.read(), zipf.read("2020-03-21 document B.jpg")) + + def test_download_default(self): + response = self.client.post("/api/documents/bulk_download/", json.dumps({ + "documents": [self.doc2.id, self.doc3.id] + }), content_type='application/json') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response['Content-Type'], 'application/zip') + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 2) + self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) + self.assertIn("2020-03-21 document B.pdf", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf")) + + with self.doc3.archive_file as f: + self.assertEqual(f.read(), zipf.read("2020-03-21 document B.pdf")) + + def test_download_both(self): + response = self.client.post("/api/documents/bulk_download/", json.dumps({ + "documents": [self.doc2.id, self.doc3.id], + "content": "both" + }), content_type='application/json') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response['Content-Type'], 'application/zip') + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 3) + self.assertIn("originals/2021-01-01 document A.pdf", zipf.namelist()) + self.assertIn("archive/2020-03-21 document B.pdf", zipf.namelist()) + self.assertIn("originals/2020-03-21 document B.jpg", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("originals/2021-01-01 document A.pdf")) + + with self.doc3.archive_file as f: + self.assertEqual(f.read(), zipf.read("archive/2020-03-21 document B.pdf")) + + with self.doc3.source_file as f: + self.assertEqual(f.read(), zipf.read("originals/2020-03-21 document B.jpg")) + + def test_filename_clashes(self): + response = self.client.post("/api/documents/bulk_download/", json.dumps({ + "documents": [self.doc2.id, self.doc2b.id] + }), content_type='application/json') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response['Content-Type'], 'application/zip') + + with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + self.assertEqual(len(zipf.filelist), 2) + + self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) + self.assertIn("2021-01-01 document A_01.pdf", zipf.namelist()) + + with self.doc2.source_file as f: + self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf")) + + with self.doc2b.source_file as f: + self.assertEqual(f.read(), zipf.read("2021-01-01 document A_01.pdf")) + + def test_compression(self): + response = self.client.post("/api/documents/bulk_download/", json.dumps({ + "documents": [self.doc2.id, self.doc2b.id], + "compression": "lzma" + }), content_type='application/json') + class TestApiAuth(APITestCase): def test_auth_required(self): @@ -1146,4 +1256,5 @@ class TestApiAuth(APITestCase): self.assertEqual(self.client.get("/api/search/").status_code, 401) self.assertEqual(self.client.get("/api/search/auto_complete/").status_code, 401) self.assertEqual(self.client.get("/api/documents/bulk_edit/").status_code, 401) + self.assertEqual(self.client.get("/api/documents/bulk_download/").status_code, 401) self.assertEqual(self.client.get("/api/documents/selection_data/").status_code, 401) diff --git a/src/documents/views.py b/src/documents/views.py index 68d6e3c77..d886324ae 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -2,6 +2,7 @@ import logging import os import tempfile import uuid +import zipfile from datetime import datetime from time import mktime @@ -34,6 +35,8 @@ from rest_framework.viewsets import ( from paperless.db import GnuPG from paperless.views import StandardPagination +from .bulk_download import OriginalAndArchiveStrategy, OriginalsOnlyStrategy, \ + ArchiveOnlyStrategy from .classifier import load_classifier from .filters import ( CorrespondentFilterSet, @@ -51,7 +54,9 @@ from .serialisers import ( DocumentTypeSerializer, PostDocumentSerializer, SavedViewSerializer, - BulkEditSerializer, SelectionDataSerializer + BulkEditSerializer, + DocumentListSerializer, + BulkDownloadSerializer ) @@ -444,7 +449,7 @@ class PostDocumentView(APIView): class SelectionDataView(APIView): permission_classes = (IsAuthenticated,) - serializer_class = SelectionDataSerializer + serializer_class = DocumentListSerializer parser_classes = (parsers.MultiPartParser, parsers.JSONParser) def get_serializer_context(self): @@ -606,3 +611,52 @@ class StatisticsView(APIView): 'documents_total': documents_total, 'documents_inbox': documents_inbox, }) + + +class BulkDownloadView(APIView): + + permission_classes = (IsAuthenticated,) + serializer_class = BulkDownloadSerializer + parser_classes = (parsers.JSONParser,) + + def get_serializer_context(self): + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + + def get_serializer(self, *args, **kwargs): + kwargs['context'] = self.get_serializer_context() + return self.serializer_class(*args, **kwargs) + + def post(self, request, format=None): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + ids = serializer.validated_data.get('documents') + compression = serializer.validated_data.get('compression') + content = serializer.validated_data.get('content') + + os.makedirs(settings.SCRATCH_DIR, exist_ok=True) + temp = tempfile.NamedTemporaryFile(dir=settings.SCRATCH_DIR, suffix="-compressed-archive", delete=False) + + if content == 'both': + strategy_class = OriginalAndArchiveStrategy + elif content == 'originals': + strategy_class = OriginalsOnlyStrategy + else: + strategy_class = ArchiveOnlyStrategy + + with zipfile.ZipFile(temp.name, "w", compression) as zipf: + strategy = strategy_class(zipf) + for id in ids: + doc = Document.objects.get(id=id) + strategy.add_document(doc) + + with open(temp.name, "rb") as f: + response = HttpResponse(f, content_type="application/zip") + response["Content-Disposition"] = '{}; filename="{}"'.format( + "attachment", "documents.zip") + + return response diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 40f4bd754..4e0b8f191 100755 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -23,7 +23,8 @@ from documents.views import ( PostDocumentView, SavedViewViewSet, BulkEditView, - SelectionDataView + SelectionDataView, + BulkDownloadView ) from paperless.views import FaviconView @@ -63,6 +64,9 @@ urlpatterns = [ re_path(r"^documents/selection_data/", SelectionDataView.as_view(), name="selection_data"), + re_path(r"^documents/bulk_download/", BulkDownloadView.as_view(), + name="bulk_download"), + path('token/', views.obtain_auth_token) ] + api_router.urls)),