Fix: retrieve document_count for tag children (#11125)

This commit is contained in:
shamoon
2025-10-22 11:13:15 -07:00
committed by GitHub
parent 0ebd9f24b5
commit 13161ebb01
4 changed files with 71 additions and 15 deletions

View File

@@ -2,6 +2,7 @@ from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.db.models import Q
from django.db.models import QuerySet
from guardian.core import ObjectPermissionChecker
from guardian.models import GroupObjectPermission
@@ -12,6 +13,8 @@ from guardian.shortcuts import remove_perm
from rest_framework.permissions import BasePermission
from rest_framework.permissions import DjangoObjectPermissions
from documents.models import Document
class PaperlessObjectPermissions(DjangoObjectPermissions):
"""
@@ -125,6 +128,25 @@ def set_permissions_for_object(permissions: list[str], object, *, merge: bool =
)
def get_document_count_filter_for_user(user):
"""
Return the Q object used to filter document counts for the given user.
"""
if user is None or not getattr(user, "is_authenticated", False):
return Q(documents__deleted_at__isnull=True, documents__owner__isnull=True)
if getattr(user, "is_superuser", False):
return Q(documents__deleted_at__isnull=True)
return Q(
documents__deleted_at__isnull=True,
documents__id__in=get_objects_for_user_owner_aware(
user,
"documents.view_document",
Document,
).values_list("id", flat=True),
)
def get_objects_for_user_owner_aware(user, perms, Model) -> QuerySet:
objects_owned = Model.objects.filter(owner=user)
objects_unowned = Model.objects.filter(owner__isnull=True)

View File

@@ -20,6 +20,7 @@ from django.core.validators import EmailValidator
from django.core.validators import MaxLengthValidator
from django.core.validators import RegexValidator
from django.core.validators import integer_validator
from django.db.models import Count
from django.utils.crypto import get_random_string
from django.utils.dateparse import parse_datetime
from django.utils.text import slugify
@@ -65,6 +66,7 @@ from documents.models import WorkflowActionEmail
from documents.models import WorkflowActionWebhook
from documents.models import WorkflowTrigger
from documents.parsers import is_mime_type_supported
from documents.permissions import get_document_count_filter_for_user
from documents.permissions import get_groups_with_only_permission
from documents.permissions import set_permissions_for_object
from documents.templating.filepath import validate_filepath_template_and_render
@@ -572,8 +574,16 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
),
)
def get_children(self, obj):
filter_q = self.context.get("document_count_filter")
if filter_q is None:
request = self.context.get("request")
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
serializer = TagSerializer(
obj.get_children(),
obj.get_children_queryset()
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q)),
many=True,
context=self.context,
)

View File

@@ -9,6 +9,7 @@ from documents.models import Tag
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
from documents.serialisers import TagSerializer
from documents.signals.handlers import run_workflows
@@ -121,6 +122,31 @@ class TestTagHierarchy(APITestCase):
tags = set(self.document.tags.values_list("pk", flat=True))
assert tags == {self.parent.pk, orphan.pk}
def test_child_document_count_included_when_parent_paginated(self):
self.document.tags.add(self.child)
response = self.client.get(
"/api/tags/",
{"page_size": 1, "ordering": "-name"},
)
assert response.status_code == 200
assert response.data["results"][0]["id"] == self.parent.pk
children = response.data["results"][0]["children"]
assert len(children) == 1
child_entry = children[0]
assert child_entry["id"] == self.child.pk
assert child_entry["document_count"] == 1
def test_tag_serializer_populates_document_filter_context(self):
context = {}
serializer = TagSerializer(self.parent, context=context)
assert serializer.data # triggers serialization
assert "document_count_filter" in context
def test_cannot_set_parent_to_self(self):
tag = Tag.objects.create(name="Selfie")
resp = self.client.patch(

View File

@@ -141,6 +141,7 @@ from documents.permissions import AcknowledgeTasksPermissions
from documents.permissions import PaperlessAdminPermissions
from documents.permissions import PaperlessNotePermissions
from documents.permissions import PaperlessObjectPermissions
from documents.permissions import get_document_count_filter_for_user
from documents.permissions import get_objects_for_user_owner_aware
from documents.permissions import has_perms_owner_aware
from documents.permissions import set_permissions_for_object
@@ -364,21 +365,13 @@ class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
Mixin to add document count to queryset, permissions-aware if needed
"""
def get_document_count_filter(self):
request = getattr(self, "request", None)
user = getattr(request, "user", None) if request else None
return get_document_count_filter_for_user(user)
def get_queryset(self):
filter = (
Q(documents__deleted_at__isnull=True)
if self.request.user is None or self.request.user.is_superuser
else (
Q(
documents__deleted_at__isnull=True,
documents__id__in=get_objects_for_user_owner_aware(
self.request.user,
"documents.view_document",
Document,
).values_list("id", flat=True),
)
)
)
filter = self.get_document_count_filter()
return (
super()
.get_queryset()
@@ -447,6 +440,11 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
filterset_class = TagFilterSet
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
def get_serializer_context(self):
context = super().get_serializer_context()
context["document_count_filter"] = self.get_document_count_filter()
return context
def perform_update(self, serializer):
old_parent = self.get_object().get_parent()
tag = serializer.save()