diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 5c90c6f1c..ff4351895 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -578,30 +578,34 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer): ), ) def get_children(self, obj): - filter_q = self.context.get("document_count_filter") - request = self.context.get("request") - if filter_q is None: - user = getattr(request, "user", None) if request else None - filter_q = get_document_count_filter_for_user(user) - self.context["document_count_filter"] = filter_q + children_map = self.context.get("children_map") + if children_map is not None: + children = children_map.get(obj.pk, []) + else: + filter_q = self.context.get("document_count_filter") + request = self.context.get("request") + if filter_q is None: + user = getattr(request, "user", None) if request else None + filter_q = get_document_count_filter_for_user(user) + self.context["document_count_filter"] = filter_q - children_queryset = ( - obj.get_children_queryset() - .select_related("owner") - .annotate(document_count=Count("documents", filter=filter_q)) - ) + children = ( + obj.get_children_queryset() + .select_related("owner") + .annotate(document_count=Count("documents", filter=filter_q)) + ) - view = self.context.get("view") - ordering = ( - OrderingFilter().get_ordering(request, children_queryset, view) - if request and view - else None - ) - ordering = ordering or (Lower("name"),) - children_queryset = children_queryset.order_by(*ordering) + view = self.context.get("view") + ordering = ( + OrderingFilter().get_ordering(request, children, view) + if request and view + else None + ) + ordering = ordering or (Lower("name"),) + children = children.order_by(*ordering) serializer = TagSerializer( - children_queryset, + children, many=True, user=self.user, full_perms=self.full_perms, diff --git a/src/documents/views.py b/src/documents/views.py index ba265926c..244e1cfcd 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -448,8 +448,43 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin): def get_serializer_context(self): context = super().get_serializer_context() context["document_count_filter"] = self.get_document_count_filter() + if hasattr(self, "_children_map"): + context["children_map"] = self._children_map return context + def list(self, request, *args, **kwargs): + """ + Build a children map once to avoid per-parent queries in the serializer. + """ + queryset = self.filter_queryset(self.get_queryset()) + ordering = OrderingFilter().get_ordering(request, queryset, self) or ( + Lower("name"), + ) + queryset = queryset.order_by(*ordering) + + all_tags = list(queryset) + descendant_pks = {pk for tag in all_tags for pk in tag.get_descendants_pks()} + + if descendant_pks: + filter_q = self.get_document_count_filter() + children_source = ( + Tag.objects.filter(pk__in=descendant_pks | {t.pk for t in all_tags}) + .select_related("owner") + .annotate(document_count=Count("documents", filter=filter_q)) + .order_by(*ordering) + ) + else: + children_source = all_tags + + children_map = {} + for tag in children_source: + children_map.setdefault(tag.tn_parent_id, []).append(tag) + self._children_map = children_map + + page = self.paginate_queryset(queryset) + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + def perform_update(self, serializer): old_parent = self.get_object().get_parent() tag = serializer.save()