From c7efcee3d60545cb5f85131300d37b0d56a15b1c Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Wed, 31 Dec 2025 09:49:20 -0800 Subject: [PATCH] First, release ASNs before document replacement (and restore if needed) --- src/documents/bulk_edit.py | 74 ++++++++++++++++++++++++--- src/documents/tests/test_bulk_edit.py | 37 +++++++++++--- 2 files changed, 98 insertions(+), 13 deletions(-) diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 219947d09..883884694 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import TYPE_CHECKING from typing import Literal -from celery import chain from celery import chord from celery import group from celery import shared_task @@ -38,6 +37,42 @@ if TYPE_CHECKING: logger: logging.Logger = logging.getLogger("paperless.bulk_edit") +@shared_task(bind=True) +def restore_archive_serial_numbers_task( + self, + backup: dict[int, int], + *args, + **kwargs, +) -> None: + restore_archive_serial_numbers(backup) + + +def release_archive_serial_numbers(doc_ids: list[int]) -> dict[int, int]: + """ + Clears ASNs on documents that are about to be replaced so new documents + can be assigned ASNs without uniqueness collisions. Returns a backup map + of doc_id -> previous ASN for potential restoration. + """ + qs = Document.objects.filter( + id__in=doc_ids, + archive_serial_number__isnull=False, + ).only("pk", "archive_serial_number") + backup = dict(qs.values_list("pk", "archive_serial_number")) + qs.update(archive_serial_number=None) + logger.info(f"Released archive serial numbers for documents {list(backup.keys())}") + return backup + + +def restore_archive_serial_numbers(backup: dict[int, int]) -> None: + """ + Restores ASNs using the provided backup map, intended for + rollback when replacement consumption fails. + """ + for doc_id, asn in backup.items(): + Document.objects.filter(pk=doc_id).update(archive_serial_number=asn) + logger.info(f"Restored archive serial numbers for documents {list(backup.keys())}") + + def set_correspondent( doc_ids: list[int], correspondent: Correspondent, @@ -433,8 +468,6 @@ def merge( if user is not None: overrides.owner_id = user.id - # Avoid copying or detecting ASN from merged PDFs to prevent collision - overrides.skip_asn = True logger.info("Adding merged document to the task queue.") @@ -447,10 +480,18 @@ def merge( ) if delete_originals: + backup = release_archive_serial_numbers(affected_docs) logger.info( "Queueing removal of original documents after consumption of merged document", ) - chain(consume_task, delete.si(affected_docs)).delay() + try: + consume_task.apply_async( + link=[delete.si(affected_docs)], + link_error=[restore_archive_serial_numbers_task.s(backup)], + ) + except Exception: + restore_archive_serial_numbers(backup) + raise else: consume_task.delay() @@ -508,10 +549,20 @@ def split( ) if delete_originals: + backup = release_archive_serial_numbers([doc.id]) logger.info( "Queueing removal of original document after consumption of the split documents", ) - chord(header=consume_tasks, body=delete.si([doc.id])).delay() + try: + chord( + header=consume_tasks, + body=delete.si([doc.id]), + ).apply_async( + link_error=[restore_archive_serial_numbers_task.s(backup)], + ) + except Exception: + restore_archive_serial_numbers(backup) + raise else: group(consume_tasks).delay() @@ -614,7 +665,6 @@ def edit_pdf( ) if user is not None: overrides.owner_id = user.id - for idx, pdf in enumerate(pdf_docs, start=1): filepath: Path = ( Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR)) @@ -633,7 +683,17 @@ def edit_pdf( ) if delete_original: - chord(header=consume_tasks, body=delete.si([doc.id])).delay() + backup = release_archive_serial_numbers([doc.id]) + try: + chord( + header=consume_tasks, + body=delete.si([doc.id]), + ).apply_async( + link_error=[restore_archive_serial_numbers_task.s(backup)], + ) + except Exception: + restore_archive_serial_numbers(backup) + raise else: group(consume_tasks).delay() diff --git a/src/documents/tests/test_bulk_edit.py b/src/documents/tests/test_bulk_edit.py index c220c1e9b..950885b22 100644 --- a/src/documents/tests/test_bulk_edit.py +++ b/src/documents/tests/test_bulk_edit.py @@ -602,23 +602,21 @@ class TestPDFActions(DirectoriesMixin, TestCase): expected_filename, ) self.assertEqual(consume_file_args[1].title, None) - self.assertTrue(consume_file_args[1].skip_asn) + self.assertFalse(consume_file_args[1].skip_asn) # With metadata_document_id overrides result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id) consume_file_args, _ = mock_consume_file.call_args self.assertEqual(consume_file_args[1].title, "B (merged)") self.assertEqual(consume_file_args[1].created, self.doc2.created) - self.assertTrue(consume_file_args[1].skip_asn) + self.assertFalse(consume_file_args[1].skip_asn) self.assertEqual(result, "OK") @mock.patch("documents.bulk_edit.delete.si") @mock.patch("documents.tasks.consume_file.s") - @mock.patch("documents.bulk_edit.chain") def test_merge_and_delete_originals( self, - mock_chain, mock_consume_file, mock_delete_documents, ): @@ -632,6 +630,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): - Document deletion task should be called """ doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id] + self.doc1.archive_serial_number = 101 + self.doc2.archive_serial_number = 102 + self.doc3.archive_serial_number = 103 + self.doc1.save() + self.doc2.save() + self.doc3.save() result = bulk_edit.merge(doc_ids, delete_originals=True) self.assertEqual(result, "OK") @@ -642,7 +646,8 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_consume_file.assert_called() mock_delete_documents.assert_called() - mock_chain.assert_called_once() + consume_sig = mock_consume_file.return_value + consume_sig.apply_async.assert_called_once() consume_file_args, _ = mock_consume_file.call_args self.assertEqual( @@ -650,7 +655,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): expected_filename, ) self.assertEqual(consume_file_args[1].title, None) - self.assertTrue(consume_file_args[1].skip_asn) + self.assertFalse(consume_file_args[1].skip_asn) delete_documents_args, _ = mock_delete_documents.call_args self.assertEqual( @@ -658,6 +663,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): doc_ids, ) + self.doc1.refresh_from_db() + self.doc2.refresh_from_db() + self.doc3.refresh_from_db() + self.assertIsNone(self.doc1.archive_serial_number) + self.assertIsNone(self.doc2.archive_serial_number) + self.assertIsNone(self.doc3.archive_serial_number) + @mock.patch("documents.tasks.consume_file.s") def test_merge_with_archive_fallback(self, mock_consume_file): """ @@ -726,6 +738,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(mock_consume_file.call_count, 2) consume_file_args, _ = mock_consume_file.call_args self.assertEqual(consume_file_args[1].title, "B (split 2)") + self.assertFalse(consume_file_args[1].skip_asn) self.assertEqual(result, "OK") @@ -750,6 +763,8 @@ class TestPDFActions(DirectoriesMixin, TestCase): """ doc_ids = [self.doc2.id] pages = [[1, 2], [3]] + self.doc2.archive_serial_number = 200 + self.doc2.save() result = bulk_edit.split(doc_ids, pages, delete_originals=True) self.assertEqual(result, "OK") @@ -757,6 +772,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(mock_consume_file.call_count, 2) consume_file_args, _ = mock_consume_file.call_args self.assertEqual(consume_file_args[1].title, "B (split 2)") + self.assertFalse(consume_file_args[1].skip_asn) mock_delete_documents.assert_called() mock_chord.assert_called_once() @@ -767,6 +783,9 @@ class TestPDFActions(DirectoriesMixin, TestCase): doc_ids, ) + self.doc2.refresh_from_db() + self.assertIsNone(self.doc2.archive_serial_number) + @mock.patch("documents.tasks.consume_file.delay") @mock.patch("pikepdf.Pdf.save") def test_split_with_errors(self, mock_save_pdf, mock_consume_file): @@ -967,10 +986,16 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_chord.return_value.delay.return_value = None doc_ids = [self.doc2.id] operations = [{"page": 1}, {"page": 2}] + self.doc2.archive_serial_number = 250 + self.doc2.save() result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True) self.assertEqual(result, "OK") mock_chord.assert_called_once() + consume_file_args, _ = mock_consume_file.call_args + self.assertFalse(consume_file_args[1].skip_asn) + self.doc2.refresh_from_db() + self.assertIsNone(self.doc2.archive_serial_number) @mock.patch("documents.tasks.update_document_content_maybe_archive_file.delay") def test_edit_pdf_with_update_document(self, mock_update_document):