mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-10-02 01:42:50 -05:00
Performance: Cache django-guardian permissions when counting documents (#10657)
Fixes N+1 queries in tag, correspondent, storage path, custom field, and document type list views. Reduces SQL queries from 160 to 9.
This commit is contained in:
@@ -6,6 +6,7 @@ import re
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import magic
|
import magic
|
||||||
from celery import states
|
from celery import states
|
||||||
@@ -252,6 +253,35 @@ class OwnedObjectSerializer(
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]):
|
||||||
|
"""
|
||||||
|
Get the given permissions from context or from django-guardian.
|
||||||
|
|
||||||
|
:param codename: The permission codename, e.g. 'view' or 'change'
|
||||||
|
:param target: 'users' or 'groups'
|
||||||
|
"""
|
||||||
|
key = f"{target}_{codename}_perms"
|
||||||
|
cached = self.context.get(key, {}).get(obj.pk)
|
||||||
|
if cached is not None:
|
||||||
|
return list(cached)
|
||||||
|
|
||||||
|
# Permission not found in the context, get it from guardian
|
||||||
|
if target == "users":
|
||||||
|
return list(
|
||||||
|
get_users_with_perms(
|
||||||
|
obj,
|
||||||
|
only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"],
|
||||||
|
with_group_users=False,
|
||||||
|
).values_list("id", flat=True),
|
||||||
|
)
|
||||||
|
else: # groups
|
||||||
|
return list(
|
||||||
|
get_groups_with_only_permission(
|
||||||
|
obj,
|
||||||
|
codename=f"{codename}_{obj.__class__.__name__.lower()}",
|
||||||
|
).values_list("id", flat=True),
|
||||||
|
)
|
||||||
|
|
||||||
@extend_schema_field(
|
@extend_schema_field(
|
||||||
field={
|
field={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -286,31 +316,14 @@ class OwnedObjectSerializer(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
def get_permissions(self, obj) -> dict:
|
def get_permissions(self, obj) -> dict:
|
||||||
view_codename = f"view_{obj.__class__.__name__.lower()}"
|
|
||||||
change_codename = f"change_{obj.__class__.__name__.lower()}"
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"view": {
|
"view": {
|
||||||
"users": get_users_with_perms(
|
"users": self._get_perms(obj, "view", "users"),
|
||||||
obj,
|
"groups": self._get_perms(obj, "view", "groups"),
|
||||||
only_with_perms_in=[view_codename],
|
|
||||||
with_group_users=False,
|
|
||||||
).values_list("id", flat=True),
|
|
||||||
"groups": get_groups_with_only_permission(
|
|
||||||
obj,
|
|
||||||
codename=view_codename,
|
|
||||||
).values_list("id", flat=True),
|
|
||||||
},
|
},
|
||||||
"change": {
|
"change": {
|
||||||
"users": get_users_with_perms(
|
"users": self._get_perms(obj, "change", "users"),
|
||||||
obj,
|
"groups": self._get_perms(obj, "change", "groups"),
|
||||||
only_with_perms_in=[change_codename],
|
|
||||||
with_group_users=False,
|
|
||||||
).values_list("id", flat=True),
|
|
||||||
"groups": get_groups_with_only_permission(
|
|
||||||
obj,
|
|
||||||
codename=change_codename,
|
|
||||||
).values_list("id", flat=True),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,17 +1,23 @@
|
|||||||
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import Permission
|
from django.contrib.auth.models import Permission
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
from django.db import connection
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from guardian.shortcuts import assign_perm
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import ShareLink
|
from documents.models import ShareLink
|
||||||
|
from documents.models import Tag
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
from paperless.models import ApplicationConfiguration
|
from paperless.models import ApplicationConfiguration
|
||||||
|
|
||||||
@@ -154,3 +160,113 @@ class TestViews(DirectoriesMixin, TestCase):
|
|||||||
response.render()
|
response.render()
|
||||||
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
|
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
|
||||||
self.assertContains(response, b"Share link has expired")
|
self.assertContains(response, b"Share link has expired")
|
||||||
|
|
||||||
|
def test_list_with_full_permissions(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Tags with different permissions
|
||||||
|
WHEN:
|
||||||
|
- Request to get tag list with full permissions is made
|
||||||
|
THEN:
|
||||||
|
- Tag list is returned with the right permission information
|
||||||
|
"""
|
||||||
|
user2 = User.objects.create(username="user2")
|
||||||
|
user3 = User.objects.create(username="user3")
|
||||||
|
group1 = Group.objects.create(name="group1")
|
||||||
|
group2 = Group.objects.create(name="group2")
|
||||||
|
group3 = Group.objects.create(name="group3")
|
||||||
|
t1 = Tag.objects.create(name="invoice", pk=1)
|
||||||
|
assign_perm("view_tag", self.user, t1)
|
||||||
|
assign_perm("view_tag", user2, t1)
|
||||||
|
assign_perm("view_tag", user3, t1)
|
||||||
|
assign_perm("view_tag", group1, t1)
|
||||||
|
assign_perm("view_tag", group2, t1)
|
||||||
|
assign_perm("view_tag", group3, t1)
|
||||||
|
assign_perm("change_tag", self.user, t1)
|
||||||
|
assign_perm("change_tag", user2, t1)
|
||||||
|
assign_perm("change_tag", group1, t1)
|
||||||
|
assign_perm("change_tag", group2, t1)
|
||||||
|
|
||||||
|
Tag.objects.create(name="bank statement", pk=2)
|
||||||
|
d1 = Document.objects.create(
|
||||||
|
title="Invoice 1",
|
||||||
|
content="This is the invoice of a very expensive item",
|
||||||
|
checksum="A",
|
||||||
|
)
|
||||||
|
d1.tags.add(t1)
|
||||||
|
d2 = Document.objects.create(
|
||||||
|
title="Invoice 2",
|
||||||
|
content="Internet invoice, I should pay it to continue contributing",
|
||||||
|
checksum="B",
|
||||||
|
)
|
||||||
|
d2.tags.add(t1)
|
||||||
|
|
||||||
|
view_permissions = Permission.objects.filter(
|
||||||
|
codename__contains="view_tag",
|
||||||
|
)
|
||||||
|
self.user.user_permissions.add(*view_permissions)
|
||||||
|
self.user.save()
|
||||||
|
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
response = self.client.get("/api/tags/?page=1&full_perms=true")
|
||||||
|
results = json.loads(response.content)["results"]
|
||||||
|
for tag in results:
|
||||||
|
if tag["name"] == "invoice":
|
||||||
|
assert tag["permissions"] == {
|
||||||
|
"view": {
|
||||||
|
"users": [self.user.pk, user2.pk, user3.pk],
|
||||||
|
"groups": [group1.pk, group2.pk, group3.pk],
|
||||||
|
},
|
||||||
|
"change": {
|
||||||
|
"users": [self.user.pk, user2.pk],
|
||||||
|
"groups": [group1.pk, group2.pk],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif tag["name"] == "bank statement":
|
||||||
|
assert tag["permissions"] == {
|
||||||
|
"view": {"users": [], "groups": []},
|
||||||
|
"change": {"users": [], "groups": []},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
assert False, f"Unexpected tag found: {tag['name']}"
|
||||||
|
|
||||||
|
def test_list_no_n_plus_1_queries(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Tags with different permissions
|
||||||
|
WHEN:
|
||||||
|
- Request to get tag list with full permissions is made
|
||||||
|
THEN:
|
||||||
|
- Permissions are not queries in database tag by tag,
|
||||||
|
i.e. there are no N+1 queries
|
||||||
|
"""
|
||||||
|
view_permissions = Permission.objects.filter(
|
||||||
|
codename__contains="view_tag",
|
||||||
|
)
|
||||||
|
self.user.user_permissions.add(*view_permissions)
|
||||||
|
self.user.save()
|
||||||
|
self.client.force_login(self.user)
|
||||||
|
|
||||||
|
# Start by a small list, and count the number of SQL queries
|
||||||
|
for i in range(2):
|
||||||
|
Tag.objects.create(name=f"tag_{i}")
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as ctx_small:
|
||||||
|
response_small = self.client.get("/api/tags/?full_perms=true")
|
||||||
|
assert response_small.status_code == 200
|
||||||
|
num_queries_small = len(ctx_small.captured_queries)
|
||||||
|
|
||||||
|
# Complete the list, and count the number of SQL queries again
|
||||||
|
for i in range(2, 50):
|
||||||
|
Tag.objects.create(name=f"tag_{i}")
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as ctx_large:
|
||||||
|
response_large = self.client.get("/api/tags/?full_perms=true")
|
||||||
|
assert response_large.status_code == 200
|
||||||
|
num_queries_large = len(ctx_large.captured_queries)
|
||||||
|
|
||||||
|
# A few additional queries are allowed, but not a linear explosion
|
||||||
|
assert num_queries_large <= num_queries_small + 5, (
|
||||||
|
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
|
||||||
|
f"but {num_queries_large} queries for 50 tags"
|
||||||
|
)
|
||||||
|
@@ -5,9 +5,11 @@ import platform
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import mktime
|
from time import mktime
|
||||||
|
from typing import Literal
|
||||||
from unicodedata import normalize
|
from unicodedata import normalize
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -19,6 +21,7 @@ from celery import states
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
from django.db.migrations.loader import MigrationLoader
|
from django.db.migrations.loader import MigrationLoader
|
||||||
from django.db.migrations.recorder import MigrationRecorder
|
from django.db.migrations.recorder import MigrationRecorder
|
||||||
@@ -56,6 +59,8 @@ from drf_spectacular.utils import OpenApiParameter
|
|||||||
from drf_spectacular.utils import extend_schema
|
from drf_spectacular.utils import extend_schema
|
||||||
from drf_spectacular.utils import extend_schema_view
|
from drf_spectacular.utils import extend_schema_view
|
||||||
from drf_spectacular.utils import inline_serializer
|
from drf_spectacular.utils import inline_serializer
|
||||||
|
from guardian.utils import get_group_obj_perms_model
|
||||||
|
from guardian.utils import get_user_obj_perms_model
|
||||||
from langdetect import detect
|
from langdetect import detect
|
||||||
from packaging import version as packaging_version
|
from packaging import version as packaging_version
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
@@ -254,7 +259,101 @@ class PassUserMixin(GenericAPIView):
|
|||||||
return super().get_serializer(*args, **kwargs)
|
return super().get_serializer(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PermissionsAwareDocumentCountMixin(PassUserMixin):
|
class BulkPermissionMixin:
|
||||||
|
"""
|
||||||
|
Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_permission_codenames(self):
|
||||||
|
model_name = self.queryset.model.__name__.lower()
|
||||||
|
return {
|
||||||
|
"view": f"view_{model_name}",
|
||||||
|
"change": f"change_{model_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_object_perms(
|
||||||
|
self,
|
||||||
|
objects: list,
|
||||||
|
perm_codenames: list[str],
|
||||||
|
actor: Literal["users", "groups"],
|
||||||
|
) -> dict[int, dict[str, list[int]]]:
|
||||||
|
"""
|
||||||
|
Collect object-level permissions for either users or groups.
|
||||||
|
"""
|
||||||
|
model = self.queryset.model
|
||||||
|
obj_perm_model = (
|
||||||
|
get_user_obj_perms_model(model)
|
||||||
|
if actor == "users"
|
||||||
|
else get_group_obj_perms_model(model)
|
||||||
|
)
|
||||||
|
id_field = "user_id" if actor == "users" else "group_id"
|
||||||
|
ctype = ContentType.objects.get_for_model(model)
|
||||||
|
object_pks = [obj.pk for obj in objects]
|
||||||
|
|
||||||
|
perms_qs = obj_perm_model.objects.filter(
|
||||||
|
content_type=ctype,
|
||||||
|
object_pk__in=object_pks,
|
||||||
|
permission__codename__in=perm_codenames,
|
||||||
|
).values_list("object_pk", id_field, "permission__codename")
|
||||||
|
|
||||||
|
perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list))
|
||||||
|
for object_pk, actor_id, codename in perms_qs:
|
||||||
|
perms[int(object_pk)][codename].append(actor_id)
|
||||||
|
|
||||||
|
# Ensure that all objects have all codenames, even if empty
|
||||||
|
for pk in object_pks:
|
||||||
|
for codename in perm_codenames:
|
||||||
|
perms[pk][codename]
|
||||||
|
|
||||||
|
return perms
|
||||||
|
|
||||||
|
def get_serializer_context(self):
|
||||||
|
"""
|
||||||
|
Get all permissions of the current list of objects at once and pass them to the serializer.
|
||||||
|
This avoid fetching permissions object by object in database.
|
||||||
|
"""
|
||||||
|
context = super().get_serializer_context()
|
||||||
|
try:
|
||||||
|
full_perms = get_boolean(
|
||||||
|
str(self.request.query_params.get("full_perms", "false")),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
full_perms = False
|
||||||
|
|
||||||
|
if not full_perms:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Check which objects are being paginated
|
||||||
|
page = getattr(self, "paginator", None)
|
||||||
|
if page and hasattr(page, "page"):
|
||||||
|
queryset = page.page.object_list
|
||||||
|
elif hasattr(self, "page"):
|
||||||
|
queryset = self.page
|
||||||
|
else:
|
||||||
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
|
||||||
|
codenames = self.get_permission_codenames()
|
||||||
|
perm_names = [codenames["view"], codenames["change"]]
|
||||||
|
user_perms = self._get_object_perms(queryset, perm_names, actor="users")
|
||||||
|
group_perms = self._get_object_perms(queryset, perm_names, actor="groups")
|
||||||
|
|
||||||
|
context["users_view_perms"] = {
|
||||||
|
pk: user_perms[pk][codenames["view"]] for pk in user_perms
|
||||||
|
}
|
||||||
|
context["users_change_perms"] = {
|
||||||
|
pk: user_perms[pk][codenames["change"]] for pk in user_perms
|
||||||
|
}
|
||||||
|
context["groups_view_perms"] = {
|
||||||
|
pk: group_perms[pk][codenames["view"]] for pk in group_perms
|
||||||
|
}
|
||||||
|
context["groups_change_perms"] = {
|
||||||
|
pk: group_perms[pk][codenames["change"]] for pk in group_perms
|
||||||
|
}
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
|
||||||
"""
|
"""
|
||||||
Mixin to add document count to queryset, permissions-aware if needed
|
Mixin to add document count to queryset, permissions-aware if needed
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user