diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 69cbb4092..5418ec0fb 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -3,6 +3,7 @@ from django.utils.text import slugify from rest_framework import serializers from rest_framework.fields import SerializerMethodField +from . import bulk_edit from .models import Correspondent, Tag, Document, Log, DocumentType from .parsers import is_mime_type_supported @@ -164,11 +165,10 @@ class LogSerializer(serializers.ModelSerializer): class BulkEditSerializer(serializers.Serializer): - documents = serializers.PrimaryKeyRelatedField( - many=True, + documents = serializers.ListField( + child=serializers.IntegerField(), label="Documents", - write_only=True, - queryset=Document.objects.all() + write_only=True ) method = serializers.ChoiceField( @@ -185,6 +185,20 @@ class BulkEditSerializer(serializers.Serializer): parameters = serializers.DictField(allow_empty=True) + def validate_method(self, method): + if method == "set_correspondent": + return bulk_edit.set_correspondent + elif method == "set_document_type": + return bulk_edit.set_document_type + elif method == "add_tag": + return bulk_edit.add_tag + elif method == "remove_tag": + return bulk_edit.remove_tag + elif method == "delete": + return bulk_edit.delete + else: + raise serializers.ValidationError("Unsupported method.") + def validate(self, attrs): return attrs diff --git a/src/documents/views.py b/src/documents/views.py index 10cb30eb3..4ce78348e 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -278,6 +278,17 @@ class BulkEditView(APIView): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) + method = serializer.validated_data.get("method") + parameters = serializer.validated_data.get("parameters") + documents = serializer.validated_data.get("documents") + + try: + # TODO: parameter validation + result = method(documents, **parameters) + return Response({"result": result}) + except Exception as e: + return HttpResponseBadRequest(str(e)) + class PostDocumentView(APIView): diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 9b390b139..dc416f05f 100755 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -17,7 +17,8 @@ from documents.views import ( IndexView, SearchAutoCompleteView, StatisticsView, - PostDocumentView + PostDocumentView, + BulkEditView ) from paperless.views import FaviconView @@ -50,6 +51,10 @@ urlpatterns = [ re_path(r"^documents/post_document/", PostDocumentView.as_view(), name="post_document"), + + re_path(r"^documents/bulk_edit/", BulkEditView.as_view(), + name="bulk_edit"), + path('token/', views.obtain_auth_token) ] + api_router.urls)),