From 3df43d828a87aff2f5feeb6c6b933e286b888017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20M=C3=A9rino?= Date: Tue, 30 Sep 2025 18:48:44 +0200 Subject: [PATCH] Performance: Cache django-guardian permissions when counting documents (#10657) Fixes N+1 queries in tag, correspondent, storage path, custom field, and document type list views. Reduces SQL queries from 160 to 9. --- src/documents/serialisers.py | 55 ++++++++------ src/documents/tests/test_views.py | 116 ++++++++++++++++++++++++++++++ src/documents/views.py | 101 +++++++++++++++++++++++++- 3 files changed, 250 insertions(+), 22 deletions(-) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 1608a0e4e..ce0192074 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -6,6 +6,7 @@ import re from datetime import datetime from decimal import Decimal from typing import TYPE_CHECKING +from typing import Literal import magic from celery import states @@ -252,6 +253,35 @@ class OwnedObjectSerializer( except KeyError: pass + def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]): + """ + Get the given permissions from context or from django-guardian. + + :param codename: The permission codename, e.g. 'view' or 'change' + :param target: 'users' or 'groups' + """ + key = f"{target}_{codename}_perms" + cached = self.context.get(key, {}).get(obj.pk) + if cached is not None: + return list(cached) + + # Permission not found in the context, get it from guardian + if target == "users": + return list( + get_users_with_perms( + obj, + only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"], + with_group_users=False, + ).values_list("id", flat=True), + ) + else: # groups + return list( + get_groups_with_only_permission( + obj, + codename=f"{codename}_{obj.__class__.__name__.lower()}", + ).values_list("id", flat=True), + ) + @extend_schema_field( field={ "type": "object", @@ -286,31 +316,14 @@ class OwnedObjectSerializer( }, ) def get_permissions(self, obj) -> dict: - view_codename = f"view_{obj.__class__.__name__.lower()}" - change_codename = f"change_{obj.__class__.__name__.lower()}" - return { "view": { - "users": get_users_with_perms( - obj, - only_with_perms_in=[view_codename], - with_group_users=False, - ).values_list("id", flat=True), - "groups": get_groups_with_only_permission( - obj, - codename=view_codename, - ).values_list("id", flat=True), + "users": self._get_perms(obj, "view", "users"), + "groups": self._get_perms(obj, "view", "groups"), }, "change": { - "users": get_users_with_perms( - obj, - only_with_perms_in=[change_codename], - with_group_users=False, - ).values_list("id", flat=True), - "groups": get_groups_with_only_permission( - obj, - codename=change_codename, - ).values_list("id", flat=True), + "users": self._get_perms(obj, "change", "users"), + "groups": self._get_perms(obj, "change", "groups"), }, } diff --git a/src/documents/tests/test_views.py b/src/documents/tests/test_views.py index 4c987e3af..57562c02c 100644 --- a/src/documents/tests/test_views.py +++ b/src/documents/tests/test_views.py @@ -1,17 +1,23 @@ +import json import tempfile from datetime import timedelta from pathlib import Path from django.conf import settings +from django.contrib.auth.models import Group from django.contrib.auth.models import Permission from django.contrib.auth.models import User +from django.db import connection from django.test import TestCase from django.test import override_settings +from django.test.utils import CaptureQueriesContext from django.utils import timezone +from guardian.shortcuts import assign_perm from rest_framework import status from documents.models import Document from documents.models import ShareLink +from documents.models import Tag from documents.tests.utils import DirectoriesMixin from paperless.models import ApplicationConfiguration @@ -154,3 +160,113 @@ class TestViews(DirectoriesMixin, TestCase): response.render() self.assertEqual(response.request["PATH_INFO"], "/accounts/login/") self.assertContains(response, b"Share link has expired") + + def test_list_with_full_permissions(self): + """ + GIVEN: + - Tags with different permissions + WHEN: + - Request to get tag list with full permissions is made + THEN: + - Tag list is returned with the right permission information + """ + user2 = User.objects.create(username="user2") + user3 = User.objects.create(username="user3") + group1 = Group.objects.create(name="group1") + group2 = Group.objects.create(name="group2") + group3 = Group.objects.create(name="group3") + t1 = Tag.objects.create(name="invoice", pk=1) + assign_perm("view_tag", self.user, t1) + assign_perm("view_tag", user2, t1) + assign_perm("view_tag", user3, t1) + assign_perm("view_tag", group1, t1) + assign_perm("view_tag", group2, t1) + assign_perm("view_tag", group3, t1) + assign_perm("change_tag", self.user, t1) + assign_perm("change_tag", user2, t1) + assign_perm("change_tag", group1, t1) + assign_perm("change_tag", group2, t1) + + Tag.objects.create(name="bank statement", pk=2) + d1 = Document.objects.create( + title="Invoice 1", + content="This is the invoice of a very expensive item", + checksum="A", + ) + d1.tags.add(t1) + d2 = Document.objects.create( + title="Invoice 2", + content="Internet invoice, I should pay it to continue contributing", + checksum="B", + ) + d2.tags.add(t1) + + view_permissions = Permission.objects.filter( + codename__contains="view_tag", + ) + self.user.user_permissions.add(*view_permissions) + self.user.save() + + self.client.force_login(self.user) + response = self.client.get("/api/tags/?page=1&full_perms=true") + results = json.loads(response.content)["results"] + for tag in results: + if tag["name"] == "invoice": + assert tag["permissions"] == { + "view": { + "users": [self.user.pk, user2.pk, user3.pk], + "groups": [group1.pk, group2.pk, group3.pk], + }, + "change": { + "users": [self.user.pk, user2.pk], + "groups": [group1.pk, group2.pk], + }, + } + elif tag["name"] == "bank statement": + assert tag["permissions"] == { + "view": {"users": [], "groups": []}, + "change": {"users": [], "groups": []}, + } + else: + assert False, f"Unexpected tag found: {tag['name']}" + + def test_list_no_n_plus_1_queries(self): + """ + GIVEN: + - Tags with different permissions + WHEN: + - Request to get tag list with full permissions is made + THEN: + - Permissions are not queries in database tag by tag, + i.e. there are no N+1 queries + """ + view_permissions = Permission.objects.filter( + codename__contains="view_tag", + ) + self.user.user_permissions.add(*view_permissions) + self.user.save() + self.client.force_login(self.user) + + # Start by a small list, and count the number of SQL queries + for i in range(2): + Tag.objects.create(name=f"tag_{i}") + + with CaptureQueriesContext(connection) as ctx_small: + response_small = self.client.get("/api/tags/?full_perms=true") + assert response_small.status_code == 200 + num_queries_small = len(ctx_small.captured_queries) + + # Complete the list, and count the number of SQL queries again + for i in range(2, 50): + Tag.objects.create(name=f"tag_{i}") + + with CaptureQueriesContext(connection) as ctx_large: + response_large = self.client.get("/api/tags/?full_perms=true") + assert response_large.status_code == 200 + num_queries_large = len(ctx_large.captured_queries) + + # A few additional queries are allowed, but not a linear explosion + assert num_queries_large <= num_queries_small + 5, ( + f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, " + f"but {num_queries_large} queries for 50 tags" + ) diff --git a/src/documents/views.py b/src/documents/views.py index 86eab92e3..bce7428cd 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -5,9 +5,11 @@ import platform import re import tempfile import zipfile +from collections import defaultdict from datetime import datetime from pathlib import Path from time import mktime +from typing import Literal from unicodedata import normalize from urllib.parse import quote from urllib.parse import urlparse @@ -19,6 +21,7 @@ from celery import states from django.conf import settings from django.contrib.auth.models import Group from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType from django.db import connections from django.db.migrations.loader import MigrationLoader from django.db.migrations.recorder import MigrationRecorder @@ -56,6 +59,8 @@ from drf_spectacular.utils import OpenApiParameter from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema_view from drf_spectacular.utils import inline_serializer +from guardian.utils import get_group_obj_perms_model +from guardian.utils import get_user_obj_perms_model from langdetect import detect from packaging import version as packaging_version from redis import Redis @@ -254,7 +259,101 @@ class PassUserMixin(GenericAPIView): return super().get_serializer(*args, **kwargs) -class PermissionsAwareDocumentCountMixin(PassUserMixin): +class BulkPermissionMixin: + """ + Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries. + """ + + def get_permission_codenames(self): + model_name = self.queryset.model.__name__.lower() + return { + "view": f"view_{model_name}", + "change": f"change_{model_name}", + } + + def _get_object_perms( + self, + objects: list, + perm_codenames: list[str], + actor: Literal["users", "groups"], + ) -> dict[int, dict[str, list[int]]]: + """ + Collect object-level permissions for either users or groups. + """ + model = self.queryset.model + obj_perm_model = ( + get_user_obj_perms_model(model) + if actor == "users" + else get_group_obj_perms_model(model) + ) + id_field = "user_id" if actor == "users" else "group_id" + ctype = ContentType.objects.get_for_model(model) + object_pks = [obj.pk for obj in objects] + + perms_qs = obj_perm_model.objects.filter( + content_type=ctype, + object_pk__in=object_pks, + permission__codename__in=perm_codenames, + ).values_list("object_pk", id_field, "permission__codename") + + perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list)) + for object_pk, actor_id, codename in perms_qs: + perms[int(object_pk)][codename].append(actor_id) + + # Ensure that all objects have all codenames, even if empty + for pk in object_pks: + for codename in perm_codenames: + perms[pk][codename] + + return perms + + def get_serializer_context(self): + """ + Get all permissions of the current list of objects at once and pass them to the serializer. + This avoid fetching permissions object by object in database. + """ + context = super().get_serializer_context() + try: + full_perms = get_boolean( + str(self.request.query_params.get("full_perms", "false")), + ) + except ValueError: + full_perms = False + + if not full_perms: + return context + + # Check which objects are being paginated + page = getattr(self, "paginator", None) + if page and hasattr(page, "page"): + queryset = page.page.object_list + elif hasattr(self, "page"): + queryset = self.page + else: + queryset = self.filter_queryset(self.get_queryset()) + + codenames = self.get_permission_codenames() + perm_names = [codenames["view"], codenames["change"]] + user_perms = self._get_object_perms(queryset, perm_names, actor="users") + group_perms = self._get_object_perms(queryset, perm_names, actor="groups") + + context["users_view_perms"] = { + pk: user_perms[pk][codenames["view"]] for pk in user_perms + } + context["users_change_perms"] = { + pk: user_perms[pk][codenames["change"]] for pk in user_perms + } + context["groups_view_perms"] = { + pk: group_perms[pk][codenames["view"]] for pk in group_perms + } + context["groups_change_perms"] = { + pk: group_perms[pk][codenames["change"]] for pk in group_perms + } + + return context + + +class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin): """ Mixin to add document count to queryset, permissions-aware if needed """