Compare commits

..

3 Commits

Author SHA1 Message Date
shamoon
016bccdcdf Remove skip_asn stuff 2025-12-31 11:52:31 -08:00
shamoon
92deebddd4 "Handoff" ASN when merging or editing PDFs 2025-12-31 11:50:27 -08:00
shamoon
c7efcee3d6 First, release ASNs before document replacement (and restore if needed) 2025-12-31 10:42:07 -08:00
12 changed files with 133 additions and 317 deletions

View File

@@ -1,139 +0,0 @@
# noqa: INP001
"""
Ad-hoc script to gauge Tag + treenode performance locally.
It bootstraps a fresh SQLite DB in a temp folder (or PAPERLESS_DATA_DIR),
uses locmem cache/redis to avoid external services, creates synthetic tags,
and measures:
- creation time
- query count and wall time for the Tag list view
Usage:
PAPERLESS_DEBUG=1 PAPERLESS_REDIS=locmem:// PYTHONPATH=src \
PAPERLESS_DATA_DIR=/tmp/paperless-tags-probe \
.venv/bin/python scripts/tag_perf_probe.py
"""
import os
import sys
import time
from collections.abc import Iterable
from contextlib import contextmanager
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "paperless.settings")
os.environ.setdefault("PAPERLESS_DEBUG", "1")
os.environ.setdefault("PAPERLESS_REDIS", "locmem://")
os.environ.setdefault("PYTHONPATH", "src")
import django
django.setup()
from django.contrib.auth import get_user_model # noqa: E402
from django.core.management import call_command # noqa: E402
from django.db import connection # noqa: E402
from django.test.client import RequestFactory # noqa: E402
from rest_framework.test import force_authenticate # noqa: E402
from treenode.signals import no_signals # noqa: E402
from documents.models import Tag # noqa: E402
from documents.views import TagViewSet # noqa: E402
User = get_user_model()
@contextmanager
def count_queries():
total = 0
def wrapper(execute, sql, params, many, context):
nonlocal total
total += 1
return execute(sql, params, many, context)
with connection.execute_wrapper(wrapper):
yield lambda: total
def measure_list(tag_count: int, user) -> tuple[int, float]:
"""Render Tag list with page_size=tag_count and return (queries, seconds)."""
rf = RequestFactory()
view = TagViewSet.as_view({"get": "list"})
request = rf.get("/api/tags/", {"page_size": tag_count})
force_authenticate(request, user=user)
with count_queries() as get_count:
start = time.perf_counter()
response = view(request)
response.render()
elapsed = time.perf_counter() - start
total_queries = get_count()
return total_queries, elapsed
def bulk_create_tags(count: int, parents: Iterable[Tag] | None = None) -> None:
"""Create tags; when parents provided, create one child per parent."""
if parents is None:
Tag.objects.bulk_create([Tag(name=f"Flat {i}") for i in range(count)])
return
children = []
for p in parents:
children.append(Tag(name=f"Child {p.id}", tn_parent=p))
Tag.objects.bulk_create(children)
def run():
# Ensure tables exist when pointing at a fresh DATA_DIR.
call_command("migrate", interactive=False, verbosity=0)
user, _ = User.objects.get_or_create(
username="admin",
defaults={"is_superuser": True, "is_staff": True},
)
# Flat scenario
Tag.objects.all().delete()
start = time.perf_counter()
bulk_create_tags(200)
flat_create = time.perf_counter() - start
q, t = measure_list(tag_count=200, user=user)
print(f"Flat create 200 -> {flat_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Nested scenario (parents + 2 children each => 600 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals(): # avoid per-save tree rebuild; rebuild once
parents = Tag.objects.bulk_create([Tag(name=f"Parent {i}") for i in range(200)])
children = []
for p in parents:
children.extend(
Tag(name=f"Child {p.id}-{j}", tn_parent=p) for j in range(2)
)
Tag.objects.bulk_create(children)
Tag.update_tree()
nested_create = time.perf_counter() - start
q, t = measure_list(tag_count=600, user=user)
print(f"Nested create 600 -> {nested_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
# Larger nested scenario (1 child per parent, 3000 total)
Tag.objects.all().delete()
start = time.perf_counter()
with no_signals():
parents = Tag.objects.bulk_create(
[Tag(name=f"Parent {i}") for i in range(1500)],
)
bulk_create_tags(0, parents=parents)
Tag.update_tree()
big_create = time.perf_counter() - start
q, t = measure_list(tag_count=3000, user=user)
print(f"Nested create 3000 -> {big_create:.2f}s; list -> {q} queries, {t:.2f}s") # noqa: T201
if __name__ == "__main__":
if "runserver" in sys.argv:
print("Run directly: .venv/bin/python scripts/tag_perf_probe.py") # noqa: T201
sys.exit(1)
run()

View File

@@ -1,9 +1,5 @@
from django.apps import AppConfig from django.apps import AppConfig
from django.db.models.signals import post_delete
from django.db.models.signals import post_save
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from treenode.signals import post_delete_treenode
from treenode.signals import post_save_treenode
class DocumentsConfig(AppConfig): class DocumentsConfig(AppConfig):
@@ -12,14 +8,12 @@ class DocumentsConfig(AppConfig):
verbose_name = _("Documents") verbose_name = _("Documents")
def ready(self): def ready(self):
from documents.models import Tag
from documents.signals import document_consumption_finished from documents.signals import document_consumption_finished
from documents.signals import document_updated from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags from documents.signals.handlers import add_inbox_tags
from documents.signals.handlers import add_to_index from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated from documents.signals.handlers import run_workflows_updated
from documents.signals.handlers import schedule_tag_tree_update
from documents.signals.handlers import set_correspondent from documents.signals.handlers import set_correspondent
from documents.signals.handlers import set_document_type from documents.signals.handlers import set_document_type
from documents.signals.handlers import set_storage_path from documents.signals.handlers import set_storage_path
@@ -34,29 +28,6 @@ class DocumentsConfig(AppConfig):
document_consumption_finished.connect(run_workflows_added) document_consumption_finished.connect(run_workflows_added)
document_updated.connect(run_workflows_updated) document_updated.connect(run_workflows_updated)
# treenode updates the entire tree on every save/delete via hooks
# so disconnect for Tags and run once-per-transaction.
post_save.disconnect(
post_save_treenode,
sender=Tag,
dispatch_uid="post_save_treenode",
)
post_delete.disconnect(
post_delete_treenode,
sender=Tag,
dispatch_uid="post_delete_treenode",
)
post_save.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_save",
)
post_delete.connect(
schedule_tag_tree_update,
sender=Tag,
dispatch_uid="paperless_tag_mark_dirty_delete",
)
import documents.schema # noqa: F401 import documents.schema # noqa: F401
AppConfig.ready(self) AppConfig.ready(self)

View File

@@ -186,11 +186,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
# Update/overwrite an ASN if possible # Update/overwrite an ASN if possible
# After splitting, as otherwise each split document gets the same ASN # After splitting, as otherwise each split document gets the same ASN
if ( if self.settings.barcode_enable_asn and (located_asn := self.asn) is not None:
self.settings.barcode_enable_asn
and not self.metadata.skip_asn
and (located_asn := self.asn) is not None
):
logger.info(f"Found ASN in barcode: {located_asn}") logger.info(f"Found ASN in barcode: {located_asn}")
self.metadata.asn = located_asn self.metadata.asn = located_asn

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Literal from typing import Literal
from celery import chain
from celery import chord from celery import chord
from celery import group from celery import group
from celery import shared_task from celery import shared_task
@@ -38,6 +37,42 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger("paperless.bulk_edit") logger: logging.Logger = logging.getLogger("paperless.bulk_edit")
@shared_task(bind=True)
def restore_archive_serial_numbers_task(
self,
backup: dict[int, int],
*args,
**kwargs,
) -> None:
restore_archive_serial_numbers(backup)
def release_archive_serial_numbers(doc_ids: list[int]) -> dict[int, int]:
"""
Clears ASNs on documents that are about to be replaced so new documents
can be assigned ASNs without uniqueness collisions. Returns a backup map
of doc_id -> previous ASN for potential restoration.
"""
qs = Document.objects.filter(
id__in=doc_ids,
archive_serial_number__isnull=False,
).only("pk", "archive_serial_number")
backup = dict(qs.values_list("pk", "archive_serial_number"))
qs.update(archive_serial_number=None)
logger.info(f"Released archive serial numbers for documents {list(backup.keys())}")
return backup
def restore_archive_serial_numbers(backup: dict[int, int]) -> None:
"""
Restores ASNs using the provided backup map, intended for
rollback when replacement consumption fails.
"""
for doc_id, asn in backup.items():
Document.objects.filter(pk=doc_id).update(archive_serial_number=asn)
logger.info(f"Restored archive serial numbers for documents {list(backup.keys())}")
def set_correspondent( def set_correspondent(
doc_ids: list[int], doc_ids: list[int],
correspondent: Correspondent, correspondent: Correspondent,
@@ -386,6 +421,7 @@ def merge(
merged_pdf = pikepdf.new() merged_pdf = pikepdf.new()
version: str = merged_pdf.pdf_version version: str = merged_pdf.pdf_version
handoff_asn: int | None = None
# use doc_ids to preserve order # use doc_ids to preserve order
for doc_id in doc_ids: for doc_id in doc_ids:
doc = qs.get(id=doc_id) doc = qs.get(id=doc_id)
@@ -401,6 +437,8 @@ def merge(
version = max(version, pdf.pdf_version) version = max(version, pdf.pdf_version)
merged_pdf.pages.extend(pdf.pages) merged_pdf.pages.extend(pdf.pages)
affected_docs.append(doc.id) affected_docs.append(doc.id)
if handoff_asn is None and doc.archive_serial_number is not None:
handoff_asn = doc.archive_serial_number
except Exception as e: except Exception as e:
logger.exception( logger.exception(
f"Error merging document {doc.id}, it will not be included in the merge: {e}", f"Error merging document {doc.id}, it will not be included in the merge: {e}",
@@ -426,6 +464,8 @@ def merge(
DocumentMetadataOverrides.from_document(metadata_document) DocumentMetadataOverrides.from_document(metadata_document)
) )
overrides.title = metadata_document.title + " (merged)" overrides.title = metadata_document.title + " (merged)"
if metadata_document.archive_serial_number is not None:
handoff_asn = metadata_document.archive_serial_number
else: else:
overrides = DocumentMetadataOverrides() overrides = DocumentMetadataOverrides()
else: else:
@@ -433,8 +473,9 @@ def merge(
if user is not None: if user is not None:
overrides.owner_id = user.id overrides.owner_id = user.id
# Avoid copying or detecting ASN from merged PDFs to prevent collision
overrides.skip_asn = True if delete_originals and handoff_asn is not None:
overrides.asn = handoff_asn
logger.info("Adding merged document to the task queue.") logger.info("Adding merged document to the task queue.")
@@ -447,12 +488,20 @@ def merge(
) )
if delete_originals: if delete_originals:
backup = release_archive_serial_numbers(affected_docs)
logger.info( logger.info(
"Queueing removal of original documents after consumption of merged document", "Queueing removal of original documents after consumption of merged document",
) )
chain(consume_task, delete.si(affected_docs)).delay() try:
else: consume_task.apply_async(
consume_task.delay() link=[delete.si(affected_docs)],
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
consume_task.delay()
return "OK" return "OK"
@@ -508,10 +557,20 @@ def split(
) )
if delete_originals: if delete_originals:
backup = release_archive_serial_numbers([doc.id])
logger.info( logger.info(
"Queueing removal of original document after consumption of the split documents", "Queueing removal of original document after consumption of the split documents",
) )
chord(header=consume_tasks, body=delete.si([doc.id])).delay() try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else: else:
group(consume_tasks).delay() group(consume_tasks).delay()
@@ -614,7 +673,8 @@ def edit_pdf(
) )
if user is not None: if user is not None:
overrides.owner_id = user.id overrides.owner_id = user.id
if delete_original and len(pdf_docs) == 1:
overrides.asn = doc.archive_serial_number
for idx, pdf in enumerate(pdf_docs, start=1): for idx, pdf in enumerate(pdf_docs, start=1):
filepath: Path = ( filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR)) Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
@@ -633,7 +693,17 @@ def edit_pdf(
) )
if delete_original: if delete_original:
chord(header=consume_tasks, body=delete.si([doc.id])).delay() backup = release_archive_serial_numbers([doc.id])
try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else: else:
group(consume_tasks).delay() group(consume_tasks).delay()

View File

@@ -696,7 +696,7 @@ class ConsumerPlugin(
pk=self.metadata.storage_path_id, pk=self.metadata.storage_path_id,
) )
if self.metadata.asn is not None and not self.metadata.skip_asn: if self.metadata.asn is not None:
document.archive_serial_number = self.metadata.asn document.archive_serial_number = self.metadata.asn
if self.metadata.owner_id: if self.metadata.owner_id:
@@ -812,8 +812,8 @@ class ConsumerPreflightPlugin(
""" """
Check that if override_asn is given, it is unique and within a valid range Check that if override_asn is given, it is unique and within a valid range
""" """
if self.metadata.skip_asn or self.metadata.asn is None: if self.metadata.asn is None:
# if skip is set or ASN is None # if ASN is None
return return
# Validate the range is above zero and less than uint32_t max # Validate the range is above zero and less than uint32_t max
# otherwise, Whoosh can't handle it in the index # otherwise, Whoosh can't handle it in the index

View File

@@ -30,7 +30,6 @@ class DocumentMetadataOverrides:
change_users: list[int] | None = None change_users: list[int] | None = None
change_groups: list[int] | None = None change_groups: list[int] | None = None
custom_fields: dict | None = None custom_fields: dict | None = None
skip_asn: bool = False
def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides": def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides":
""" """
@@ -50,8 +49,6 @@ class DocumentMetadataOverrides:
self.storage_path_id = other.storage_path_id self.storage_path_id = other.storage_path_id
if other.owner_id is not None: if other.owner_id is not None:
self.owner_id = other.owner_id self.owner_id = other.owner_id
if other.skip_asn:
self.skip_asn = True
# merge # merge
if self.tag_ids is None: if self.tag_ids is None:

View File

@@ -580,34 +580,30 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
), ),
) )
def get_children(self, obj): def get_children(self, obj):
children_map = self.context.get("children_map") filter_q = self.context.get("document_count_filter")
if children_map is not None: request = self.context.get("request")
children = children_map.get(obj.pk, []) if filter_q is None:
else: user = getattr(request, "user", None) if request else None
filter_q = self.context.get("document_count_filter") filter_q = get_document_count_filter_for_user(user)
request = self.context.get("request") self.context["document_count_filter"] = filter_q
if filter_q is None:
user = getattr(request, "user", None) if request else None
filter_q = get_document_count_filter_for_user(user)
self.context["document_count_filter"] = filter_q
children = ( children_queryset = (
obj.get_children_queryset() obj.get_children_queryset()
.select_related("owner") .select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q)) .annotate(document_count=Count("documents", filter=filter_q))
) )
view = self.context.get("view") view = self.context.get("view")
ordering = ( ordering = (
OrderingFilter().get_ordering(request, children, view) OrderingFilter().get_ordering(request, children_queryset, view)
if request and view if request and view
else None else None
) )
ordering = ordering or (Lower("name"),) ordering = ordering or (Lower("name"),)
children = children.order_by(*ordering) children_queryset = children_queryset.order_by(*ordering)
serializer = TagSerializer( serializer = TagSerializer(
children, children_queryset,
many=True, many=True,
user=self.user, user=self.user,
full_perms=self.full_perms, full_perms=self.full_perms,

View File

@@ -19,7 +19,6 @@ from django.db import DatabaseError
from django.db import close_old_connections from django.db import close_old_connections
from django.db import connections from django.db import connections
from django.db import models from django.db import models
from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django.dispatch import receiver from django.dispatch import receiver
from django.utils import timezone from django.utils import timezone
@@ -61,8 +60,6 @@ if TYPE_CHECKING:
logger = logging.getLogger("paperless.handlers") logger = logging.getLogger("paperless.handlers")
_tag_tree_update_scheduled = False
def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs): def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
if document.owner is not None: if document.owner is not None:
@@ -947,26 +944,3 @@ def close_connection_pool_on_worker_init(**kwargs):
for conn in connections.all(initialized_only=True): for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool: if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool() conn.close_pool()
def schedule_tag_tree_update(**_kwargs):
"""
Schedule a single Tag.update_tree() at transaction commit.
Treenode's default post_save hooks rebuild the entire tree on every save,
which is very slow for large tag sets so collapse to one update per
transaction.
"""
global _tag_tree_update_scheduled
if _tag_tree_update_scheduled:
return
_tag_tree_update_scheduled = True
def _run():
global _tag_tree_update_scheduled
try:
Tag.update_tree()
finally:
_tag_tree_update_scheduled = False
transaction.on_commit(_run)

View File

@@ -602,23 +602,21 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename, expected_filename,
) )
self.assertEqual(consume_file_args[1].title, None) self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn) # No metadata_document_id, delete_originals False, so ASN should be None
self.assertIsNone(consume_file_args[1].asn)
# With metadata_document_id overrides # With metadata_document_id overrides
result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id) result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id)
consume_file_args, _ = mock_consume_file.call_args consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (merged)") self.assertEqual(consume_file_args[1].title, "B (merged)")
self.assertEqual(consume_file_args[1].created, self.doc2.created) self.assertEqual(consume_file_args[1].created, self.doc2.created)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertEqual(result, "OK") self.assertEqual(result, "OK")
@mock.patch("documents.bulk_edit.delete.si") @mock.patch("documents.bulk_edit.delete.si")
@mock.patch("documents.tasks.consume_file.s") @mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.chain")
def test_merge_and_delete_originals( def test_merge_and_delete_originals(
self, self,
mock_chain,
mock_consume_file, mock_consume_file,
mock_delete_documents, mock_delete_documents,
): ):
@@ -632,6 +630,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
- Document deletion task should be called - Document deletion task should be called
""" """
doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id] doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id]
self.doc1.archive_serial_number = 101
self.doc2.archive_serial_number = 102
self.doc3.archive_serial_number = 103
self.doc1.save()
self.doc2.save()
self.doc3.save()
result = bulk_edit.merge(doc_ids, delete_originals=True) result = bulk_edit.merge(doc_ids, delete_originals=True)
self.assertEqual(result, "OK") self.assertEqual(result, "OK")
@@ -642,7 +646,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_called() mock_consume_file.assert_called()
mock_delete_documents.assert_called() mock_delete_documents.assert_called()
mock_chain.assert_called_once() consume_sig = mock_consume_file.return_value
consume_sig.apply_async.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args consume_file_args, _ = mock_consume_file.call_args
self.assertEqual( self.assertEqual(
@@ -650,7 +655,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename, expected_filename,
) )
self.assertEqual(consume_file_args[1].title, None) self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn) self.assertEqual(consume_file_args[1].asn, 101)
delete_documents_args, _ = mock_delete_documents.call_args delete_documents_args, _ = mock_delete_documents.call_args
self.assertEqual( self.assertEqual(
@@ -658,6 +663,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids, doc_ids,
) )
self.doc1.refresh_from_db()
self.doc2.refresh_from_db()
self.doc3.refresh_from_db()
self.assertIsNone(self.doc1.archive_serial_number)
self.assertIsNone(self.doc2.archive_serial_number)
self.assertIsNone(self.doc3.archive_serial_number)
@mock.patch("documents.tasks.consume_file.s") @mock.patch("documents.tasks.consume_file.s")
def test_merge_with_archive_fallback(self, mock_consume_file): def test_merge_with_archive_fallback(self, mock_consume_file):
""" """
@@ -726,6 +738,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(mock_consume_file.call_count, 2) self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)") self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertIsNone(consume_file_args[1].asn)
self.assertEqual(result, "OK") self.assertEqual(result, "OK")
@@ -750,6 +763,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
""" """
doc_ids = [self.doc2.id] doc_ids = [self.doc2.id]
pages = [[1, 2], [3]] pages = [[1, 2], [3]]
self.doc2.archive_serial_number = 200
self.doc2.save()
result = bulk_edit.split(doc_ids, pages, delete_originals=True) result = bulk_edit.split(doc_ids, pages, delete_originals=True)
self.assertEqual(result, "OK") self.assertEqual(result, "OK")
@@ -767,6 +782,9 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids, doc_ids,
) )
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.consume_file.delay") @mock.patch("documents.tasks.consume_file.delay")
@mock.patch("pikepdf.Pdf.save") @mock.patch("pikepdf.Pdf.save")
def test_split_with_errors(self, mock_save_pdf, mock_consume_file): def test_split_with_errors(self, mock_save_pdf, mock_consume_file):
@@ -967,10 +985,16 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_chord.return_value.delay.return_value = None mock_chord.return_value.delay.return_value = None
doc_ids = [self.doc2.id] doc_ids = [self.doc2.id]
operations = [{"page": 1}, {"page": 2}] operations = [{"page": 1}, {"page": 2}]
self.doc2.archive_serial_number = 250
self.doc2.save()
result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True) result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.assertEqual(result, "OK") self.assertEqual(result, "OK")
mock_chord.assert_called_once() mock_chord.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].asn, 250)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.update_document_content_maybe_archive_file.delay") @mock.patch("documents.tasks.update_document_content_maybe_archive_file.delay")
def test_edit_pdf_with_update_document(self, mock_update_document): def test_edit_pdf_with_update_document(self, mock_update_document):

