Optimize tag children retrieval

This commit is contained in:
shamoon
2025-12-15 10:56:23 -08:00
parent 917e3f3c60
commit 02a63144cf
2 changed files with 59 additions and 20 deletions

View File

@@ -578,30 +578,34 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
),
)
def get_children(self, obj):
filter_q = self.context.get("document_count_filter")
request = self.context.get("request")
if filter_q is None:
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
children_map = self.context.get("children_map")
if children_map is not None:
children = children_map.get(obj.pk, [])
else:
filter_q = self.context.get("document_count_filter")
request = self.context.get("request")
if filter_q is None:
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
children_queryset = (
obj.get_children_queryset()
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
)
children = (
obj.get_children_queryset()
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
)
view = self.context.get("view")
ordering = (
OrderingFilter().get_ordering(request, children_queryset, view)
if request and view
else None
)
ordering = ordering or (Lower("name"),)
children_queryset = children_queryset.order_by(*ordering)
view = self.context.get("view")
ordering = (
OrderingFilter().get_ordering(request, children, view)
if request and view
else None
)
ordering = ordering or (Lower("name"),)
children = children.order_by(*ordering)
serializer = TagSerializer(
children_queryset,
children,
many=True,
user=self.user,
full_perms=self.full_perms,

View File

@@ -448,8 +448,43 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
def get_serializer_context(self):
context = super().get_serializer_context()
context["document_count_filter"] = self.get_document_count_filter()
if hasattr(self, "_children_map"):
context["children_map"] = self._children_map
return context
def list(self, request, *args, **kwargs):
"""
Build a children map once to avoid per-parent queries in the serializer.
"""
queryset = self.filter_queryset(self.get_queryset())
ordering = OrderingFilter().get_ordering(request, queryset, self) or (
Lower("name"),
)
queryset = queryset.order_by(*ordering)
all_tags = list(queryset)
descendant_pks = {pk for tag in all_tags for pk in tag.get_descendants_pks()}
if descendant_pks:
filter_q = self.get_document_count_filter()
children_source = (
Tag.objects.filter(pk__in=descendant_pks | {t.pk for t in all_tags})
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
.order_by(*ordering)
)
else:
children_source = all_tags
children_map = {}
for tag in children_source:
children_map.setdefault(tag.tn_parent_id, []).append(tag)
self._children_map = children_map
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
def perform_update(self, serializer):
old_parent = self.get_object().get_parent()
tag = serializer.save()