From 707da938443f29a020e005bf36acbde2ccb87171 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 25 Jan 2025 18:13:51 -0800 Subject: [PATCH] Fix: use api version > 7 for new custom field select format --- docs/api.md | 5 ++ src-ui/src/environments/environment.prod.ts | 2 +- src-ui/src/environments/environment.ts | 2 +- src/documents/serialisers.py | 77 ++++++++++++++----- src/documents/tests/test_api_custom_fields.py | 64 ++++++++++++++- src/documents/views.py | 4 + src/paperless/settings.py | 4 +- 7 files changed, 131 insertions(+), 27 deletions(-) diff --git a/docs/api.md b/docs/api.md index 49e63c04b..59f9d7efc 100644 --- a/docs/api.md +++ b/docs/api.md @@ -573,3 +573,8 @@ Initial API version. #### Version 6 - Moved acknowledge tasks endpoint to be under `/api/tasks/acknowledge/`. + +#### Version 7 + +- The format of select type custom fields has changed to return the options + as an array of objects with `id` and `label` fields. diff --git a/src-ui/src/environments/environment.prod.ts b/src-ui/src/environments/environment.prod.ts index 702b584cb..d2108ee86 100644 --- a/src-ui/src/environments/environment.prod.ts +++ b/src-ui/src/environments/environment.prod.ts @@ -3,7 +3,7 @@ const base_url = new URL(document.baseURI) export const environment = { production: true, apiBaseUrl: document.baseURI + 'api/', - apiVersion: '6', + apiVersion: '7', appTitle: 'Paperless-ngx', version: '2.14.5', webSocketHost: window.location.host, diff --git a/src-ui/src/environments/environment.ts b/src-ui/src/environments/environment.ts index 6256f3ae3..2cad64ce0 100644 --- a/src-ui/src/environments/environment.ts +++ b/src-ui/src/environments/environment.ts @@ -5,7 +5,7 @@ export const environment = { production: false, apiBaseUrl: 'http://localhost:8000/api/', - apiVersion: '6', + apiVersion: '7', appTitle: 'Paperless-ngx', version: 'DEVELOPMENT', webSocketHost: 'localhost:8000', diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index eb1eba8f1..a9746ea2b 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -495,7 +495,27 @@ class StoragePathField(serializers.PrimaryKeyRelatedField): return StoragePath.objects.all() +class ReadWriteSerializerMethodField(serializers.SerializerMethodField): + """ + Based on https://stackoverflow.com/a/62579804 + """ + + def __init__(self, method_name=None, *args, **kwargs): + self.method_name = method_name + kwargs["source"] = "*" + super(serializers.SerializerMethodField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + return {self.field_name: data} + + class CustomFieldSerializer(serializers.ModelSerializer): + def __init__(self, *args, **kwargs): + self.api_version = int( + kwargs.pop("api_version", settings.REST_FRAMEWORK["ALLOWED_VERSIONS"][-1]), + ) + super().__init__(*args, **kwargs) + data_type = serializers.ChoiceField( choices=CustomField.FieldDataType, read_only=False, @@ -503,6 +523,8 @@ class CustomFieldSerializer(serializers.ModelSerializer): document_count = serializers.IntegerField(read_only=True) + extra_data = ReadWriteSerializerMethodField(required=False) + class Meta: model = CustomField fields = [ @@ -544,18 +566,39 @@ class CustomFieldSerializer(serializers.ModelSerializer): or "select_options" not in attrs["extra_data"] or not isinstance(attrs["extra_data"]["select_options"], list) or len(attrs["extra_data"]["select_options"]) == 0 - or not all( - len(option.get("label", "")) > 0 - for option in attrs["extra_data"]["select_options"] + or ( + # version 6 and below require a list of strings + self.api_version < 7 + and not all( + len(option) > 0 + for option in attrs["extra_data"]["select_options"] + ) + ) + or ( + # version 7 and above require a list of objects with labels + self.api_version >= 7 + and not all( + len(option.get("label", "")) > 0 + for option in attrs["extra_data"]["select_options"] + ) ) ): raise serializers.ValidationError( {"error": "extra_data.select_options must be a valid list"}, ) # labels are valid, generate ids if not present - for option in attrs["extra_data"]["select_options"]: - if option.get("id") is None: - option["id"] = get_random_string(length=16) + if self.api_version < 7: + attrs["extra_data"]["select_options"] = [ + { + "label": option, + "id": get_random_string(length=16), + } + for option in attrs["extra_data"]["select_options"] + ] + else: + for option in attrs["extra_data"]["select_options"]: + if option.get("id") is None: + option["id"] = get_random_string(length=16) elif ( "data_type" in attrs and attrs["data_type"] == CustomField.FieldDataType.MONETARY @@ -575,19 +618,15 @@ class CustomFieldSerializer(serializers.ModelSerializer): ) return super().validate(attrs) - -class ReadWriteSerializerMethodField(serializers.SerializerMethodField): - """ - Based on https://stackoverflow.com/a/62579804 - """ - - def __init__(self, method_name=None, *args, **kwargs): - self.method_name = method_name - kwargs["source"] = "*" - super(serializers.SerializerMethodField, self).__init__(*args, **kwargs) - - def to_internal_value(self, data): - return {self.field_name: data} + def get_extra_data(self, obj): + extra_data = obj.extra_data + if self.api_version < 7 and obj.data_type == CustomField.FieldDataType.SELECT: + # Convert the select options with ids to a list of strings + extra_data["select_options"] = [ + option["label"] for option in extra_data["select_options"] + ] + field = serializers.JSONField() + return field.to_representation(extra_data) class CustomFieldInstanceSerializer(serializers.ModelSerializer): diff --git a/src/documents/tests/test_api_custom_fields.py b/src/documents/tests/test_api_custom_fields.py index 8c809429f..cd15de006 100644 --- a/src/documents/tests/test_api_custom_fields.py +++ b/src/documents/tests/test_api_custom_fields.py @@ -43,10 +43,13 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase): ]: resp = self.client.post( self.ENDPOINT, - data={ - "data_type": field_type, - "name": name, - }, + data=json.dumps( + { + "data_type": field_type, + "name": name, + }, + ), + content_type="application/json", ) self.assertEqual(resp.status_code, status.HTTP_201_CREATED) @@ -272,6 +275,59 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase): doc.refresh_from_db() self.assertEqual(doc.custom_fields.first().value, None) + def test_custom_field_select_old_version(self): + """ + GIVEN: + - Select custom field exists with old version of select options + WHEN: + - API post request is made for custom fields with api version header < 7 + - API get request is made for custom fields with api version header < 7 + THEN: + - The select options are returned in the old format + """ + resp = self.client.post( + self.ENDPOINT, + headers={"Accept": "application/json; version=6"}, + data=json.dumps( + { + "data_type": "select", + "name": "Select Field", + "extra_data": { + "select_options": [ + "Option 1", + "Option 2", + ], + }, + }, + ), + content_type="application/json", + ) + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + + field = CustomField.objects.get(name="Select Field") + self.assertEqual( + field.extra_data["select_options"], + [ + {"label": "Option 1", "id": ANY}, + {"label": "Option 2", "id": ANY}, + ], + ) + + resp = self.client.get( + f"{self.ENDPOINT}{field.id}/", + headers={"Accept": "application/json; version=6"}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + data = resp.json() + self.assertEqual( + data["extra_data"]["select_options"], + [ + "Option 1", + "Option 2", + ], + ) + def test_create_custom_field_monetary_validation(self): """ GIVEN: diff --git a/src/documents/views.py b/src/documents/views.py index f98932a6f..a4e0e0f63 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -2065,6 +2065,10 @@ class CustomFieldViewSet(ModelViewSet): ) ) + def get_serializer(self, *args, **kwargs): + kwargs.setdefault("api_version", self.request.version) + return super().get_serializer(*args, **kwargs) + class SystemStatusView(PassUserMixin): permission_classes = (IsAuthenticated,) diff --git a/src/paperless/settings.py b/src/paperless/settings.py index ef842dde6..dcfdc020d 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -341,10 +341,10 @@ REST_FRAMEWORK = { "rest_framework.authentication.SessionAuthentication", ], "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.AcceptHeaderVersioning", - "DEFAULT_VERSION": "1", + "DEFAULT_VERSION": "7", # Make sure these are ordered and that the most recent version appears # last - "ALLOWED_VERSIONS": ["1", "2", "3", "4", "5", "6"], + "ALLOWED_VERSIONS": ["1", "2", "3", "4", "5", "6", "7"], } if DEBUG: