Compare commits

..

3 Commits

Author SHA1 Message Date
shamoon
61e0bd7eb6 A ChatGPT script for testing 2025-12-30 10:24:53 -08:00
shamoon
d127361411 Optimize tag children retrieval 2025-12-30 10:24:53 -08:00
shamoon
d45dee6d39 Run Tag tree updates once per transaction 2025-12-30 10:24:53 -08:00
7 changed files with 284 additions and 21 deletions

View File

@@ -32,7 +32,7 @@ RUN set -eux \
# Purpose: Installs s6-overlay and rootfs # Purpose: Installs s6-overlay and rootfs
# Comments: # Comments:
# - Don't leave anything extra in here either # - Don't leave anything extra in here either
FROM ghcr.io/astral-sh/uv:0.9.19-python3.12-trixie-slim AS s6-overlay-base FROM ghcr.io/astral-sh/uv:0.9.15-python3.12-trixie-slim AS s6-overlay-base
WORKDIR /usr/src/s6 WORKDIR /usr/src/s6

139
scripts/tag_perf_probe.py Normal file
View File

@@ -0,0 +1,139 @@
# noqa: INP001
"""
Ad-hoc script to gauge Tag + treenode performance locally.
It bootstraps a fresh SQLite DB in a temp folder (or PAPERLESS_DATA_DIR),
uses locmem cache/redis to avoid external services, creates synthetic tags,
and measures:
- creation time
- query count and wall time for the Tag list view
Usage:
PAPERLESS_DEBUG=1 PAPERLESS_REDIS=locmem:// PYTHONPATH=src \
PAPERLESS_DATA_DIR=/tmp/paperless-tags-probe \
.venv/bin/python scripts/tag_perf_probe.py
"""
import os
import sys
import time
from collections.abc import Iterable
from contextlib import contextmanager
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "paperless.settings")
os.environ.setdefault("PAPERLESS_DEBUG", "1")
os.environ.setdefault("PAPERLESS_REDIS", "locmem://")
os.environ.setdefault("PYTHONPATH", "src")
import django
django.setup()
from django.contrib.auth import get_user_model # noqa: E402
from django.core.management import call_command # noqa: E402
from django.db import connection # noqa: E402
from django.test.client import RequestFactory # noqa: E402
from rest_framework.test import force_authenticate # noqa: E402
from treenode.signals import no_signals # noqa: E402
from documents.models import Tag # noqa: E402
from documents.views import TagViewSet # noqa: E402
User = get_user_model()
@contextmanager
def count_queries():
total = 0
def wrapper(execute, sql, params, many, context):
nonlocal total
total += 1
return execute(sql, params, many, context)
with connection.execute_wrapper(wrapper):
yield lambda: total
def measure_list(tag_count: int, user) -> tuple[int, float]:
"""Render Tag list with page_size=tag_count and return (queries, seconds)."""
rf = RequestFactory()
view = TagViewSet.as_view({"get": "list"})
request = rf.get("/api/tags/", {"page_size": tag_count})
force_authenticate(request, user=user)
with count_queries() as get_count:
start = time.perf_counter()
response = view(request)
response.render()
elapsed = time.perf_counter() - start
total_queries = get_count()
return total_queries, elapsed
def bulk_create_tags(count: int, parents: Iterable[Tag] | None = None) -> None:
"""Create tags; when parents provided, create one child per parent."""
if parents is None:
Tag.objects.bulk_create([Tag(name=f"Flat {i}") for i in range(count)])
return
children = []
for p in parents:
children.append(Tag(name=f"Child {p.id}", tn_parent=p))
Tag.objects.bulk_create(children)
def run():
# Ensure tables exist when pointing at a fresh DATA_DIR.
call_command("migrate", interactive=False, verbosity=0)
user, _ = User.objects.get_or_create(
username="admin",
defaults={"is_superuser": True, "is_staff": True},
)
# Flat scenario
Tag.objects.all().delete()
start = time.perf_counter()
bulk_create_tags(200)
flat_create = time.perf_counter() - start
q, t = measure_list(tag_count=200, user=user)
print(f"Flat create 200 -> {flat_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Nested scenario (parents + 2 children each => 600 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals(): # avoid per-save tree rebuild; rebuild once
parents = Tag.objects.bulk_create([Tag(name=f"Parent {i}") for i in range(200)])
children = []
for p in parents:
children.extend(
Tag(name=f"Child {p.id}-{j}", tn_parent=p) for j in range(2)
)
Tag.objects.bulk_create(children)
Tag.update_tree()
nested_create = time.perf_counter() - start
q, t = measure_list(tag_count=600, user=user)
print(f"Nested create 600 -> {nested_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Larger nested scenario (1 child per parent, 3000 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals():
parents = Tag.objects.bulk_create(
[Tag(name=f"Parent {i}") for i in range(1500)],
)
bulk_create_tags(0, parents=parents)
Tag.update_tree()
big_create = time.perf_counter() - start
q, t = measure_list(tag_count=3000, user=user)
print(f"Nested create 3000 -> {big_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
if __name__ == "__main__":
if "runserver" in sys.argv:
print("Run directly: .venv/bin/python scripts/tag_perf_probe.py") # noqa: T201
sys.exit(1)
run()

View File

@@ -1,5 +1,9 @@
from django.apps import AppConfig from django.apps import AppConfig
from django.db.models.signals import post_delete
from django.db.models.signals import post_save
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from treenode.signals import post_delete_treenode
from treenode.signals import post_save_treenode
class DocumentsConfig(AppConfig): class DocumentsConfig(AppConfig):
@@ -8,12 +12,14 @@ class DocumentsConfig(AppConfig):
verbose_name = _("Documents") verbose_name = _("Documents")
def ready(self): def ready(self):
from documents.models import Tag
from documents.signals import document_consumption_finished from documents.signals import document_consumption_finished
from documents.signals import document_updated from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags from documents.signals.handlers import add_inbox_tags
from documents.signals.handlers import add_to_index from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated from documents.signals.handlers import run_workflows_updated
from documents.signals.handlers import schedule_tag_tree_update
from documents.signals.handlers import set_correspondent from documents.signals.handlers import set_correspondent
from documents.signals.handlers import set_document_type from documents.signals.handlers import set_document_type
from documents.signals.handlers import set_storage_path from documents.signals.handlers import set_storage_path
@@ -28,6 +34,29 @@ class DocumentsConfig(AppConfig):
document_consumption_finished.connect(run_workflows_added) document_consumption_finished.connect(run_workflows_added)
document_updated.connect(run_workflows_updated) document_updated.connect(run_workflows_updated)
# treenode updates the entire tree on every save/delete via hooks
# so disconnect for Tags and run once-per-transaction.
post_save.disconnect(
post_save_treenode,
sender=Tag,
dispatch_uid="post_save_treenode",
)
post_delete.disconnect(
post_delete_treenode,
sender=Tag,
dispatch_uid="post_delete_treenode",
)
post_save.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_save",
)
post_delete.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_delete",
)
import documents.schema # noqa: F401 import documents.schema # noqa: F401
AppConfig.ready(self) AppConfig.ready(self)

View File

@@ -580,6 +580,10 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
), ),
) )
def get_children(self, obj): def get_children(self, obj):
children_map = self.context.get("children_map")
if children_map is not None:
children = children_map.get(obj.pk, [])
else:
filter_q = self.context.get("document_count_filter") filter_q = self.context.get("document_count_filter")
request = self.context.get("request") request = self.context.get("request")
if filter_q is None: if filter_q is None:
@@ -587,7 +591,7 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
filter_q = get_document_count_filter_for_user(user) filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q self.context["document_count_filter"] = filter_q
children_queryset = ( children = (
obj.get_children_queryset() obj.get_children_queryset()
.select_related("owner") .select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q)) .annotate(document_count=Count("documents", filter=filter_q))
@@ -595,15 +599,15 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
view = self.context.get("view") view = self.context.get("view")
ordering = ( ordering = (
OrderingFilter().get_ordering(request, children_queryset, view) OrderingFilter().get_ordering(request, children, view)
if request and view if request and view
else None else None
) )
ordering = ordering or (Lower("name"),) ordering = ordering or (Lower("name"),)
children_queryset = children_queryset.order_by(*ordering) children = children.order_by(*ordering)
serializer = TagSerializer( serializer = TagSerializer(
children_queryset, children,
many=True, many=True,
user=self.user, user=self.user,
full_perms=self.full_perms, full_perms=self.full_perms,

View File

@@ -19,6 +19,7 @@ from django.db import DatabaseError
from django.db import close_old_connections from django.db import close_old_connections
from django.db import connections from django.db import connections
from django.db import models from django.db import models
from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django.dispatch import receiver from django.dispatch import receiver
from django.utils import timezone from django.utils import timezone
@@ -60,6 +61,8 @@ if TYPE_CHECKING:
logger = logging.getLogger("paperless.handlers") logger = logging.getLogger("paperless.handlers")
_tag_tree_update_scheduled = False
def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs): def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
if document.owner is not None: if document.owner is not None:
@@ -944,3 +947,26 @@ def close_connection_pool_on_worker_init(**kwargs):
for conn in connections.all(initialized_only=True): for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool: if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool() conn.close_pool()
def schedule_tag_tree_update(**_kwargs):
"""
Schedule a single Tag.update_tree() at transaction commit.
Treenode's default post_save hooks rebuild the entire tree on every save,
which is very slow for large tag sets so collapse to one update per
transaction.
"""
global _tag_tree_update_scheduled
if _tag_tree_update_scheduled:
return
_tag_tree_update_scheduled = True
def _run():
global _tag_tree_update_scheduled
try:
Tag.update_tree()
finally:
_tag_tree_update_scheduled = False
transaction.on_commit(_run)

View File

@@ -1,6 +1,7 @@
from unittest import mock from unittest import mock
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import transaction
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents import bulk_edit from documents import bulk_edit
@@ -10,6 +11,7 @@ from documents.models import Workflow
from documents.models import WorkflowAction from documents.models import WorkflowAction
from documents.models import WorkflowTrigger from documents.models import WorkflowTrigger
from documents.serialisers import TagSerializer from documents.serialisers import TagSerializer
from documents.signals import handlers
from documents.signals.handlers import run_workflows from documents.signals.handlers import run_workflows
@@ -250,3 +252,31 @@ class TestTagHierarchy(APITestCase):
row for row in response.data["results"] if row["id"] == self.parent.pk row for row in response.data["results"] if row["id"] == self.parent.pk
) )
assert any(child["id"] == self.child.pk for child in parent_entry["children"]) assert any(child["id"] == self.child.pk for child in parent_entry["children"])
def test_tag_tree_deferred_update_runs_on_commit(self):
from django.db import transaction
# Create tags inside an explicit transaction and commit.
with transaction.atomic():
parent = Tag.objects.create(name="Parent 2")
child = Tag.objects.create(name="Child 2", tn_parent=parent)
# After commit, tn_* fields should be populated.
parent.refresh_from_db()
child.refresh_from_db()
assert parent.tn_children_count == 1
assert child.tn_ancestors_count == 1
def test_tag_tree_update_runs_once_per_transaction(self):
handlers._tag_tree_update_scheduled = False
with mock.patch("documents.signals.handlers.Tag.update_tree") as update_tree:
with self.captureOnCommitCallbacks(execute=True) as callbacks:
with transaction.atomic():
handlers.schedule_tag_tree_update()
handlers.schedule_tag_tree_update()
update_tree.assert_not_called()
assert handlers._tag_tree_update_scheduled is True
assert len(callbacks) == 1
update_tree.assert_called_once()
assert handlers._tag_tree_update_scheduled is False

