From 057ce29676a573466a84943eb6ae3a54a766b976 Mon Sep 17 00:00:00 2001 From: Yichi Yang Date: Sun, 25 Aug 2024 12:20:24 +0800 Subject: [PATCH] Refactor: Reduce number of SQL queries when serializing List[Document] (#7505) --- src/documents/serialisers.py | 116 +++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 25 deletions(-) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 0c0813aa4..747d744b6 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -2,6 +2,7 @@ import datetime import math import re import zoneinfo +from collections.abc import Iterable from decimal import Decimal import magic @@ -232,22 +233,45 @@ class OwnedObjectSerializer( ) ) - def get_is_shared_by_requester(self, obj: Document): - ctype = ContentType.objects.get_for_model(obj) + @staticmethod + def get_shared_object_pks(objects: Iterable): + """ + Return the primary keys of the subset of objects that are shared. + """ + try: + first_obj = next(iter(objects)) + except StopIteration: + return set() + + ctype = ContentType.objects.get_for_model(first_obj) + object_pks = list(obj.pk for obj in objects) + pk_type = type(first_obj.pk) + + def get_pks_for_permission_type(model): + return map( + pk_type, # coerce the pk to be the same type of the provided objects + model.objects.filter( + content_type=ctype, + object_pk__in=object_pks, + ) + .values_list("object_pk", flat=True) + .distinct(), + ) + UserObjectPermission = get_user_obj_perms_model() GroupObjectPermission = get_group_obj_perms_model() - return obj.owner == self.user and ( - UserObjectPermission.objects.filter( - content_type=ctype, - object_pk=obj.pk, - ).count() - > 0 - or GroupObjectPermission.objects.filter( - content_type=ctype, - object_pk=obj.pk, - ).count() - > 0 - ) + user_permission_pks = get_pks_for_permission_type(UserObjectPermission) + group_permission_pks = get_pks_for_permission_type(GroupObjectPermission) + + return set(user_permission_pks) | set(group_permission_pks) + + def get_is_shared_by_requester(self, obj: Document): + # First check the context to see if `shared_object_pks` is set by the parent. + shared_object_pks = self.context.get("shared_object_pks") + # If not just check if the current object is shared. + if shared_object_pks is None: + shared_object_pks = self.get_shared_object_pks([obj]) + return obj.owner == self.user and obj.id in shared_object_pks permissions = SerializerMethodField(read_only=True) user_can_change = SerializerMethodField(read_only=True) @@ -303,6 +327,14 @@ class OwnedObjectSerializer( return super().update(instance, validated_data) +class OwnedObjectListSerializer(serializers.ListSerializer): + def to_representation(self, documents): + self.child.context["shared_object_pks"] = self.child.get_shared_object_pks( + documents, + ) + return super().to_representation(documents) + + class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer): last_correspondence = serializers.DateTimeField(read_only=True, required=False) @@ -863,35 +895,69 @@ class DocumentSerializer( "custom_fields", "remove_inbox_tags", ) + list_serializer_class = OwnedObjectListSerializer + + +class SearchResultListSerializer(serializers.ListSerializer): + def to_representation(self, hits): + document_ids = [hit["id"] for hit in hits] + # Fetch all Document objects in the list in one SQL query. + documents = self.child.fetch_documents(document_ids) + self.child.context["documents"] = documents + # Also check if they are shared with other users / groups. + self.child.context["shared_object_pks"] = self.child.get_shared_object_pks( + documents.values(), + ) + + return super().to_representation(hits) class SearchResultSerializer(DocumentSerializer): - def to_representation(self, instance): - doc = ( - Document.objects.select_related( + @staticmethod + def fetch_documents(ids): + """ + Return a dict that maps given document IDs to Document objects. + """ + return { + document.id: document + for document in Document.objects.select_related( "correspondent", "storage_path", "document_type", "owner", ) .prefetch_related("tags", "custom_fields", "notes") - .get(id=instance["id"]) - ) + .filter(id__in=ids) + } + + def to_representation(self, hit): + # Again we first check if the parent has already fetched the documents. + documents = self.context.get("documents") + # Otherwise we fetch this document. + if documents is None: # pragma: no cover + # In practice we only serialize **lists** of whoosh.searching.Hit. + # I'm keeping this check for completeness but marking it no cover for now. + documents = self.fetch_documents([hit["id"]]) + document = documents[hit["id"]] + notes = ",".join( - [str(c.note) for c in doc.notes.all()], + [str(c.note) for c in document.notes.all()], ) - r = super().to_representation(doc) + r = super().to_representation(document) r["__search_hit__"] = { - "score": instance.score, - "highlights": instance.highlights("content", text=doc.content), + "score": hit.score, + "highlights": hit.highlights("content", text=document.content), "note_highlights": ( - instance.highlights("notes", text=notes) if doc else None + hit.highlights("notes", text=notes) if document else None ), - "rank": instance.rank, + "rank": hit.rank, } return r + class Meta(DocumentSerializer.Meta): + list_serializer_class = SearchResultListSerializer + class SavedViewFilterRuleSerializer(serializers.ModelSerializer): class Meta: