This commit is contained in:
shamoon 2025-01-25 22:27:30 -08:00
parent 1ee7728fc5
commit 2ae70b6482
3 changed files with 99 additions and 82 deletions

View File

@ -495,20 +495,6 @@ class StoragePathField(serializers.PrimaryKeyRelatedField):
return StoragePath.objects.all() return StoragePath.objects.all()
class ReadWriteSerializerMethodField(serializers.SerializerMethodField):
"""
Based on https://stackoverflow.com/a/62579804
"""
def __init__(self, method_name=None, *args, **kwargs):
self.method_name = method_name
kwargs["source"] = "*"
super(serializers.SerializerMethodField, self).__init__(*args, **kwargs)
def to_internal_value(self, data):
return {self.field_name: data}
class CustomFieldSerializer(serializers.ModelSerializer): class CustomFieldSerializer(serializers.ModelSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
context = kwargs.get("context") context = kwargs.get("context")
@ -526,8 +512,6 @@ class CustomFieldSerializer(serializers.ModelSerializer):
document_count = serializers.IntegerField(read_only=True) document_count = serializers.IntegerField(read_only=True)
extra_data = ReadWriteSerializerMethodField(required=False)
class Meta: class Meta:
model = CustomField model = CustomField
fields = [ fields = [
@ -569,36 +553,15 @@ class CustomFieldSerializer(serializers.ModelSerializer):
or "select_options" not in attrs["extra_data"] or "select_options" not in attrs["extra_data"]
or not isinstance(attrs["extra_data"]["select_options"], list) or not isinstance(attrs["extra_data"]["select_options"], list)
or len(attrs["extra_data"]["select_options"]) == 0 or len(attrs["extra_data"]["select_options"]) == 0
or ( or not all(
# version 6 and below require a list of strings
self.api_version < 7
and not all(
len(option) > 0
for option in attrs["extra_data"]["select_options"]
)
)
or (
# version 7 and above require a list of objects with labels
self.api_version >= 7
and not all(
len(option.get("label", "")) > 0 len(option.get("label", "")) > 0
for option in attrs["extra_data"]["select_options"] for option in attrs["extra_data"]["select_options"]
) )
)
): ):
raise serializers.ValidationError( raise serializers.ValidationError(
{"error": "extra_data.select_options must be a valid list"}, {"error": "extra_data.select_options must be a valid list"},
) )
# labels are valid, generate ids if not present # labels are valid, generate ids if not present
if self.api_version < 7:
attrs["extra_data"]["select_options"] = [
{
"label": option,
"id": get_random_string(length=16),
}
for option in attrs["extra_data"]["select_options"]
]
else:
for option in attrs["extra_data"]["select_options"]: for option in attrs["extra_data"]["select_options"]:
if option.get("id") is None: if option.get("id") is None:
option["id"] = get_random_string(length=16) option["id"] = get_random_string(length=16)
@ -621,14 +584,51 @@ class CustomFieldSerializer(serializers.ModelSerializer):
) )
return super().validate(attrs) return super().validate(attrs)
def get_extra_data(self, obj): def to_internal_value(self, data):
extra_data = obj.extra_data ret = super().to_internal_value(data)
if self.api_version < 7 and obj.data_type == CustomField.FieldDataType.SELECT:
# Convert the select options with ids to a list of strings if (
extra_data["select_options"] = [ self.api_version < 7
option["label"] for option in extra_data["select_options"] and ret.get("data_type", "") == CustomField.FieldDataType.SELECT
and isinstance(ret.get("extra_data", {}).get("select_options"), list)
):
ret["extra_data"]["select_options"] = [
{
"label": option,
"id": get_random_string(length=16),
}
for option in ret["extra_data"]["select_options"]
] ]
return serializers.JSONField().to_representation(extra_data)
return ret
def to_representation(self, instance):
ret = super().to_representation(instance)
if (
self.api_version < 7
and instance.data_type == CustomField.FieldDataType.SELECT
):
# Convert the select options with ids to a list of strings
ret["extra_data"]["select_options"] = [
option["label"] for option in ret["extra_data"]["select_options"]
]
return ret
class ReadWriteSerializerMethodField(serializers.SerializerMethodField):
"""
Based on https://stackoverflow.com/a/62579804
"""
def __init__(self, method_name=None, *args, **kwargs):
self.method_name = method_name
kwargs["source"] = "*"
super(serializers.SerializerMethodField, self).__init__(*args, **kwargs)
def to_internal_value(self, data):
return {self.field_name: data}
class CustomFieldInstanceSerializer(serializers.ModelSerializer): class CustomFieldInstanceSerializer(serializers.ModelSerializer):
@ -659,21 +659,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
return instance return instance
def get_value(self, obj: CustomFieldInstance): def get_value(self, obj: CustomFieldInstance):
api_version = int(
self.context.get("request").version
if self.context.get("request")
else settings.REST_FRAMEWORK["DEFAULT_VERSION"],
)
if api_version < 7 and obj.field.data_type == CustomField.FieldDataType.SELECT:
# return the index of the option in the field.extra_data["select_options"] list
return next(
(
idx
for idx, option in enumerate(obj.field.extra_data["select_options"])
if option["id"] == obj.value
),
None,
)
return obj.value return obj.value
def validate(self, data): def validate(self, data):
@ -686,11 +671,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
""" """
data = super().validate(data) data = super().validate(data)
field: CustomField = data["field"] field: CustomField = data["field"]
api_version = int(
self.context.get("request").version
if self.context.get("request")
else settings.REST_FRAMEWORK["DEFAULT_VERSION"],
)
if "value" in data and data["value"] is not None: if "value" in data and data["value"] is not None:
if ( if (
@ -720,13 +700,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
elif field.data_type == CustomField.FieldDataType.SELECT: elif field.data_type == CustomField.FieldDataType.SELECT:
select_options = field.extra_data["select_options"] select_options = field.extra_data["select_options"]
if api_version < 7:
# Convert the index of the option in the field.extra_data["select_options"]
# list to the options unique id
data["value"] = field.extra_data["select_options"][data["value"]][
"id"
]
try: try:
next( next(
option option
@ -752,6 +725,53 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
return data return data
def to_internal_value(self, data):
ret = super().to_internal_value(data)
api_version = int(
self.context.get("request").version
if self.context.get("request")
else settings.REST_FRAMEWORK["DEFAULT_VERSION"],
)
if (
api_version < 7
and ret.get("field").data_type == CustomField.FieldDataType.SELECT
and ret.get("value") is not None
):
# Convert the index of the option in the field.extra_data["select_options"]
# list to the options unique id
ret["value"] = ret.get("field").extra_data["select_options"][ret["value"]][
"id"
]
return ret
def to_representation(self, instance):
ret = super().to_representation(instance)
api_version = int(
self.context.get("request").version
if self.context.get("request")
else settings.REST_FRAMEWORK["DEFAULT_VERSION"],
)
if (
api_version < 7
and instance.field.data_type == CustomField.FieldDataType.SELECT
):
# return the index of the option in the field.extra_data["select_options"] list
ret["value"] = next(
(
idx
for idx, option in enumerate(
instance.field.extra_data["select_options"],
)
if option["id"] == instance.value
),
None,
)
return ret
def reflect_doclinks( def reflect_doclinks(
self, self,
document: Document, document: Document,

View File

@ -249,7 +249,8 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
resp = self.client.patch( resp = self.client.patch(
f"{self.ENDPOINT}{custom_field_select.id}/", f"{self.ENDPOINT}{custom_field_select.id}/",
json.dumps( headers={"Accept": "application/json; version=7"},
data=json.dumps(
{ {
"extra_data": { "extra_data": {
"select_options": [ "select_options": [

View File

@ -2065,10 +2065,6 @@ class CustomFieldViewSet(ModelViewSet):
) )
) )
def get_serializer(self, *args, **kwargs):
kwargs.setdefault("context", self.get_serializer_context())
return super().get_serializer(*args, **kwargs)
class SystemStatusView(PassUserMixin): class SystemStatusView(PassUserMixin):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)