diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 9d8552e6f..c686a52f7 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1668,9 +1668,8 @@ class PostDocumentSerializer(serializers.Serializer): max_value=Document.ARCHIVE_SERIAL_NUMBER_MAX, ) - custom_fields = serializers.PrimaryKeyRelatedField( - many=True, - queryset=CustomField.objects.all(), + # Accept either a list of custom field ids or a dict mapping id -> value + custom_fields = serializers.JSONField( label="Custom fields", write_only=True, required=False, @@ -1682,12 +1681,6 @@ class PostDocumentSerializer(serializers.Serializer): required=False, ) - custom_fields_w_values = serializers.JSONField( - label="Custom fields with values", - write_only=True, - required=False, - ) - def validate_document(self, document): document_data = document.file.read() mime_type = magic.from_buffer(document_data, mime=True) @@ -1733,33 +1726,61 @@ class PostDocumentSerializer(serializers.Serializer): return None def validate_custom_fields(self, custom_fields): - if custom_fields: - return [custom_field.id for custom_field in custom_fields] - else: + if not custom_fields: return None - def validate_custom_fields_w_values(self, custom_fields_w_values): - if custom_fields_w_values: + if isinstance(custom_fields, dict): custom_field_serializer = CustomFieldInstanceSerializer() - for field_id, value in custom_fields_w_values.items(): + normalized = {} + for field_id, value in custom_fields.items(): try: - field = CustomField.objects.get(id=field_id) + field_id_int = int(field_id) + except (TypeError, ValueError): + raise serializers.ValidationError( + _("Custom field id must be an integer: %(id)s") + % {"id": field_id}, + ) + + try: + field = CustomField.objects.get(id=field_id_int) except CustomField.DoesNotExist: raise serializers.ValidationError( _("Custom field with id %(id)s does not exist") - % { - "id": field_id, - }, + % {"id": field_id_int}, ) - # validate the value using the CustomFieldInstanceSerializer + custom_field_serializer.validate( { "field": field, "value": value, }, ) - # Normalize keys to integers for later - return {int(k): v for k, v in custom_fields_w_values.items()} + normalized[field_id_int] = value + + return normalized + + if isinstance(custom_fields, list): + try: + ids = [int(i) for i in custom_fields] + except (TypeError, ValueError): + raise serializers.ValidationError( + _( + "Custom fields must be a list of integers or an object mapping ids to values.", + ), + ) + if CustomField.objects.filter(id__in=ids).count() != len(set(ids)): + raise serializers.ValidationError( + _("Some custom fields don't exist or were specified twice."), + ) + return ids + + raise serializers.ValidationError( + _( + "Custom fields must be a list of integers or an object mapping ids to values.", + ), + ) + + # custom_fields_w_values handled via validate_custom_fields def validate_created(self, created): # support datetime format for created for backwards compatibility diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index bab03267d..48d52eca0 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -1554,7 +1554,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): "/api/documents/post_document/", { "document": f, - "custom_fields_w_values": json.dumps({"3456": "a string"}), + "custom_fields": json.dumps({"3456": "a string"}), }, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -1575,7 +1575,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): "/api/documents/post_document/", { "document": f, - "custom_fields_w_values": json.dumps( + "custom_fields": json.dumps( { str(cf_string.id): "a string", str(cf_int.id): 123, diff --git a/src/documents/views.py b/src/documents/views.py index 55e401a3a..7baf12c69 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1497,10 +1497,7 @@ class PostDocumentView(GenericAPIView): title = serializer.validated_data.get("title") created = serializer.validated_data.get("created") archive_serial_number = serializer.validated_data.get("archive_serial_number") - custom_field_ids = serializer.validated_data.get("custom_fields") - custom_fields_w_values = serializer.validated_data.get( - "custom_fields_w_values", - ) + cf = serializer.validated_data.get("custom_fields") from_webui = serializer.validated_data.get("from_webui") t = int(mktime(datetime.now().timetuple())) @@ -1520,12 +1517,10 @@ class PostDocumentView(GenericAPIView): original_file=temp_file_path, ) custom_fields = None - if custom_fields_w_values: - custom_fields = { - cf_id: value for cf_id, value in custom_fields_w_values.items() - } - elif custom_field_ids: - custom_fields = {cf_id: None for cf_id in custom_field_ids} + if isinstance(cf, dict) and cf: + custom_fields = {cf_id: value for cf_id, value in cf.items()} + elif isinstance(cf, list) and cf: + custom_fields = {cf_id: None for cf_id in cf} input_doc_overrides = DocumentMetadataOverrides( filename=doc_name, title=title,