From a8eb460c3a8afa94580efb72cb36c009b4dcc413 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 12 Dec 2024 23:15:28 -0800 Subject: [PATCH] Sheesh, solved with subqueries --- src/documents/filters.py | 193 +++++++++++++++------------------------ 1 file changed, 76 insertions(+), 117 deletions(-) diff --git a/src/documents/filters.py b/src/documents/filters.py index 1f1a89af0..ff2cc2936 100644 --- a/src/documents/filters.py +++ b/src/documents/filters.py @@ -9,12 +9,12 @@ from django.contrib.contenttypes.models import ContentType from django.db.models import Case from django.db.models import CharField from django.db.models import Count -from django.db.models import DateTimeField from django.db.models import Exists -from django.db.models import FloatField from django.db.models import IntegerField from django.db.models import OuterRef from django.db.models import Q +from django.db.models import Subquery +from django.db.models import Sum from django.db.models import Value from django.db.models import When from django.db.models.functions import Cast @@ -785,133 +785,101 @@ class DocumentsOrderingFilter(OrderingFilter): annotation = None match field.data_type: case CustomField.FieldDataType.STRING: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_text", - output_field=CharField(), - ), - ), - default=Value(""), - output_field=CharField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_text")[:1], ) case CustomField.FieldDataType.INT: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_int", - output_field=IntegerField(), - ), - ), - default=Value(0), - output_field=IntegerField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_int")[:1], ) case CustomField.FieldDataType.FLOAT: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_float", - output_field=FloatField(), - ), - ), - default=Value(0), - output_field=FloatField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_float")[:1], ) case CustomField.FieldDataType.DATE: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_date", - output_field=DateTimeField(), - ), - ), - default=Value("1900-01-01T00:00:00Z"), - output_field=DateTimeField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_date")[:1], ) case CustomField.FieldDataType.MONETARY: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_monetary_amount", - output_field=FloatField(), - ), - ), - default=Value(0), - output_field=FloatField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_monetary_amount")[:1], ) case CustomField.FieldDataType.SELECT: - select_options = field.extra_data.get("select_options", []) - whens = [] - # Create a case for each select option - for option in select_options: - whens.append( - When( - custom_fields__field_id=custom_field_id, - custom_fields__value_select=option.get("id"), - then=Value(option["label"], output_field=CharField()), - ), + # Select options are a little more complicated since the value is the id of the option, not + # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a + # case statement for each option, setting the value to the index of the option in a list + # sorted by label, and then summing the results to give a single value for the annotation + + select_options = sorted( + field.extra_data.get("select_options", []), + key=lambda x: x.get("label"), + ) + whens = [ + When( + custom_fields__field_id=custom_field_id, + custom_fields__value_select=option.get("id"), + then=Value(idx, output_field=IntegerField()), ) - annotation = Case( - *whens, - default=Value(""), - output_field=CharField(), + for idx, option in enumerate(select_options) + ] + whens.append( + When( + custom_fields__field_id=custom_field_id, + custom_fields__value_select__isnull=True, + then=Value( + len(select_options), + output_field=IntegerField(), + ), + ), + ) + annotation = Sum( + Case( + *whens, + default=Value(0), + output_field=IntegerField(), + ), ) case CustomField.FieldDataType.DOCUMENTLINK: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_document_ids", - output_field=CharField(), - ), - ), - default=Value(""), - output_field=CharField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_document_ids")[:1], ) case CustomField.FieldDataType.URL: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - then=Cast( - "custom_fields__value_url", - output_field=CharField(), - ), - ), - default=Value(""), - output_field=CharField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_url")[:1], ) - case CustomField.FieldDataType.BOOL: - annotation = Case( - When( - custom_fields__field_id=custom_field_id, - custom_fields__value_bool=True, - then=Value( - 1, - output_field=IntegerField(), - ), - ), - When( - custom_fields__field_id=custom_field_id, - custom_fields__value_bool=False, - then=Value( - 0, - output_field=IntegerField(), - ), - ), - default=Value(0), - output_field=IntegerField(), + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_bool")[:1], ) if not annotation: raise ValueError("Invalid custom field data type") - ids_sorted_by_cf = ( + queryset = ( queryset.annotate( # We need to annotate the queryset with the custom field value custom_field_value=annotation, @@ -930,16 +898,7 @@ class DocumentsOrderingFilter(OrderingFilter): "custom_field_value", ), ) - .values_list("id", flat=True) + .distinct() ) - # We need to preserve the order of the ids sorted by custom field, see https://docs.djangoproject.com/en/dev/ref/models/querysets/#distinct - preserved = Case( - *[ - When(id=id, then=position) - for position, id in enumerate(ids_sorted_by_cf) - ], - ) - queryset = queryset.filter(id__in=ids_sorted_by_cf).order_by(preserved) - return super().filter_queryset(request, queryset, view)