From a7a4ca86bdb0cc8344a2bc71e84d11f57bf1df37 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:19:38 -0800 Subject: [PATCH] Initial shot at backend custom field sorting --- src/documents/filters.py | 162 ++++++++++++++++++++++ src/documents/tests/test_api_documents.py | 139 +++++++++++++++++++ src/documents/views.py | 4 +- 3 files changed, 304 insertions(+), 1 deletion(-) diff --git a/src/documents/filters.py b/src/documents/filters.py index 237973b6f..b171d0127 100644 --- a/src/documents/filters.py +++ b/src/documents/filters.py @@ -6,10 +6,16 @@ from collections.abc import Callable from contextlib import contextmanager from django.contrib.contenttypes.models import ContentType +from django.db.models import Case from django.db.models import CharField from django.db.models import Count +from django.db.models import DateTimeField +from django.db.models import FloatField +from django.db.models import IntegerField from django.db.models import OuterRef from django.db.models import Q +from django.db.models import Value +from django.db.models import When from django.db.models.functions import Cast from django.utils.translation import gettext_lazy as _ from django_filters.rest_framework import BooleanFilter @@ -18,6 +24,7 @@ from django_filters.rest_framework import FilterSet from guardian.utils import get_group_obj_perms_model from guardian.utils import get_user_obj_perms_model from rest_framework import serializers +from rest_framework.filters import OrderingFilter from rest_framework_guardian.filters import ObjectPermissionsFilter from documents.models import Correspondent @@ -760,3 +767,158 @@ class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter): objects_owned = queryset.filter(owner=request.user) objects_unowned = queryset.filter(owner__isnull=True) return objects_owned | objects_unowned + + +class DocumentsOrderingFilter(OrderingFilter): + field_name = "ordering" + prefix = "custom_field_" + + def __init__(self, *args, **kwargs): + super().__init__() + + def filter_queryset(self, request, queryset, view): + param = request.query_params.get("ordering") + if param and self.prefix in param: + custom_field_id = int(param.split(self.prefix)[1]) + field = CustomField.objects.get(pk=custom_field_id) + if not field: + raise ValueError("Custom field not found") + + annotation = None + match field.data_type: + case CustomField.FieldDataType.STRING: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_text", + output_field=CharField(), + ), + ), + default=Value(""), + output_field=CharField(), + ) + case CustomField.FieldDataType.INT: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_int", + output_field=IntegerField(), + ), + ), + default=Value(0), + output_field=IntegerField(), + ) + case CustomField.FieldDataType.FLOAT: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_float", + output_field=FloatField(), + ), + ), + default=Value(0), + output_field=FloatField(), + ) + case CustomField.FieldDataType.DATE: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_date", + output_field=DateTimeField(), + ), + ), + default=Value("1900-01-01T00:00:00Z"), + output_field=DateTimeField(), + ) + case CustomField.FieldDataType.MONETARY: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_monetary_amount", + output_field=FloatField(), + ), + ), + default=Value(0), + output_field=FloatField(), + ) + case CustomField.FieldDataType.SELECT: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_select_name", + output_field=CharField(), + ), + ), + default=Value(""), + output_field=CharField(), + ) + case CustomField.FieldDataType.DOCUMENTLINK: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_document_ids", + output_field=CharField(), + ), + ), + default=Value(""), + output_field=CharField(), + ) + case CustomField.FieldDataType.URL: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_url", + output_field=CharField(), + ), + ), + default=Value(""), + output_field=CharField(), + ) + + case CustomField.FieldDataType.BOOL: + annotation = Case( + When( + custom_fields__field_id=custom_field_id, + then=Cast( + "custom_fields__value_bool", + output_field=IntegerField(), + ), + ), + default=Value(0), + output_field=IntegerField(), + ) + + if not annotation: + raise ValueError("Invalid custom field data type") + + queryset = queryset.annotate( + # We need to annotate the queryset with the custom field value + custom_field_value=annotation, + # We also need to annotate the queryset with a boolean for sorting whether the field exists + has_field=Case( + When( + custom_fields__field_id=custom_field_id, + then=Value(1), + ), + default=Value(0), + output_field=IntegerField(), + ), + ) + + return queryset.order_by( + "-has_field", + param.replace( + self.prefix + str(custom_field_id), + "custom_field_value", + ), + ) + + return super().filter_queryset(request, queryset, view) diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 8307d6c4c..70daa5d17 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -5,6 +5,7 @@ import tempfile import uuid import zoneinfo from binascii import hexlify +from datetime import date from datetime import timedelta from pathlib import Path from unittest import mock @@ -2762,3 +2763,141 @@ class TestDocumentApiV2(DirectoriesMixin, APITestCase): self.client.get(f"/api/tags/{t.id}/", format="json").data["text_color"], "#000000", ) + + +class TestDocumentApiCustomFieldsSorting(DirectoriesMixin, APITestCase): + def setUp(self): + super().setUp() + + self.user = User.objects.create_superuser(username="temp_admin") + self.client.force_authenticate(user=self.user) + + self.doc1 = Document.objects.create( + title="none1", + checksum="A", + mime_type="application/pdf", + ) + self.doc2 = Document.objects.create( + title="none2", + checksum="B", + mime_type="application/pdf", + ) + self.doc3 = Document.objects.create( + title="none3", + checksum="C", + mime_type="application/pdf", + ) + + cache.clear() + + def test_document_custom_fields_sorting(self): + """ + GIVEN: + - Documents with custom fields + WHEN: + - API request for document filtering with custom field sorting + THEN: + - Documents are sorted by custom field values + """ + values = { + CustomField.FieldDataType.STRING: { + "values": ["foo", "bar", "baz"], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.STRING + ], + }, + CustomField.FieldDataType.INT: { + "values": [1, 2, 3], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.INT + ], + }, + CustomField.FieldDataType.FLOAT: { + "values": [1.1, 2.2, 3.3], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.FLOAT + ], + }, + CustomField.FieldDataType.BOOL: { + "values": [True, False, False], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.BOOL + ], + }, + CustomField.FieldDataType.DATE: { + "values": [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.DATE + ], + }, + CustomField.FieldDataType.URL: { + "values": [ + "http://example.com", + "http://example.net", + "http://example.org", + ], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.URL + ], + }, + CustomField.FieldDataType.MONETARY: { + "values": ["USD123.00", "USD456.00", "USD789.00"], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.MONETARY + ], + }, + CustomField.FieldDataType.DOCUMENTLINK: { + "values": [self.doc1.pk, self.doc2.pk, self.doc3.pk], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.DOCUMENTLINK + ], + }, + CustomField.FieldDataType.SELECT: { + "values": ["abc-123", "def-456", "ghi-789"], + "field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + CustomField.FieldDataType.SELECT + ], + "extra_data": { + "select_options": [ + {"label": "Option 1", "id": "abc-123"}, + {"label": "Option 2", "id": "def-456"}, + {"label": "Option 3", "id": "ghi-789"}, + ], + }, + }, + } + + for data_type, data in values.items(): + custom_field = CustomField.objects.create( + name=f"custom field {data_type}", + data_type=data_type, + extra_data=data.get("extra_data", {}), + ) + for i, value in enumerate(data["values"]): + CustomFieldInstance.objects.create( + document=[self.doc1, self.doc2, self.doc3][i], + field=custom_field, + **{data["field_name"]: value}, + ) + + response = self.client.get( + f"/api/documents/?ordering=custom_fields__{custom_field.pk}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data["results"] + self.assertEqual(len(results), 3) + self.assertEqual( + [results[0]["id"], results[1]["id"], results[2]["id"]], + [self.doc3.id, self.doc2.id, self.doc1.id], + ) + + response = self.client.get( + f"/api/documents/?ordering=-custom_fields__{custom_field.pk}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data["results"] + self.assertEqual(len(results), 3) + self.assertEqual( + [results[0]["id"], results[1]["id"], results[2]["id"]], + [self.doc3.id, self.doc2.id, self.doc1.id], + ) diff --git a/src/documents/views.py b/src/documents/views.py index 6d2c8cbd8..4e2e4a8bf 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -96,6 +96,7 @@ from documents.data_models import DocumentSource from documents.filters import CorrespondentFilterSet from documents.filters import CustomFieldFilterSet from documents.filters import DocumentFilterSet +from documents.filters import DocumentsOrderingFilter from documents.filters import DocumentTypeFilterSet from documents.filters import ObjectOwnedOrGrantedPermissionsFilter from documents.filters import ObjectOwnedPermissionsFilter @@ -350,7 +351,7 @@ class DocumentViewSet( filter_backends = ( DjangoFilterBackend, SearchFilter, - OrderingFilter, + DocumentsOrderingFilter, ObjectOwnedOrGrantedPermissionsFilter, ) filterset_class = DocumentFilterSet @@ -367,6 +368,7 @@ class DocumentViewSet( "num_notes", "owner", "page_count", + "custom_field_", ) def get_queryset(self):