Performance: Cache django-guardian permissions when counting documents (#10657)

Fixes N+1 queries in tag, correspondent, storage path, custom field,
and document type list views.
Reduces SQL queries from 160 to 9.
This commit is contained in:
Antoine Mérino
2025-09-30 18:48:44 +02:00
committed by GitHub
parent 643e2b4a8e
commit 3df43d828a
3 changed files with 250 additions and 22 deletions

View File

@@ -6,6 +6,7 @@ import re
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING
from typing import Literal
import magic
from celery import states
@@ -252,6 +253,35 @@ class OwnedObjectSerializer(
except KeyError:
pass
def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]):
"""
Get the given permissions from context or from django-guardian.
:param codename: The permission codename, e.g. 'view' or 'change'
:param target: 'users' or 'groups'
"""
key = f"{target}_{codename}_perms"
cached = self.context.get(key, {}).get(obj.pk)
if cached is not None:
return list(cached)
# Permission not found in the context, get it from guardian
if target == "users":
return list(
get_users_with_perms(
obj,
only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"],
with_group_users=False,
).values_list("id", flat=True),
)
else: # groups
return list(
get_groups_with_only_permission(
obj,
codename=f"{codename}_{obj.__class__.__name__.lower()}",
).values_list("id", flat=True),
)
@extend_schema_field(
field={
"type": "object",
@@ -286,31 +316,14 @@ class OwnedObjectSerializer(
},
)
def get_permissions(self, obj) -> dict:
view_codename = f"view_{obj.__class__.__name__.lower()}"
change_codename = f"change_{obj.__class__.__name__.lower()}"
return {
"view": {
"users": get_users_with_perms(
obj,
only_with_perms_in=[view_codename],
with_group_users=False,
).values_list("id", flat=True),
"groups": get_groups_with_only_permission(
obj,
codename=view_codename,
).values_list("id", flat=True),
"users": self._get_perms(obj, "view", "users"),
"groups": self._get_perms(obj, "view", "groups"),
},
"change": {
"users": get_users_with_perms(
obj,
only_with_perms_in=[change_codename],
with_group_users=False,
).values_list("id", flat=True),
"groups": get_groups_with_only_permission(
obj,
codename=change_codename,
).values_list("id", flat=True),
"users": self._get_perms(obj, "change", "users"),
"groups": self._get_perms(obj, "change", "groups"),
},
}

View File

@@ -1,17 +1,23 @@
import json
import tempfile
from datetime import timedelta
from pathlib import Path
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.db import connection
from django.test import TestCase
from django.test import override_settings
from django.test.utils import CaptureQueriesContext
from django.utils import timezone
from guardian.shortcuts import assign_perm
from rest_framework import status
from documents.models import Document
from documents.models import ShareLink
from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
from paperless.models import ApplicationConfiguration
@@ -154,3 +160,113 @@ class TestViews(DirectoriesMixin, TestCase):
response.render()
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
self.assertContains(response, b"Share link has expired")
def test_list_with_full_permissions(self):
"""
GIVEN:
- Tags with different permissions
WHEN:
- Request to get tag list with full permissions is made
THEN:
- Tag list is returned with the right permission information
"""
user2 = User.objects.create(username="user2")
user3 = User.objects.create(username="user3")
group1 = Group.objects.create(name="group1")
group2 = Group.objects.create(name="group2")
group3 = Group.objects.create(name="group3")
t1 = Tag.objects.create(name="invoice", pk=1)
assign_perm("view_tag", self.user, t1)
assign_perm("view_tag", user2, t1)
assign_perm("view_tag", user3, t1)
assign_perm("view_tag", group1, t1)
assign_perm("view_tag", group2, t1)
assign_perm("view_tag", group3, t1)
assign_perm("change_tag", self.user, t1)
assign_perm("change_tag", user2, t1)
assign_perm("change_tag", group1, t1)
assign_perm("change_tag", group2, t1)
Tag.objects.create(name="bank statement", pk=2)
d1 = Document.objects.create(
title="Invoice 1",
content="This is the invoice of a very expensive item",
checksum="A",
)
d1.tags.add(t1)
d2 = Document.objects.create(
title="Invoice 2",
content="Internet invoice, I should pay it to continue contributing",
checksum="B",
)
d2.tags.add(t1)
view_permissions = Permission.objects.filter(
codename__contains="view_tag",
)
self.user.user_permissions.add(*view_permissions)
self.user.save()
self.client.force_login(self.user)
response = self.client.get("/api/tags/?page=1&full_perms=true")
results = json.loads(response.content)["results"]
for tag in results:
if tag["name"] == "invoice":
assert tag["permissions"] == {
"view": {
"users": [self.user.pk, user2.pk, user3.pk],
"groups": [group1.pk, group2.pk, group3.pk],
},
"change": {
"users": [self.user.pk, user2.pk],
"groups": [group1.pk, group2.pk],
},
}
elif tag["name"] == "bank statement":
assert tag["permissions"] == {
"view": {"users": [], "groups": []},
"change": {"users": [], "groups": []},
}
else:
assert False, f"Unexpected tag found: {tag['name']}"
def test_list_no_n_plus_1_queries(self):
"""
GIVEN:
- Tags with different permissions
WHEN:
- Request to get tag list with full permissions is made
THEN:
- Permissions are not queries in database tag by tag,
i.e. there are no N+1 queries
"""
view_permissions = Permission.objects.filter(
codename__contains="view_tag",
)
self.user.user_permissions.add(*view_permissions)
self.user.save()
self.client.force_login(self.user)
# Start by a small list, and count the number of SQL queries
for i in range(2):
Tag.objects.create(name=f"tag_{i}")
with CaptureQueriesContext(connection) as ctx_small:
response_small = self.client.get("/api/tags/?full_perms=true")
assert response_small.status_code == 200
num_queries_small = len(ctx_small.captured_queries)
# Complete the list, and count the number of SQL queries again
for i in range(2, 50):
Tag.objects.create(name=f"tag_{i}")
with CaptureQueriesContext(connection) as ctx_large:
response_large = self.client.get("/api/tags/?full_perms=true")
assert response_large.status_code == 200
num_queries_large = len(ctx_large.captured_queries)
# A few additional queries are allowed, but not a linear explosion
assert num_queries_large <= num_queries_small + 5, (
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
f"but {num_queries_large} queries for 50 tags"
)

View File

@@ -5,9 +5,11 @@ import platform
import re
import tempfile
import zipfile
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from time import mktime
from typing import Literal
from unicodedata import normalize
from urllib.parse import quote
from urllib.parse import urlparse
@@ -19,6 +21,7 @@ from celery import states
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.db import connections
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
@@ -56,6 +59,8 @@ from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import extend_schema_view
from drf_spectacular.utils import inline_serializer
from guardian.utils import get_group_obj_perms_model
from guardian.utils import get_user_obj_perms_model
from langdetect import detect
from packaging import version as packaging_version
from redis import Redis
@@ -254,7 +259,101 @@ class PassUserMixin(GenericAPIView):
return super().get_serializer(*args, **kwargs)
class PermissionsAwareDocumentCountMixin(PassUserMixin):
class BulkPermissionMixin:
"""
Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries.
"""
def get_permission_codenames(self):
model_name = self.queryset.model.__name__.lower()
return {
"view": f"view_{model_name}",
"change": f"change_{model_name}",
}
def _get_object_perms(
self,
objects: list,
perm_codenames: list[str],
actor: Literal["users", "groups"],
) -> dict[int, dict[str, list[int]]]:
"""
Collect object-level permissions for either users or groups.
"""
model = self.queryset.model
obj_perm_model = (
get_user_obj_perms_model(model)
if actor == "users"
else get_group_obj_perms_model(model)
)
id_field = "user_id" if actor == "users" else "group_id"
ctype = ContentType.objects.get_for_model(model)
object_pks = [obj.pk for obj in objects]
perms_qs = obj_perm_model.objects.filter(
content_type=ctype,
object_pk__in=object_pks,
permission__codename__in=perm_codenames,
).values_list("object_pk", id_field, "permission__codename")
perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list))
for object_pk, actor_id, codename in perms_qs:
perms[int(object_pk)][codename].append(actor_id)
# Ensure that all objects have all codenames, even if empty
for pk in object_pks:
for codename in perm_codenames:
perms[pk][codename]
return perms
def get_serializer_context(self):
"""
Get all permissions of the current list of objects at once and pass them to the serializer.
This avoid fetching permissions object by object in database.
"""
context = super().get_serializer_context()
try:
full_perms = get_boolean(
str(self.request.query_params.get("full_perms", "false")),
)
except ValueError:
full_perms = False
if not full_perms:
return context
# Check which objects are being paginated
page = getattr(self, "paginator", None)
if page and hasattr(page, "page"):
queryset = page.page.object_list
elif hasattr(self, "page"):
queryset = self.page
else:
queryset = self.filter_queryset(self.get_queryset())
codenames = self.get_permission_codenames()
perm_names = [codenames["view"], codenames["change"]]
user_perms = self._get_object_perms(queryset, perm_names, actor="users")
group_perms = self._get_object_perms(queryset, perm_names, actor="groups")
context["users_view_perms"] = {
pk: user_perms[pk][codenames["view"]] for pk in user_perms
}
context["users_change_perms"] = {
pk: user_perms[pk][codenames["change"]] for pk in user_perms
}
context["groups_view_perms"] = {
pk: group_perms[pk][codenames["view"]] for pk in group_perms
}
context["groups_change_perms"] = {
pk: group_perms[pk][codenames["change"]] for pk in group_perms
}
return context
class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
"""
Mixin to add document count to queryset, permissions-aware if needed
"""