Initial shot at backend custom field sorting

This commit is contained in:
shamoon 2024-12-10 12:19:38 -08:00
parent e44cfef662
commit a7a4ca86bd
3 changed files with 304 additions and 1 deletions

View File

@ -6,10 +6,16 @@ from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db.models import Case
from django.db.models import CharField from django.db.models import CharField
from django.db.models import Count from django.db.models import Count
from django.db.models import DateTimeField
from django.db.models import FloatField
from django.db.models import IntegerField
from django.db.models import OuterRef from django.db.models import OuterRef
from django.db.models import Q from django.db.models import Q
from django.db.models import Value
from django.db.models import When
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_filters.rest_framework import BooleanFilter from django_filters.rest_framework import BooleanFilter
@ -18,6 +24,7 @@ from django_filters.rest_framework import FilterSet
from guardian.utils import get_group_obj_perms_model from guardian.utils import get_group_obj_perms_model
from guardian.utils import get_user_obj_perms_model from guardian.utils import get_user_obj_perms_model
from rest_framework import serializers from rest_framework import serializers
from rest_framework.filters import OrderingFilter
from rest_framework_guardian.filters import ObjectPermissionsFilter from rest_framework_guardian.filters import ObjectPermissionsFilter
from documents.models import Correspondent from documents.models import Correspondent
@ -760,3 +767,158 @@ class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
objects_owned = queryset.filter(owner=request.user) objects_owned = queryset.filter(owner=request.user)
objects_unowned = queryset.filter(owner__isnull=True) objects_unowned = queryset.filter(owner__isnull=True)
return objects_owned | objects_unowned return objects_owned | objects_unowned
class DocumentsOrderingFilter(OrderingFilter):
field_name = "ordering"
prefix = "custom_field_"
def __init__(self, *args, **kwargs):
super().__init__()
def filter_queryset(self, request, queryset, view):
param = request.query_params.get("ordering")
if param and self.prefix in param:
custom_field_id = int(param.split(self.prefix)[1])
field = CustomField.objects.get(pk=custom_field_id)
if not field:
raise ValueError("Custom field not found")
annotation = None
match field.data_type:
case CustomField.FieldDataType.STRING:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_text",
output_field=CharField(),
),
),
default=Value(""),
output_field=CharField(),
)
case CustomField.FieldDataType.INT:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_int",
output_field=IntegerField(),
),
),
default=Value(0),
output_field=IntegerField(),
)
case CustomField.FieldDataType.FLOAT:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_float",
output_field=FloatField(),
),
),
default=Value(0),
output_field=FloatField(),
)
case CustomField.FieldDataType.DATE:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_date",
output_field=DateTimeField(),
),
),
default=Value("1900-01-01T00:00:00Z"),
output_field=DateTimeField(),
)
case CustomField.FieldDataType.MONETARY:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_monetary_amount",
output_field=FloatField(),
),
),
default=Value(0),
output_field=FloatField(),
)
case CustomField.FieldDataType.SELECT:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_select_name",
output_field=CharField(),
),
),
default=Value(""),
output_field=CharField(),
)
case CustomField.FieldDataType.DOCUMENTLINK:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_document_ids",
output_field=CharField(),
),
),
default=Value(""),
output_field=CharField(),
)
case CustomField.FieldDataType.URL:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_url",
output_field=CharField(),
),
),
default=Value(""),
output_field=CharField(),
)
case CustomField.FieldDataType.BOOL:
annotation = Case(
When(
custom_fields__field_id=custom_field_id,
then=Cast(
"custom_fields__value_bool",
output_field=IntegerField(),
),
),
default=Value(0),
output_field=IntegerField(),
)
if not annotation:
raise ValueError("Invalid custom field data type")
queryset = queryset.annotate(
# We need to annotate the queryset with the custom field value
custom_field_value=annotation,
# We also need to annotate the queryset with a boolean for sorting whether the field exists
has_field=Case(
When(
custom_fields__field_id=custom_field_id,
then=Value(1),
),
default=Value(0),
output_field=IntegerField(),
),
)
return queryset.order_by(
"-has_field",
param.replace(
self.prefix + str(custom_field_id),
"custom_field_value",
),
)
return super().filter_queryset(request, queryset, view)

View File

