Refactor: Reduce number of SQL queries when serializing List[Document] (#7505)

This commit is contained in:
Yichi Yang 2024-08-25 12:20:24 +08:00 committed by GitHub
parent 982eeb0d24
commit 057ce29676
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@ import datetime
import math import math
import re import re
import zoneinfo import zoneinfo
from collections.abc import Iterable
from decimal import Decimal from decimal import Decimal
import magic import magic
@ -232,22 +233,45 @@ class OwnedObjectSerializer(
) )
) )
def get_is_shared_by_requester(self, obj: Document): @staticmethod
ctype = ContentType.objects.get_for_model(obj) def get_shared_object_pks(objects: Iterable):
"""
Return the primary keys of the subset of objects that are shared.
"""
try:
first_obj = next(iter(objects))
except StopIteration:
return set()
ctype = ContentType.objects.get_for_model(first_obj)
object_pks = list(obj.pk for obj in objects)
pk_type = type(first_obj.pk)
def get_pks_for_permission_type(model):
return map(
pk_type, # coerce the pk to be the same type of the provided objects
model.objects.filter(
content_type=ctype,
object_pk__in=object_pks,
)
.values_list("object_pk", flat=True)
.distinct(),
)
UserObjectPermission = get_user_obj_perms_model() UserObjectPermission = get_user_obj_perms_model()
GroupObjectPermission = get_group_obj_perms_model() GroupObjectPermission = get_group_obj_perms_model()
return obj.owner == self.user and ( user_permission_pks = get_pks_for_permission_type(UserObjectPermission)
UserObjectPermission.objects.filter( group_permission_pks = get_pks_for_permission_type(GroupObjectPermission)
content_type=ctype,
object_pk=obj.pk, return set(user_permission_pks) | set(group_permission_pks)
).count()
> 0 def get_is_shared_by_requester(self, obj: Document):
or GroupObjectPermission.objects.filter( # First check the context to see if `shared_object_pks` is set by the parent.
content_type=ctype, shared_object_pks = self.context.get("shared_object_pks")
object_pk=obj.pk, # If not just check if the current object is shared.
).count() if shared_object_pks is None:
> 0 shared_object_pks = self.get_shared_object_pks([obj])
) return obj.owner == self.user and obj.id in shared_object_pks
permissions = SerializerMethodField(read_only=True) permissions = SerializerMethodField(read_only=True)
user_can_change = SerializerMethodField(read_only=True) user_can_change = SerializerMethodField(read_only=True)
@ -303,6 +327,14 @@ class OwnedObjectSerializer(
return super().update(instance, validated_data) return super().update(instance, validated_data)
class OwnedObjectListSerializer(serializers.ListSerializer):
def to_representation(self, documents):
self.child.context["shared_object_pks"] = self.child.get_shared_object_pks(
documents,
)
return super().to_representation(documents)
class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer): class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer):
last_correspondence = serializers.DateTimeField(read_only=True, required=False) last_correspondence = serializers.DateTimeField(read_only=True, required=False)
@ -863,35 +895,69 @@ class DocumentSerializer(
"custom_fields", "custom_fields",
"remove_inbox_tags", "remove_inbox_tags",
) )
list_serializer_class = OwnedObjectListSerializer
class SearchResultListSerializer(serializers.ListSerializer):
def to_representation(self, hits):
document_ids = [hit["id"] for hit in hits]
# Fetch all Document objects in the list in one SQL query.
documents = self.child.fetch_documents(document_ids)
self.child.context["documents"] = documents
# Also check if they are shared with other users / groups.
self.child.context["shared_object_pks"] = self.child.get_shared_object_pks(
documents.values(),
)
return super().to_representation(hits)
class SearchResultSerializer(DocumentSerializer): class SearchResultSerializer(DocumentSerializer):
def to_representation(self, instance): @staticmethod
doc = ( def fetch_documents(ids):
Document.objects.select_related( """
Return a dict that maps given document IDs to Document objects.
"""
return {
document.id: document
for document in Document.objects.select_related(
"correspondent", "correspondent",
"storage_path", "storage_path",
"document_type", "document_type",
"owner", "owner",
) )
.prefetch_related("tags", "custom_fields", "notes") .prefetch_related("tags", "custom_fields", "notes")
.get(id=instance["id"]) .filter(id__in=ids)
) }
def to_representation(self, hit):
# Again we first check if the parent has already fetched the documents.
documents = self.context.get("documents")
# Otherwise we fetch this document.
if documents is None: # pragma: no cover
# In practice we only serialize **lists** of whoosh.searching.Hit.
# I'm keeping this check for completeness but marking it no cover for now.
documents = self.fetch_documents([hit["id"]])
document = documents[hit["id"]]
notes = ",".join( notes = ",".join(
[str(c.note) for c in doc.notes.all()], [str(c.note) for c in document.notes.all()],
) )
r = super().to_representation(doc) r = super().to_representation(document)
r["__search_hit__"] = { r["__search_hit__"] = {
"score": instance.score, "score": hit.score,
"highlights": instance.highlights("content", text=doc.content), "highlights": hit.highlights("content", text=document.content),
"note_highlights": ( "note_highlights": (
instance.highlights("notes", text=notes) if doc else None hit.highlights("notes", text=notes) if document else None
), ),
"rank": instance.rank, "rank": hit.rank,
} }
return r return r
class Meta(DocumentSerializer.Meta):
list_serializer_class = SearchResultListSerializer
class SavedViewFilterRuleSerializer(serializers.ModelSerializer): class SavedViewFilterRuleSerializer(serializers.ModelSerializer):
class Meta: class Meta: