Compare commits

..

3 Commits

Author SHA1 Message Date
shamoon
016bccdcdf Remove skip_asn stuff 2025-12-31 11:52:31 -08:00
shamoon
92deebddd4 "Handoff" ASN when merging or editing PDFs 2025-12-31 11:50:27 -08:00
shamoon
c7efcee3d6 First, release ASNs before document replacement (and restore if needed) 2025-12-31 10:42:07 -08:00
12 changed files with 133 additions and 317 deletions

View File

@@ -1,139 +0,0 @@
# noqa: INP001
"""
Ad-hoc script to gauge Tag + treenode performance locally.
It bootstraps a fresh SQLite DB in a temp folder (or PAPERLESS_DATA_DIR),
uses locmem cache/redis to avoid external services, creates synthetic tags,
and measures:
- creation time
- query count and wall time for the Tag list view
Usage:
PAPERLESS_DEBUG=1 PAPERLESS_REDIS=locmem:// PYTHONPATH=src \
PAPERLESS_DATA_DIR=/tmp/paperless-tags-probe \
.venv/bin/python scripts/tag_perf_probe.py
"""
import os
import sys
import time
from collections.abc import Iterable
from contextlib import contextmanager
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "paperless.settings")
os.environ.setdefault("PAPERLESS_DEBUG", "1")
os.environ.setdefault("PAPERLESS_REDIS", "locmem://")
os.environ.setdefault("PYTHONPATH", "src")
import django
django.setup()
from django.contrib.auth import get_user_model # noqa: E402
from django.core.management import call_command # noqa: E402
from django.db import connection # noqa: E402
from django.test.client import RequestFactory # noqa: E402
from rest_framework.test import force_authenticate # noqa: E402
from treenode.signals import no_signals # noqa: E402
from documents.models import Tag # noqa: E402
from documents.views import TagViewSet # noqa: E402
User = get_user_model()
@contextmanager
def count_queries():
total = 0
def wrapper(execute, sql, params, many, context):
nonlocal total
total += 1
return execute(sql, params, many, context)
with connection.execute_wrapper(wrapper):
yield lambda: total
def measure_list(tag_count: int, user) -> tuple[int, float]:
"""Render Tag list with page_size=tag_count and return (queries, seconds)."""
rf = RequestFactory()
view = TagViewSet.as_view({"get": "list"})
request = rf.get("/api/tags/", {"page_size": tag_count})
force_authenticate(request, user=user)
with count_queries() as get_count:
start = time.perf_counter()
response = view(request)
response.render()
elapsed = time.perf_counter() - start
total_queries = get_count()
return total_queries, elapsed
def bulk_create_tags(count: int, parents: Iterable[Tag] | None = None) -> None:
"""Create tags; when parents provided, create one child per parent."""
if parents is None:
Tag.objects.bulk_create([Tag(name=f"Flat {i}") for i in range(count)])
return
children = []
for p in parents:
children.append(Tag(name=f"Child {p.id}", tn_parent=p))
Tag.objects.bulk_create(children)
def run():
# Ensure tables exist when pointing at a fresh DATA_DIR.
call_command("migrate", interactive=False, verbosity=0)
user, _ = User.objects.get_or_create(
username="admin",
defaults={"is_superuser": True, "is_staff": True},
)
# Flat scenario
Tag.objects.all().delete()
start = time.perf_counter()
bulk_create_tags(200)
flat_create = time.perf_counter() - start
q, t = measure_list(tag_count=200, user=user)
print(f"Flat create 200 -> {flat_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Nested scenario (parents + 2 children each => 600 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals(): # avoid per-save tree rebuild; rebuild once
parents = Tag.objects.bulk_create([Tag(name=f"Parent {i}") for i in range(200)])
children = []
for p in parents:
children.extend(
Tag(name=f"Child {p.id}-{j}", tn_parent=p) for j in range(2)
)
Tag.objects.bulk_create(children)
Tag.update_tree()
nested_create = time.perf_counter() - start
q, t = measure_list(tag_count=600, user=user)
print(f"Nested create 600 -> {nested_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Larger nested scenario (1 child per parent, 3000 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals():
parents = Tag.objects.bulk_create(
[Tag(name=f"Parent {i}") for i in range(1500)],
)
bulk_create_tags(0, parents=parents)
Tag.update_tree()
big_create = time.perf_counter() - start
q, t = measure_list(tag_count=3000, user=user)
print(f"Nested create 3000 -> {big_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
if __name__ == "__main__":
if "runserver" in sys.argv:
print("Run directly: .venv/bin/python scripts/tag_perf_probe.py") # noqa: T201
sys.exit(1)
run()

View File

@@ -1,9 +1,5 @@
from django.apps import AppConfig
from django.db.models.signals import post_delete
from django.db.models.signals import post_save
from django.utils.translation import gettext_lazy as _
from treenode.signals import post_delete_treenode
from treenode.signals import post_save_treenode
class DocumentsConfig(AppConfig):
@@ -12,14 +8,12 @@ class DocumentsConfig(AppConfig):
verbose_name = _("Documents")
def ready(self):
from documents.models import Tag
from documents.signals import document_consumption_finished
from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags
from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated
from documents.signals.handlers import schedule_tag_tree_update
from documents.signals.handlers import set_correspondent
from documents.signals.handlers import set_document_type
from documents.signals.handlers import set_storage_path
@@ -34,29 +28,6 @@ class DocumentsConfig(AppConfig):
document_consumption_finished.connect(run_workflows_added)
document_updated.connect(run_workflows_updated)
# treenode updates the entire tree on every save/delete via hooks
# so disconnect for Tags and run once-per-transaction.
post_save.disconnect(
post_save_treenode,
sender=Tag,
dispatch_uid="post_save_treenode",
)
post_delete.disconnect(
post_delete_treenode,
sender=Tag,
dispatch_uid="post_delete_treenode",
)
post_save.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_save",
)
post_delete.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_delete",
)
import documents.schema # noqa: F401
AppConfig.ready(self)

View File

@@ -186,11 +186,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
# Update/overwrite an ASN if possible
# After splitting, as otherwise each split document gets the same ASN
if (
self.settings.barcode_enable_asn
and not self.metadata.skip_asn
and (located_asn := self.asn) is not None
):
if self.settings.barcode_enable_asn and (located_asn := self.asn) is not None:
logger.info(f"Found ASN in barcode: {located_asn}")
self.metadata.asn = located_asn

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import TYPE_CHECKING
from typing import Literal
from celery import chain
from celery import chord
from celery import group
from celery import shared_task
@@ -38,6 +37,42 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger("paperless.bulk_edit")
@shared_task(bind=True)
def restore_archive_serial_numbers_task(
self,
backup: dict[int, int],
*args,
**kwargs,
) -> None:
restore_archive_serial_numbers(backup)
def release_archive_serial_numbers(doc_ids: list[int]) -> dict[int, int]:
"""
Clears ASNs on documents that are about to be replaced so new documents
can be assigned ASNs without uniqueness collisions. Returns a backup map
of doc_id -> previous ASN for potential restoration.
"""
qs = Document.objects.filter(
id__in=doc_ids,
archive_serial_number__isnull=False,
).only("pk", "archive_serial_number")
backup = dict(qs.values_list("pk", "archive_serial_number"))
qs.update(archive_serial_number=None)
logger.info(f"Released archive serial numbers for documents {list(backup.keys())}")
return backup
def restore_archive_serial_numbers(backup: dict[int, int]) -> None:
"""
Restores ASNs using the provided backup map, intended for
rollback when replacement consumption fails.
"""
for doc_id, asn in backup.items():
Document.objects.filter(pk=doc_id).update(archive_serial_number=asn)
logger.info(f"Restored archive serial numbers for documents {list(backup.keys())}")
def set_correspondent(
doc_ids: list[int],
correspondent: Correspondent,
@@ -386,6 +421,7 @@ def merge(
merged_pdf = pikepdf.new()
version: str = merged_pdf.pdf_version
handoff_asn: int | None = None
# use doc_ids to preserve order
for doc_id in doc_ids:
doc = qs.get(id=doc_id)
@@ -401,6 +437,8 @@ def merge(
version = max(version, pdf.pdf_version)
merged_pdf.pages.extend(pdf.pages)
affected_docs.append(doc.id)
if handoff_asn is None and doc.archive_serial_number is not None:
handoff_asn = doc.archive_serial_number
except Exception as e:
logger.exception(
f"Error merging document {doc.id}, it will not be included in the merge: {e}",
@@ -426,6 +464,8 @@ def merge(
DocumentMetadataOverrides.from_document(metadata_document)
)
overrides.title = metadata_document.title + " (merged)"
if metadata_document.archive_serial_number is not None:
handoff_asn = metadata_document.archive_serial_number
else:
overrides = DocumentMetadataOverrides()
else:
@@ -433,8 +473,9 @@ def merge(
if user is not None:
overrides.owner_id = user.id
# Avoid copying or detecting ASN from merged PDFs to prevent collision
overrides.skip_asn = True
if delete_originals and handoff_asn is not None:
overrides.asn = handoff_asn
logger.info("Adding merged document to the task queue.")
@@ -447,12 +488,20 @@ def merge(
)
if delete_originals:
backup = release_archive_serial_numbers(affected_docs)
logger.info(
"Queueing removal of original documents after consumption of merged document",
)
chain(consume_task, delete.si(affected_docs)).delay()
else:
consume_task.delay()
try:
consume_task.apply_async(
link=[delete.si(affected_docs)],
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
consume_task.delay()
return "OK"
@@ -508,10 +557,20 @@ def split(
)
if delete_originals:
backup = release_archive_serial_numbers([doc.id])
logger.info(
"Queueing removal of original document after consumption of the split documents",
)
chord(header=consume_tasks, body=delete.si([doc.id])).delay()
try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
group(consume_tasks).delay()
@@ -614,7 +673,8 @@ def edit_pdf(
)
if user is not None:
overrides.owner_id = user.id
if delete_original and len(pdf_docs) == 1:
overrides.asn = doc.archive_serial_number
for idx, pdf in enumerate(pdf_docs, start=1):
filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
@@ -633,7 +693,17 @@ def edit_pdf(
)
if delete_original:
chord(header=consume_tasks, body=delete.si([doc.id])).delay()
backup = release_archive_serial_numbers([doc.id])
try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
group(consume_tasks).delay()

View File

@@ -696,7 +696,7 @@ class ConsumerPlugin(
pk=self.metadata.storage_path_id,
)
if self.metadata.asn is not None and not self.metadata.skip_asn:
if self.metadata.asn is not None:
document.archive_serial_number = self.metadata.asn
if self.metadata.owner_id:
@@ -812,8 +812,8 @@ class ConsumerPreflightPlugin(
"""
Check that if override_asn is given, it is unique and within a valid range
"""
if self.metadata.skip_asn or self.metadata.asn is None:
# if skip is set or ASN is None
if self.metadata.asn is None:
# if ASN is None
return
# Validate the range is above zero and less than uint32_t max
# otherwise, Whoosh can't handle it in the index

View File

@@ -30,7 +30,6 @@ class DocumentMetadataOverrides:
change_users: list[int] | None = None
change_groups: list[int] | None = None
custom_fields: dict | None = None
skip_asn: bool = False
def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides":
"""
@@ -50,8 +49,6 @@ class DocumentMetadataOverrides:
self.storage_path_id = other.storage_path_id
if other.owner_id is not None:
self.owner_id = other.owner_id
if other.skip_asn:
self.skip_asn = True
# merge
if self.tag_ids is None:

View File

@@ -580,34 +580,30 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
),
)
def get_children(self, obj):
children_map = self.context.get("children_map")
if children_map is not None:
children = children_map.get(obj.pk, [])
else:
filter_q = self.context.get("document_count_filter")
request = self.context.get("request")
if filter_q is None:
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
filter_q = self.context.get("document_count_filter")
request = self.context.get("request")
if filter_q is None:
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
children = (
obj.get_children_queryset()
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
)
children_queryset = (
obj.get_children_queryset()
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
)
view = self.context.get("view")
ordering = (
OrderingFilter().get_ordering(request, children, view)
if request and view
else None
)
ordering = ordering or (Lower("name"),)
children = children.order_by(*ordering)
view = self.context.get("view")
ordering = (
OrderingFilter().get_ordering(request, children_queryset, view)
if request and view
else None
)
ordering = ordering or (Lower("name"),)
children_queryset = children_queryset.order_by(*ordering)
serializer = TagSerializer(
children,
children_queryset,
many=True,
user=self.user,
full_perms=self.full_perms,

View File

@@ -19,7 +19,6 @@ from django.db import DatabaseError
from django.db import close_old_connections
from django.db import connections
from django.db import models
from django.db import transaction
from django.db.models import Q
from django.dispatch import receiver
from django.utils import timezone
@@ -61,8 +60,6 @@ if TYPE_CHECKING:
logger = logging.getLogger("paperless.handlers")
_tag_tree_update_scheduled = False
def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
if document.owner is not None:
@@ -947,26 +944,3 @@ def close_connection_pool_on_worker_init(**kwargs):
for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool()
def schedule_tag_tree_update(**_kwargs):
"""
Schedule a single Tag.update_tree() at transaction commit.
Treenode's default post_save hooks rebuild the entire tree on every save,
which is very slow for large tag sets so collapse to one update per
transaction.
"""
global _tag_tree_update_scheduled
if _tag_tree_update_scheduled:
return
_tag_tree_update_scheduled = True
def _run():
global _tag_tree_update_scheduled
try:
Tag.update_tree()
finally:
_tag_tree_update_scheduled = False
transaction.on_commit(_run)

View File

@@ -602,23 +602,21 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn)
# No metadata_document_id, delete_originals False, so ASN should be None
self.assertIsNone(consume_file_args[1].asn)
# With metadata_document_id overrides
result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (merged)")
self.assertEqual(consume_file_args[1].created, self.doc2.created)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertEqual(result, "OK")
@mock.patch("documents.bulk_edit.delete.si")
@mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.chain")
def test_merge_and_delete_originals(
self,
mock_chain,
mock_consume_file,
mock_delete_documents,
):
@@ -632,6 +630,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
- Document deletion task should be called
"""
doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id]
self.doc1.archive_serial_number = 101
self.doc2.archive_serial_number = 102
self.doc3.archive_serial_number = 103
self.doc1.save()
self.doc2.save()
self.doc3.save()
result = bulk_edit.merge(doc_ids, delete_originals=True)
self.assertEqual(result, "OK")
@@ -642,7 +646,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_called()
mock_delete_documents.assert_called()
mock_chain.assert_called_once()
consume_sig = mock_consume_file.return_value
consume_sig.apply_async.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(
@@ -650,7 +655,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertEqual(consume_file_args[1].asn, 101)
delete_documents_args, _ = mock_delete_documents.call_args
self.assertEqual(
@@ -658,6 +663,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids,
)
self.doc1.refresh_from_db()
self.doc2.refresh_from_db()
self.doc3.refresh_from_db()
self.assertIsNone(self.doc1.archive_serial_number)
self.assertIsNone(self.doc2.archive_serial_number)
self.assertIsNone(self.doc3.archive_serial_number)
@mock.patch("documents.tasks.consume_file.s")
def test_merge_with_archive_fallback(self, mock_consume_file):
"""
@@ -726,6 +738,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertIsNone(consume_file_args[1].asn)
self.assertEqual(result, "OK")
@@ -750,6 +763,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
"""
doc_ids = [self.doc2.id]
pages = [[1, 2], [3]]
self.doc2.archive_serial_number = 200
self.doc2.save()
result = bulk_edit.split(doc_ids, pages, delete_originals=True)
self.assertEqual(result, "OK")
@@ -767,6 +782,9 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids,
)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("pikepdf.Pdf.save")
def test_split_with_errors(self, mock_save_pdf, mock_consume_file):
@@ -967,10 +985,16 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_chord.return_value.delay.return_value = None
doc_ids = [self.doc2.id]
operations = [{"page": 1}, {"page": 2}]
self.doc2.archive_serial_number = 250
self.doc2.save()
result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.assertEqual(result, "OK")
mock_chord.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].asn, 250)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.update_document_content_maybe_archive_file.delay")
def test_edit_pdf_with_update_document(self, mock_update_document):

View File

@@ -412,14 +412,6 @@ class TestConsumer(
self.assertEqual(document.archive_serial_number, 123)
self._assert_first_last_send_progress()
def testMetadataOverridesSkipAsnPropagation(self):
overrides = DocumentMetadataOverrides()
incoming = DocumentMetadataOverrides(skip_asn=True)
overrides.update(incoming)
self.assertTrue(overrides.skip_asn)
def testOverrideTitlePlaceholders(self):
c = Correspondent.objects.create(name="Correspondent Name")
dt = DocumentType.objects.create(name="DocType Name")

View File

@@ -1,7 +1,6 @@
from unittest import mock
from django.contrib.auth.models import User
from django.db import transaction
from rest_framework.test import APITestCase
from documents import bulk_edit
@@ -11,7 +10,6 @@ from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
from documents.serialisers import TagSerializer
from documents.signals import handlers
from documents.signals.handlers import run_workflows
@@ -252,31 +250,3 @@ class TestTagHierarchy(APITestCase):
row for row in response.data["results"] if row["id"] == self.parent.pk
)
assert any(child["id"] == self.child.pk for child in parent_entry["children"])
def test_tag_tree_deferred_update_runs_on_commit(self):
from django.db import transaction
# Create tags inside an explicit transaction and commit.
with transaction.atomic():
parent = Tag.objects.create(name="Parent 2")
child = Tag.objects.create(name="Child 2", tn_parent=parent)
# After commit, tn_* fields should be populated.
parent.refresh_from_db()
child.refresh_from_db()
assert parent.tn_children_count == 1
assert child.tn_ancestors_count == 1
def test_tag_tree_update_runs_once_per_transaction(self):
handlers._tag_tree_update_scheduled = False
with mock.patch("documents.signals.handlers.Tag.update_tree") as update_tree:
with self.captureOnCommitCallbacks(execute=True) as callbacks:
with transaction.atomic():
handlers.schedule_tag_tree_update()
handlers.schedule_tag_tree_update()
update_tree.assert_not_called()
assert handlers._tag_tree_update_scheduled is True
assert len(callbacks) == 1
update_tree.assert_called_once()
assert handlers._tag_tree_update_scheduled is False

View File

@@ -448,43 +448,8 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
def get_serializer_context(self):
context = super().get_serializer_context()
context["document_count_filter"] = self.get_document_count_filter()
if hasattr(self, "_children_map"):
context["children_map"] = self._children_map
return context
def list(self, request, *args, **kwargs):
"""
Build a children map once to avoid per-parent queries in the serializer.
"""
queryset = self.filter_queryset(self.get_queryset())
ordering = OrderingFilter().get_ordering(request, queryset, self) or (
Lower("name"),
)
queryset = queryset.order_by(*ordering)
all_tags = list(queryset)
descendant_pks = {pk for tag in all_tags for pk in tag.get_descendants_pks()}
if descendant_pks:
filter_q = self.get_document_count_filter()
children_source = (
Tag.objects.filter(pk__in=descendant_pks | {t.pk for t in all_tags})
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
.order_by(*ordering)
)
else:
children_source = all_tags
children_map = {}
for tag in children_source:
children_map.setdefault(tag.tn_parent_id, []).append(tag)
self._children_map = children_map
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
def perform_update(self, serializer):
old_parent = self.get_object().get_parent()
tag = serializer.save()