@ -5,6 +5,7 @@ import tempfile
import uuid import uuid
import zoneinfo import zoneinfo
from binascii import hexlify from binascii import hexlify
from datetime import date
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
@ -2762,3 +2763,141 @@ class TestDocumentApiV2(DirectoriesMixin, APITestCase):
self.client.get(f"/api/tags/{t.id}/", format="json").data["text_color"], self.client.get(f"/api/tags/{t.id}/", format="json").data["text_color"],
"#000000", "#000000",
) )
class TestDocumentApiCustomFieldsSorting(DirectoriesMixin, APITestCase):
def setUp(self):
super().setUp()
self.user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=self.user)
self.doc1 = Document.objects.create(
title="none1",
checksum="A",
mime_type="application/pdf",
)
self.doc2 = Document.objects.create(
title="none2",
checksum="B",
mime_type="application/pdf",
)
self.doc3 = Document.objects.create(
title="none3",
checksum="C",
mime_type="application/pdf",
)
cache.clear()
def test_document_custom_fields_sorting(self):
"""
GIVEN:
- Documents with custom fields
WHEN:
- API request for document filtering with custom field sorting
THEN:
- Documents are sorted by custom field values
"""
values = {
CustomField.FieldDataType.STRING: {
"values": ["foo", "bar", "baz"],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.STRING
],
},
CustomField.FieldDataType.INT: {
"values": [1, 2, 3],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.INT
],
},
CustomField.FieldDataType.FLOAT: {
"values": [1.1, 2.2, 3.3],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.FLOAT
],
},
CustomField.FieldDataType.BOOL: {
"values": [True, False, False],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.BOOL
],
},
CustomField.FieldDataType.DATE: {
"values": [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.DATE
],
},
CustomField.FieldDataType.URL: {
"values": [
"http://example.com",
"http://example.net",
"http://example.org",
],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.URL
],
},
CustomField.FieldDataType.MONETARY: {
"values": ["USD123.00", "USD456.00", "USD789.00"],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.MONETARY
],
},
CustomField.FieldDataType.DOCUMENTLINK: {
"values": [self.doc1.pk, self.doc2.pk, self.doc3.pk],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.DOCUMENTLINK
],
},
CustomField.FieldDataType.SELECT: {
"values": ["abc-123", "def-456", "ghi-789"],
"field_name": CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
CustomField.FieldDataType.SELECT
],
"extra_data": {
"select_options": [
{"label": "Option 1", "id": "abc-123"},
{"label": "Option 2", "id": "def-456"},
{"label": "Option 3", "id": "ghi-789"},
],
},
},
}
for data_type, data in values.items():
custom_field = CustomField.objects.create(
name=f"custom field {data_type}",
data_type=data_type,
extra_data=data.get("extra_data", {}),
)
for i, value in enumerate(data["values"]):
CustomFieldInstance.objects.create(
document=[self.doc1, self.doc2, self.doc3][i],
field=custom_field,
**{data["field_name"]: value},
)
response = self.client.get(
f"/api/documents/?ordering=custom_fields__{custom_field.pk}",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data["results"]
self.assertEqual(len(results), 3)
self.assertEqual(
[results[0]["id"], results[1]["id"], results[2]["id"]],
[self.doc3.id, self.doc2.id, self.doc1.id],
)
response = self.client.get(
f"/api/documents/?ordering=-custom_fields__{custom_field.pk}",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data["results"]
self.assertEqual(len(results), 3)
self.assertEqual(
[results[0]["id"], results[1]["id"], results[2]["id"]],
[self.doc3.id, self.doc2.id, self.doc1.id],
)

View File

@ -96,6 +96,7 @@ from documents.data_models import DocumentSource
from documents.filters import CorrespondentFilterSet from documents.filters import CorrespondentFilterSet
from documents.filters import CustomFieldFilterSet from documents.filters import CustomFieldFilterSet
from documents.filters import DocumentFilterSet from documents.filters import DocumentFilterSet
from documents.filters import DocumentsOrderingFilter
from documents.filters import DocumentTypeFilterSet from documents.filters import DocumentTypeFilterSet
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
from documents.filters import ObjectOwnedPermissionsFilter from documents.filters import ObjectOwnedPermissionsFilter
@ -350,7 +351,7 @@ class DocumentViewSet(
filter_backends = ( filter_backends = (
DjangoFilterBackend, DjangoFilterBackend,
SearchFilter, SearchFilter,
OrderingFilter, DocumentsOrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter, ObjectOwnedOrGrantedPermissionsFilter,
) )
filterset_class = DocumentFilterSet filterset_class = DocumentFilterSet
@ -367,6 +368,7 @@ class DocumentViewSet(
"num_notes", "num_notes",
"owner", "owner",
"page_count", "page_count",
"custom_field_",
) )
def get_queryset(self): def get_queryset(self):