View File

@@ -412,14 +412,6 @@ class TestConsumer(
self.assertEqual(document.archive_serial_number, 123) self.assertEqual(document.archive_serial_number, 123)
self._assert_first_last_send_progress() self._assert_first_last_send_progress()
def testMetadataOverridesSkipAsnPropagation(self):
overrides = DocumentMetadataOverrides()
incoming = DocumentMetadataOverrides(skip_asn=True)
overrides.update(incoming)
self.assertTrue(overrides.skip_asn)
def testOverrideTitlePlaceholders(self): def testOverrideTitlePlaceholders(self):
c = Correspondent.objects.create(name="Correspondent Name") c = Correspondent.objects.create(name="Correspondent Name")
dt = DocumentType.objects.create(name="DocType Name") dt = DocumentType.objects.create(name="DocType Name")

View File

@@ -1,7 +1,6 @@
from unittest import mock from unittest import mock
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import transaction
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents import bulk_edit from documents import bulk_edit
@@ -11,7 +10,6 @@ from documents.models import Workflow
from documents.models import WorkflowAction from documents.models import WorkflowAction
from documents.models import WorkflowTrigger from documents.models import WorkflowTrigger
from documents.serialisers import TagSerializer from documents.serialisers import TagSerializer
from documents.signals import handlers
from documents.signals.handlers import run_workflows from documents.signals.handlers import run_workflows
@@ -252,31 +250,3 @@ class TestTagHierarchy(APITestCase):
row for row in response.data["results"] if row["id"] == self.parent.pk row for row in response.data["results"] if row["id"] == self.parent.pk
) )
assert any(child["id"] == self.child.pk for child in parent_entry["children"]) assert any(child["id"] == self.child.pk for child in parent_entry["children"])
def test_tag_tree_deferred_update_runs_on_commit(self):
from django.db import transaction
# Create tags inside an explicit transaction and commit.
with transaction.atomic():
parent = Tag.objects.create(name="Parent 2")
child = Tag.objects.create(name="Child 2", tn_parent=parent)
# After commit, tn_* fields should be populated.
parent.refresh_from_db()
child.refresh_from_db()
assert parent.tn_children_count == 1
assert child.tn_ancestors_count == 1
def test_tag_tree_update_runs_once_per_transaction(self):
handlers._tag_tree_update_scheduled = False
with mock.patch("documents.signals.handlers.Tag.update_tree") as update_tree:
with self.captureOnCommitCallbacks(execute=True) as callbacks:
with transaction.atomic():
handlers.schedule_tag_tree_update()
handlers.schedule_tag_tree_update()
update_tree.assert_not_called()
assert handlers._tag_tree_update_scheduled is True
assert len(callbacks) == 1
update_tree.assert_called_once()
assert handlers._tag_tree_update_scheduled is False