View File

@@ -448,8 +448,43 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
def get_serializer_context(self): def get_serializer_context(self):
context = super().get_serializer_context() context = super().get_serializer_context()
context["document_count_filter"] = self.get_document_count_filter() context["document_count_filter"] = self.get_document_count_filter()
if hasattr(self, "_children_map"):
context["children_map"] = self._children_map
return context return context
def list(self, request, *args, **kwargs):
"""
Build a children map once to avoid per-parent queries in the serializer.
"""
queryset = self.filter_queryset(self.get_queryset())
ordering = OrderingFilter().get_ordering(request, queryset, self) or (
Lower("name"),
)
queryset = queryset.order_by(*ordering)
all_tags = list(queryset)
descendant_pks = {pk for tag in all_tags for pk in tag.get_descendants_pks()}
if descendant_pks:
filter_q = self.get_document_count_filter()
children_source = (
Tag.objects.filter(pk__in=descendant_pks | {t.pk for t in all_tags})
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
.order_by(*ordering)
)
else:
children_source = all_tags
children_map = {}
for tag in children_source:
children_map.setdefault(tag.tn_parent_id, []).append(tag)
self._children_map = children_map
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
def perform_update(self, serializer): def perform_update(self, serializer):
old_parent = self.get_object().get_parent() old_parent = self.get_object().get_parent()
tag = serializer.save() tag = serializer.save()