API support for downloading compressed archives of multiple documents

This commit is contained in:
jonaswinkler 2021-02-20 16:09:29 +01:00
parent 4de4789605
commit 31f03ef1d3
5 changed files with 274 additions and 26 deletions

@ -0,0 +1,60 @@
from zipfile import ZipFile
from documents.models import Document
class BulkArchiveStrategy:
def __init__(self, zipf: ZipFile):
self.zipf = zipf
def make_unique_filename(self,
doc: Document,
archive: bool = False,
folder: str = ""):
counter = 0
while True:
filename = folder + doc.get_public_filename(archive, counter)
if filename in self.zipf.namelist():
counter += 1
else:
return filename
def add_document(self, doc: Document):
raise NotImplementedError() # pragma: no cover
class OriginalsOnlyStrategy(BulkArchiveStrategy):
def add_document(self, doc: Document):
self.zipf.write(doc.source_path, self.make_unique_filename(doc))
class ArchiveOnlyStrategy(BulkArchiveStrategy):
def __init__(self, zipf):
super(ArchiveOnlyStrategy, self).__init__(zipf)
def add_document(self, doc: Document):
if doc.has_archive_version:
self.zipf.write(doc.archive_path,
self.make_unique_filename(doc, archive=True))
else:
self.zipf.write(doc.source_path,
self.make_unique_filename(doc))
class OriginalAndArchiveStrategy(BulkArchiveStrategy):
def add_document(self, doc: Document):
if doc.has_archive_version:
self.zipf.write(
doc.archive_path, self.make_unique_filename(
doc, archive=True, folder="archive/"
)
)
self.zipf.write(
doc.source_path,
self.make_unique_filename(doc, folder="originals/")
)

