From 3115106dc1aab3f9f2995abba313d39686815b65 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:04:31 -0800 Subject: [PATCH] Enhancement: add basic filters for listing custom fields (#5257) --- src/documents/filters.py | 10 ++++++ src/documents/tests/test_api_custom_fields.py | 36 +++++++++++++++++++ src/documents/views.py | 6 ++++ 3 files changed, 52 insertions(+) diff --git a/src/documents/filters.py b/src/documents/filters.py index c63484ee2..bab20a4dc 100644 --- a/src/documents/filters.py +++ b/src/documents/filters.py @@ -12,6 +12,7 @@ from guardian.utils import get_user_obj_perms_model from rest_framework_guardian.filters import ObjectPermissionsFilter from documents.models import Correspondent +from documents.models import CustomField from documents.models import Document from documents.models import DocumentType from documents.models import Log @@ -141,6 +142,15 @@ class SharedByUser(Filter): ) +class CustomFieldFilterSet(FilterSet): + class Meta: + model = CustomField + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + } + + class CustomFieldsFilter(Filter): def filter(self, qs, value): if value: diff --git a/src/documents/tests/test_api_custom_fields.py b/src/documents/tests/test_api_custom_fields.py index af16d12b1..cf33e2800 100644 --- a/src/documents/tests/test_api_custom_fields.py +++ b/src/documents/tests/test_api_custom_fields.py @@ -662,3 +662,39 @@ class TestCustomField(DirectoriesMixin, APITestCase): self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(doc5.custom_fields.first().value, [1]) + + def test_custom_field_filters(self): + custom_field_string = CustomField.objects.create( + name="Test Custom Field String", + data_type=CustomField.FieldDataType.STRING, + ) + custom_field_date = CustomField.objects.create( + name="Test Custom Field Date", + data_type=CustomField.FieldDataType.DATE, + ) + custom_field_int = CustomField.objects.create( + name="Test Custom Field Int", + data_type=CustomField.FieldDataType.INT, + ) + + response = self.client.get( + f"{self.ENDPOINT}?id={custom_field_string.id}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data["results"] + self.assertEqual(len(results), 1) + + response = self.client.get( + f"{self.ENDPOINT}?id__in={custom_field_string.id},{custom_field_date.id}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data["results"] + self.assertEqual(len(results), 2) + + response = self.client.get( + f"{self.ENDPOINT}?name__icontains=Int", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data["results"] + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], custom_field_int.name) diff --git a/src/documents/views.py b/src/documents/views.py index 83f2fc321..d6b90cbfd 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -66,6 +66,7 @@ from documents.data_models import ConsumableDocument from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentSource from documents.filters import CorrespondentFilterSet +from documents.filters import CustomFieldFilterSet from documents.filters import DocumentFilterSet from documents.filters import DocumentTypeFilterSet from documents.filters import ObjectOwnedOrGrantedPermissionsFilter @@ -1438,6 +1439,11 @@ class CustomFieldViewSet(ModelViewSet): serializer_class = CustomFieldSerializer pagination_class = StandardPagination + filter_backends = ( + DjangoFilterBackend, + OrderingFilter, + ) + filterset_class = CustomFieldFilterSet model = CustomField