View File

@@ -448,43 +448,8 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
def get_serializer_context(self): def get_serializer_context(self):
context = super().get_serializer_context() context = super().get_serializer_context()
context["document_count_filter"] = self.get_document_count_filter() context["document_count_filter"] = self.get_document_count_filter()
if hasattr(self, "_children_map"):
context["children_map"] = self._children_map
return context return context
def list(self, request, *args, **kwargs):
"""
Build a children map once to avoid per-parent queries in the serializer.
"""
queryset = self.filter_queryset(self.get_queryset())
ordering = OrderingFilter().get_ordering(request, queryset, self) or (
Lower("name"),
)
queryset = queryset.order_by(*ordering)
all_tags = list(queryset)
descendant_pks = {pk for tag in all_tags for pk in tag.get_descendants_pks()}
if descendant_pks:
filter_q = self.get_document_count_filter()
children_source = (
Tag.objects.filter(pk__in=descendant_pks | {t.pk for t in all_tags})
.select_related("owner")
.annotate(document_count=Count("documents", filter=filter_q))
.order_by(*ordering)
)
else:
children_source = all_tags
children_map = {}
for tag in children_source:
children_map.setdefault(tag.tn_parent_id, []).append(tag)
self._children_map = children_map
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
def perform_update(self, serializer): def perform_update(self, serializer):
old_parent = self.get_object().get_parent() old_parent = self.get_object().get_parent()
tag = serializer.save() tag = serializer.save()