@ -192,14 +192,34 @@ class SavedViewSerializer(serializers.ModelSerializer):
return saved_view return saved_view
class BulkEditSerializer(serializers.Serializer): class DocumentListSerializer(serializers.Serializer):
documents = serializers.ListField( documents = serializers.ListField(
child=serializers.IntegerField(), required=True,
label="Documents", label="Documents",
write_only=True write_only=True,
child=serializers.IntegerField()
) )
def _validate_document_id_list(self, documents, name="documents"):
if not type(documents) == list:
raise serializers.ValidationError(f"{name} must be a list")
if not all([type(i) == int for i in documents]):
raise serializers.ValidationError(
f"{name} must be a list of integers")
count = Document.objects.filter(id__in=documents).count()
if not count == len(documents):
raise serializers.ValidationError(
f"Some documents in {name} don't exist or were "
f"specified twice.")
def validate_documents(self, documents):
self._validate_document_id_list(documents)
return documents
class BulkEditSerializer(DocumentListSerializer):
method = serializers.ChoiceField( method = serializers.ChoiceField(
choices=[ choices=[
"set_correspondent", "set_correspondent",
@ -215,18 +235,6 @@ class BulkEditSerializer(serializers.Serializer):
parameters = serializers.DictField(allow_empty=True) parameters = serializers.DictField(allow_empty=True)
def _validate_document_id_list(self, documents, name="documents"):
if not type(documents) == list:
raise serializers.ValidationError(f"{name} must be a list")
if not all([type(i) == int for i in documents]):
raise serializers.ValidationError(
f"{name} must be a list of integers")
count = Document.objects.filter(id__in=documents).count()
if not count == len(documents):
raise serializers.ValidationError(
f"Some documents in {name} don't exist or were "
f"specified twice.")
def _validate_tag_id_list(self, tags, name="tags"): def _validate_tag_id_list(self, tags, name="tags"):
if not type(tags) == list: if not type(tags) == list:
raise serializers.ValidationError(f"{name} must be a list") raise serializers.ValidationError(f"{name} must be a list")
@ -238,10 +246,6 @@ class BulkEditSerializer(serializers.Serializer):
raise serializers.ValidationError( raise serializers.ValidationError(
f"Some tags in {name} don't exist or were specified twice.") f"Some tags in {name} don't exist or were specified twice.")
def validate_documents(self, documents):
self._validate_document_id_list(documents)
return documents
def validate_method(self, method): def validate_method(self, method):
if method == "set_correspondent": if method == "set_correspondent":
return bulk_edit.set_correspondent return bulk_edit.set_correspondent
@ -392,9 +396,24 @@ class PostDocumentSerializer(serializers.Serializer):
return None return None
class SelectionDataSerializer(serializers.Serializer): class BulkDownloadSerializer(DocumentListSerializer):
documents = serializers.ListField( content = serializers.ChoiceField(
required=True, choices=["archive", "originals", "both"],
child=serializers.IntegerField() default="archive"
) )
compression = serializers.ChoiceField(
choices=["none", "deflated", "bzip2", "lzma"],
default="none"
)
def validate_compression(self, compression):
import zipfile
return {
"none": zipfile.ZIP_STORED,
"deflated": zipfile.ZIP_DEFLATED,
"bzip2": zipfile.ZIP_BZIP2,
"lzma": zipfile.ZIP_LZMA
}[compression]

@ -1,7 +1,10 @@
import datetime
import io
import json import json
import os import os
import shutil import shutil
import tempfile import tempfile
import zipfile
from unittest import mock from unittest import mock
from django.conf import settings from django.conf import settings
@ -1123,6 +1126,113 @@ class TestBulkEdit(DirectoriesMixin, APITestCase):
self.assertCountEqual(response.data['selected_document_types'], [{"id": self.c1.id, "document_count": 1}, {"id": self.c2.id, "document_count": 0}]) self.assertCountEqual(response.data['selected_document_types'], [{"id": self.c1.id, "document_count": 1}, {"id": self.c2.id, "document_count": 0}])
class TestBulkDownload(DirectoriesMixin, APITestCase):
def setUp(self):
super(TestBulkDownload, self).setUp()
user = User.objects.create_superuser(username="temp_admin")
self.client.force_login(user=user)
self.doc1 = Document.objects.create(title="unrelated", checksum="A")
self.doc2 = Document.objects.create(title="document A", filename="docA.pdf", mime_type="application/pdf", checksum="B", created=datetime.datetime(2021, 1, 1))
self.doc2b = Document.objects.create(title="document A", filename="docA2.pdf", mime_type="application/pdf", checksum="D", created=datetime.datetime(2021, 1, 1))
self.doc3 = Document.objects.create(title="document B", filename="docB.jpg", mime_type="image/jpeg", checksum="C", created=datetime.datetime(2020, 3, 21), archive_filename="docB.pdf", archive_checksum="D")
shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), self.doc2.source_path)
shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.png"), self.doc2b.source_path)
shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "simple.jpg"), self.doc3.source_path)
shutil.copy(os.path.join(os.path.dirname(__file__), "samples", "test_with_bom.pdf"), self.doc3.archive_path)
def test_download_originals(self):
response = self.client.post("/api/documents/bulk_download/", json.dumps({
"documents": [self.doc2.id, self.doc3.id],
"content": "originals"
}), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'], 'application/zip')
with zipfile.ZipFile(io.BytesIO(response.content)) as zipf:
self.assertEqual(len(zipf.filelist), 2)
self.assertIn("2021-01-01 document A.pdf", zipf.namelist())
self.assertIn("2020-03-21 document B.jpg", zipf.namelist())
with self.doc2.source_file as f:
self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf"))
with self.doc3.source_file as f:
self.assertEqual(f.read(), zipf.read("2020-03-21 document B.jpg"))
def test_download_default(self):
response = self.client.post("/api/documents/bulk_download/", json.dumps({
"documents": [self.doc2.id, self.doc3.id]
}), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'], 'application/zip')
with zipfile.ZipFile(io.BytesIO(response.content)) as zipf:
self.assertEqual(len(zipf.filelist), 2)
self.assertIn("2021-01-01 document A.pdf", zipf.namelist())
self.assertIn("2020-03-21 document B.pdf", zipf.namelist())
with self.doc2.source_file as f:
self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf"))
with self.doc3.archive_file as f:
self.assertEqual(f.read(), zipf.read("2020-03-21 document B.pdf"))
def test_download_both(self):
response = self.client.post("/api/documents/bulk_download/", json.dumps({
"documents": [self.doc2.id, self.doc3.id],
"content": "both"
}), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'], 'application/zip')
with zipfile.ZipFile(io.BytesIO(response.content)) as zipf:
self.assertEqual(len(zipf.filelist), 3)
self.assertIn("originals/2021-01-01 document A.pdf", zipf.namelist())
self.assertIn("archive/2020-03-21 document B.pdf", zipf.namelist())
self.assertIn("originals/2020-03-21 document B.jpg", zipf.namelist())
with self.doc2.source_file as f:
self.assertEqual(f.read(), zipf.read("originals/2021-01-01 document A.pdf"))
with self.doc3.archive_file as f:
self.assertEqual(f.read(), zipf.read("archive/2020-03-21 document B.pdf"))
with self.doc3.source_file as f:
self.assertEqual(f.read(), zipf.read("originals/2020-03-21 document B.jpg"))
def test_filename_clashes(self):
response = self.client.post("/api/documents/bulk_download/", json.dumps({
"documents": [self.doc2.id, self.doc2b.id]
}), content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'], 'application/zip')
with zipfile.ZipFile(io.BytesIO(response.content)) as zipf:
self.assertEqual(len(zipf.filelist), 2)
self.assertIn("2021-01-01 document A.pdf", zipf.namelist())
self.assertIn("2021-01-01 document A_01.pdf", zipf.namelist())
with self.doc2.source_file as f:
self.assertEqual(f.read(), zipf.read("2021-01-01 document A.pdf"))
with self.doc2b.source_file as f:
self.assertEqual(f.read(), zipf.read("2021-01-01 document A_01.pdf"))
def test_compression(self):
response = self.client.post("/api/documents/bulk_download/", json.dumps({
"documents": [self.doc2.id, self.doc2b.id],
"compression": "lzma"
}), content_type='application/json')
class TestApiAuth(APITestCase): class TestApiAuth(APITestCase):
def test_auth_required(self): def test_auth_required(self):
@ -1146,4 +1256,5 @@ class TestApiAuth(APITestCase):
self.assertEqual(self.client.get("/api/search/").status_code, 401) self.assertEqual(self.client.get("/api/search/").status_code, 401)
self.assertEqual(self.client.get("/api/search/auto_complete/").status_code, 401) self.assertEqual(self.client.get("/api/search/auto_complete/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/bulk_edit/").status_code, 401) self.assertEqual(self.client.get("/api/documents/bulk_edit/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/bulk_download/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/selection_data/").status_code, 401) self.assertEqual(self.client.get("/api/documents/selection_data/").status_code, 401)

@ -2,6 +2,7 @@ import logging
import os import os
import tempfile import tempfile
import uuid import uuid
import zipfile
from datetime import datetime from datetime import datetime
from time import mktime from time import mktime
@ -34,6 +35,8 @@ from rest_framework.viewsets import (
from paperless.db import GnuPG from paperless.db import GnuPG
from paperless.views import StandardPagination from paperless.views import StandardPagination
from .bulk_download import OriginalAndArchiveStrategy, OriginalsOnlyStrategy, \
ArchiveOnlyStrategy
from .classifier import load_classifier from .classifier import load_classifier
from .filters import ( from .filters import (
CorrespondentFilterSet, CorrespondentFilterSet,
@ -51,7 +54,9 @@ from .serialisers import (
DocumentTypeSerializer, DocumentTypeSerializer,
PostDocumentSerializer, PostDocumentSerializer,
SavedViewSerializer, SavedViewSerializer,
BulkEditSerializer, SelectionDataSerializer BulkEditSerializer,
DocumentListSerializer,
BulkDownloadSerializer
) )
@ -444,7 +449,7 @@ class PostDocumentView(APIView):
class SelectionDataView(APIView): class SelectionDataView(APIView):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = SelectionDataSerializer serializer_class = DocumentListSerializer
parser_classes = (parsers.MultiPartParser, parsers.JSONParser) parser_classes = (parsers.MultiPartParser, parsers.JSONParser)
def get_serializer_context(self): def get_serializer_context(self):
@ -606,3 +611,52 @@ class StatisticsView(APIView):
'documents_total': documents_total, 'documents_total': documents_total,
'documents_inbox': documents_inbox, 'documents_inbox': documents_inbox,
}) })
class BulkDownloadView(APIView):
permission_classes = (IsAuthenticated,)
serializer_class = BulkDownloadSerializer
parser_classes = (parsers.JSONParser,)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, format=None):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
ids = serializer.validated_data.get('documents')
compression = serializer.validated_data.get('compression')
content = serializer.validated_data.get('content')
os.makedirs(settings.SCRATCH_DIR, exist_ok=True)
temp = tempfile.NamedTemporaryFile(dir=settings.SCRATCH_DIR, suffix="-compressed-archive", delete=False)
if content == 'both':
strategy_class = OriginalAndArchiveStrategy
elif content == 'originals':
strategy_class = OriginalsOnlyStrategy
else:
strategy_class = ArchiveOnlyStrategy
with zipfile.ZipFile(temp.name, "w", compression) as zipf:
strategy = strategy_class(zipf)
for id in ids:
doc = Document.objects.get(id=id)
strategy.add_document(doc)
with open(temp.name, "rb") as f:
response = HttpResponse(f, content_type="application/zip")
response["Content-Disposition"] = '{}; filename="{}"'.format(
"attachment", "documents.zip")
return response

@ -23,7 +23,8 @@ from documents.views import (
PostDocumentView, PostDocumentView,
SavedViewViewSet, SavedViewViewSet,
BulkEditView, BulkEditView,
SelectionDataView SelectionDataView,
BulkDownloadView
) )
from paperless.views import FaviconView from paperless.views import FaviconView
@ -63,6 +64,9 @@ urlpatterns = [
re_path(r"^documents/selection_data/", SelectionDataView.as_view(), re_path(r"^documents/selection_data/", SelectionDataView.as_view(),
name="selection_data"), name="selection_data"),
re_path(r"^documents/bulk_download/", BulkDownloadView.as_view(),
name="bulk_download"),
path('token/', views.obtain_auth_token) path('token/', views.obtain_auth_token)
] + api_router.urls)), ] + api_router.urls)),