First, release ASNs before document replacement (and restore if needed)

This commit is contained in:
shamoon
2025-12-31 09:49:20 -08:00
parent 72fd05501b
commit c7efcee3d6
2 changed files with 98 additions and 13 deletions

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import TYPE_CHECKING
from typing import Literal
from celery import chain
from celery import chord
from celery import group
from celery import shared_task
@@ -38,6 +37,42 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger("paperless.bulk_edit")
@shared_task(bind=True)
def restore_archive_serial_numbers_task(
self,
backup: dict[int, int],
*args,
**kwargs,
) -> None:
restore_archive_serial_numbers(backup)
def release_archive_serial_numbers(doc_ids: list[int]) -> dict[int, int]:
"""
Clears ASNs on documents that are about to be replaced so new documents
can be assigned ASNs without uniqueness collisions. Returns a backup map
of doc_id -> previous ASN for potential restoration.
"""
qs = Document.objects.filter(
id__in=doc_ids,
archive_serial_number__isnull=False,
).only("pk", "archive_serial_number")
backup = dict(qs.values_list("pk", "archive_serial_number"))
qs.update(archive_serial_number=None)
logger.info(f"Released archive serial numbers for documents {list(backup.keys())}")
return backup
def restore_archive_serial_numbers(backup: dict[int, int]) -> None:
"""
Restores ASNs using the provided backup map, intended for
rollback when replacement consumption fails.
"""
for doc_id, asn in backup.items():
Document.objects.filter(pk=doc_id).update(archive_serial_number=asn)
logger.info(f"Restored archive serial numbers for documents {list(backup.keys())}")
def set_correspondent(
doc_ids: list[int],
correspondent: Correspondent,
@@ -433,8 +468,6 @@ def merge(
if user is not None:
overrides.owner_id = user.id
# Avoid copying or detecting ASN from merged PDFs to prevent collision
overrides.skip_asn = True
logger.info("Adding merged document to the task queue.")
@@ -447,10 +480,18 @@ def merge(
)
if delete_originals:
backup = release_archive_serial_numbers(affected_docs)
logger.info(
"Queueing removal of original documents after consumption of merged document",
)
chain(consume_task, delete.si(affected_docs)).delay()
try:
consume_task.apply_async(
link=[delete.si(affected_docs)],
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
consume_task.delay()
@@ -508,10 +549,20 @@ def split(
)
if delete_originals:
backup = release_archive_serial_numbers([doc.id])
logger.info(
"Queueing removal of original document after consumption of the split documents",
)
chord(header=consume_tasks, body=delete.si([doc.id])).delay()
try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
group(consume_tasks).delay()
@@ -614,7 +665,6 @@ def edit_pdf(
)
if user is not None:
overrides.owner_id = user.id
for idx, pdf in enumerate(pdf_docs, start=1):
filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
@@ -633,7 +683,17 @@ def edit_pdf(
)
if delete_original:
chord(header=consume_tasks, body=delete.si([doc.id])).delay()
backup = release_archive_serial_numbers([doc.id])
try:
chord(
header=consume_tasks,
body=delete.si([doc.id]),
).apply_async(
link_error=[restore_archive_serial_numbers_task.s(backup)],
)
except Exception:
restore_archive_serial_numbers(backup)
raise
else:
group(consume_tasks).delay()

View File

@@ -602,23 +602,21 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertFalse(consume_file_args[1].skip_asn)
# With metadata_document_id overrides
result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (merged)")
self.assertEqual(consume_file_args[1].created, self.doc2.created)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertFalse(consume_file_args[1].skip_asn)
self.assertEqual(result, "OK")
@mock.patch("documents.bulk_edit.delete.si")
@mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.chain")
def test_merge_and_delete_originals(
self,
mock_chain,
mock_consume_file,
mock_delete_documents,
):
@@ -632,6 +630,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
- Document deletion task should be called
"""
doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id]
self.doc1.archive_serial_number = 101
self.doc2.archive_serial_number = 102
self.doc3.archive_serial_number = 103
self.doc1.save()
self.doc2.save()
self.doc3.save()
result = bulk_edit.merge(doc_ids, delete_originals=True)
self.assertEqual(result, "OK")
@@ -642,7 +646,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_called()
mock_delete_documents.assert_called()
mock_chain.assert_called_once()
consume_sig = mock_consume_file.return_value
consume_sig.apply_async.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(
@@ -650,7 +655,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertTrue(consume_file_args[1].skip_asn)
self.assertFalse(consume_file_args[1].skip_asn)
delete_documents_args, _ = mock_delete_documents.call_args
self.assertEqual(
@@ -658,6 +663,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids,
)
self.doc1.refresh_from_db()
self.doc2.refresh_from_db()
self.doc3.refresh_from_db()
self.assertIsNone(self.doc1.archive_serial_number)
self.assertIsNone(self.doc2.archive_serial_number)
self.assertIsNone(self.doc3.archive_serial_number)
@mock.patch("documents.tasks.consume_file.s")
def test_merge_with_archive_fallback(self, mock_consume_file):
"""
@@ -726,6 +738,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertFalse(consume_file_args[1].skip_asn)
self.assertEqual(result, "OK")
@@ -750,6 +763,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
"""
doc_ids = [self.doc2.id]
pages = [[1, 2], [3]]
self.doc2.archive_serial_number = 200
self.doc2.save()
result = bulk_edit.split(doc_ids, pages, delete_originals=True)
self.assertEqual(result, "OK")
@@ -757,6 +772,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertFalse(consume_file_args[1].skip_asn)
mock_delete_documents.assert_called()
mock_chord.assert_called_once()
@@ -767,6 +783,9 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids,
)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("pikepdf.Pdf.save")
def test_split_with_errors(self, mock_save_pdf, mock_consume_file):
@@ -967,10 +986,16 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_chord.return_value.delay.return_value = None
doc_ids = [self.doc2.id]
operations = [{"page": 1}, {"page": 2}]
self.doc2.archive_serial_number = 250
self.doc2.save()
result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.assertEqual(result, "OK")
mock_chord.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertFalse(consume_file_args[1].skip_asn)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@mock.patch("documents.tasks.update_document_content_maybe_archive_file.delay")
def test_edit_pdf_with_update_document(self, mock_update_document):