diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index d163e769d..dda834094 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -495,20 +495,6 @@ class StoragePathField(serializers.PrimaryKeyRelatedField): return StoragePath.objects.all() -class ReadWriteSerializerMethodField(serializers.SerializerMethodField): - """ - Based on https://stackoverflow.com/a/62579804 - """ - - def __init__(self, method_name=None, *args, **kwargs): - self.method_name = method_name - kwargs["source"] = "*" - super(serializers.SerializerMethodField, self).__init__(*args, **kwargs) - - def to_internal_value(self, data): - return {self.field_name: data} - - class CustomFieldSerializer(serializers.ModelSerializer): def __init__(self, *args, **kwargs): context = kwargs.get("context") @@ -526,8 +512,6 @@ class CustomFieldSerializer(serializers.ModelSerializer): document_count = serializers.IntegerField(read_only=True) - extra_data = ReadWriteSerializerMethodField(required=False) - class Meta: model = CustomField fields = [ @@ -569,39 +553,18 @@ class CustomFieldSerializer(serializers.ModelSerializer): or "select_options" not in attrs["extra_data"] or not isinstance(attrs["extra_data"]["select_options"], list) or len(attrs["extra_data"]["select_options"]) == 0 - or ( - # version 6 and below require a list of strings - self.api_version < 7 - and not all( - len(option) > 0 - for option in attrs["extra_data"]["select_options"] - ) - ) - or ( - # version 7 and above require a list of objects with labels - self.api_version >= 7 - and not all( - len(option.get("label", "")) > 0 - for option in attrs["extra_data"]["select_options"] - ) + or not all( + len(option.get("label", "")) > 0 + for option in attrs["extra_data"]["select_options"] ) ): raise serializers.ValidationError( {"error": "extra_data.select_options must be a valid list"}, ) # labels are valid, generate ids if not present - if self.api_version < 7: - attrs["extra_data"]["select_options"] = [ - { - "label": option, - "id": get_random_string(length=16), - } - for option in attrs["extra_data"]["select_options"] - ] - else: - for option in attrs["extra_data"]["select_options"]: - if option.get("id") is None: - option["id"] = get_random_string(length=16) + for option in attrs["extra_data"]["select_options"]: + if option.get("id") is None: + option["id"] = get_random_string(length=16) elif ( "data_type" in attrs and attrs["data_type"] == CustomField.FieldDataType.MONETARY @@ -621,14 +584,51 @@ class CustomFieldSerializer(serializers.ModelSerializer): ) return super().validate(attrs) - def get_extra_data(self, obj): - extra_data = obj.extra_data - if self.api_version < 7 and obj.data_type == CustomField.FieldDataType.SELECT: - # Convert the select options with ids to a list of strings - extra_data["select_options"] = [ - option["label"] for option in extra_data["select_options"] + def to_internal_value(self, data): + ret = super().to_internal_value(data) + + if ( + self.api_version < 7 + and ret.get("data_type", "") == CustomField.FieldDataType.SELECT + and isinstance(ret.get("extra_data", {}).get("select_options"), list) + ): + ret["extra_data"]["select_options"] = [ + { + "label": option, + "id": get_random_string(length=16), + } + for option in ret["extra_data"]["select_options"] ] - return serializers.JSONField().to_representation(extra_data) + + return ret + + def to_representation(self, instance): + ret = super().to_representation(instance) + + if ( + self.api_version < 7 + and instance.data_type == CustomField.FieldDataType.SELECT + ): + # Convert the select options with ids to a list of strings + ret["extra_data"]["select_options"] = [ + option["label"] for option in ret["extra_data"]["select_options"] + ] + + return ret + + +class ReadWriteSerializerMethodField(serializers.SerializerMethodField): + """ + Based on https://stackoverflow.com/a/62579804 + """ + + def __init__(self, method_name=None, *args, **kwargs): + self.method_name = method_name + kwargs["source"] = "*" + super(serializers.SerializerMethodField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + return {self.field_name: data} class CustomFieldInstanceSerializer(serializers.ModelSerializer): @@ -659,21 +659,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer): return instance def get_value(self, obj: CustomFieldInstance): - api_version = int( - self.context.get("request").version - if self.context.get("request") - else settings.REST_FRAMEWORK["DEFAULT_VERSION"], - ) - if api_version < 7 and obj.field.data_type == CustomField.FieldDataType.SELECT: - # return the index of the option in the field.extra_data["select_options"] list - return next( - ( - idx - for idx, option in enumerate(obj.field.extra_data["select_options"]) - if option["id"] == obj.value - ), - None, - ) return obj.value def validate(self, data): @@ -686,11 +671,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer): """ data = super().validate(data) field: CustomField = data["field"] - api_version = int( - self.context.get("request").version - if self.context.get("request") - else settings.REST_FRAMEWORK["DEFAULT_VERSION"], - ) if "value" in data and data["value"] is not None: if ( @@ -720,13 +700,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer): elif field.data_type == CustomField.FieldDataType.SELECT: select_options = field.extra_data["select_options"] - if api_version < 7: - # Convert the index of the option in the field.extra_data["select_options"] - # list to the options unique id - data["value"] = field.extra_data["select_options"][data["value"]][ - "id" - ] - try: next( option @@ -752,6 +725,53 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer): return data + def to_internal_value(self, data): + ret = super().to_internal_value(data) + + api_version = int( + self.context.get("request").version + if self.context.get("request") + else settings.REST_FRAMEWORK["DEFAULT_VERSION"], + ) + if ( + api_version < 7 + and ret.get("field").data_type == CustomField.FieldDataType.SELECT + and ret.get("value") is not None + ): + # Convert the index of the option in the field.extra_data["select_options"] + # list to the options unique id + ret["value"] = ret.get("field").extra_data["select_options"][ret["value"]][ + "id" + ] + + return ret + + def to_representation(self, instance): + ret = super().to_representation(instance) + + api_version = int( + self.context.get("request").version + if self.context.get("request") + else settings.REST_FRAMEWORK["DEFAULT_VERSION"], + ) + if ( + api_version < 7 + and instance.field.data_type == CustomField.FieldDataType.SELECT + ): + # return the index of the option in the field.extra_data["select_options"] list + ret["value"] = next( + ( + idx + for idx, option in enumerate( + instance.field.extra_data["select_options"], + ) + if option["id"] == instance.value + ), + None, + ) + + return ret + def reflect_doclinks( self, document: Document, diff --git a/src/documents/tests/test_api_custom_fields.py b/src/documents/tests/test_api_custom_fields.py index 8cb42aa8d..8a9e32174 100644 --- a/src/documents/tests/test_api_custom_fields.py +++ b/src/documents/tests/test_api_custom_fields.py @@ -249,7 +249,8 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase): resp = self.client.patch( f"{self.ENDPOINT}{custom_field_select.id}/", - json.dumps( + headers={"Accept": "application/json; version=7"}, + data=json.dumps( { "extra_data": { "select_options": [ diff --git a/src/documents/views.py b/src/documents/views.py index 2868c1e4e..f98932a6f 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -2065,10 +2065,6 @@ class CustomFieldViewSet(ModelViewSet): ) ) - def get_serializer(self, *args, **kwargs): - kwargs.setdefault("context", self.get_serializer_context()) - return super().get_serializer(*args, **kwargs) - class SystemStatusView(PassUserMixin): permission_classes = (IsAuthenticated,)