diff --git a/src/documents/admin.py b/src/documents/admin.py index 59cbf1853..aad649acb 100644 --- a/src/documents/admin.py +++ b/src/documents/admin.py @@ -1,6 +1,7 @@ from django.conf import settings from django.contrib import admin from guardian.admin import GuardedModelAdmin +from treenode.admin import TreeNodeModelAdmin from documents.models import Correspondent from documents.models import CustomField @@ -14,6 +15,7 @@ from documents.models import SavedViewFilterRule from documents.models import ShareLink from documents.models import StoragePath from documents.models import Tag +from documents.tasks import update_document_parent_tags if settings.AUDIT_LOG_ENABLED: from auditlog.admin import LogEntryAdmin @@ -26,12 +28,25 @@ class CorrespondentAdmin(GuardedModelAdmin): list_editable = ("match", "matching_algorithm") -class TagAdmin(GuardedModelAdmin): +class TagAdmin(GuardedModelAdmin, TreeNodeModelAdmin): list_display = ("name", "color", "match", "matching_algorithm") list_filter = ("matching_algorithm",) list_editable = ("color", "match", "matching_algorithm") search_fields = ("color", "name") + def save_model(self, request, obj, form, change): + old_parent = None + if change and obj.pk: + tag = Tag.objects.get(pk=obj.pk) + old_parent = tag.get_parent() if tag else None + + super().save_model(request, obj, form, change) + + # sync parent tags on documents if changed + new_parent = obj.get_parent() + if old_parent != new_parent: + update_document_parent_tags(obj, new_parent) + class DocumentTypeAdmin(GuardedModelAdmin): list_display = ("name", "match", "matching_algorithm") diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 89db54497..0be16274d 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -515,3 +515,44 @@ def check_scheduled_workflows(): workflow_to_run=workflow, document=document, ) + + +def update_document_parent_tags(tag: Tag, new_parent: Tag | None) -> None: + """ + When a tag's parent changes, ensure all documents containing the tag also have + the parent tag (and its ancestors) applied. + """ + if new_parent is None: + return + + DocumentTagRelationship = Document.tags.through + + doc_ids: list[int] = list( + Document.objects.filter(tags=tag).values_list("pk", flat=True), + ) + + if not doc_ids: + return + + parents_to_add = [new_parent, *new_parent.get_ancestors()] + + to_create: list = [] + affected: set[int] = set() + + for parent in parents_to_add: + missing_qs = Document.objects.filter(id__in=doc_ids).exclude(tags=parent) + missing_ids = list(missing_qs.values_list("pk", flat=True)) + to_create.extend( + DocumentTagRelationship(document_id=doc_id, tag_id=parent.id) + for doc_id in missing_ids + ) + affected.update(missing_ids) + + if to_create: + DocumentTagRelationship.objects.bulk_create( + to_create, + ignore_conflicts=True, + ) + + if affected: + bulk_update_documents.delay(document_ids=list(affected)) diff --git a/src/documents/tests/test_admin.py b/src/documents/tests/test_admin.py index ab32562a8..278014f7c 100644 --- a/src/documents/tests/test_admin.py +++ b/src/documents/tests/test_admin.py @@ -1,4 +1,5 @@ import types +from unittest.mock import patch from django.contrib.admin.sites import AdminSite from django.contrib.auth.models import User @@ -7,7 +8,9 @@ from django.utils import timezone from documents import index from documents.admin import DocumentAdmin +from documents.admin import TagAdmin from documents.models import Document +from documents.models import Tag from documents.tests.utils import DirectoriesMixin from paperless.admin import PaperlessUserAdmin @@ -70,6 +73,24 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase): self.assertEqual(self.doc_admin.created_(doc), "2020-04-12") +class TestTagAdmin(DirectoriesMixin, TestCase): + def setUp(self) -> None: + super().setUp() + self.tag_admin = TagAdmin(model=Tag, admin_site=AdminSite()) + + @patch("documents.tasks.bulk_update_documents") + def test_parent_tags_get_added(self, mock_bulk_update): + document = Document.objects.create(title="test") + parent = Tag.objects.create(name="parent") + child = Tag.objects.create(name="child") + document.tags.add(child) + + child.tn_parent = parent + self.tag_admin.save_model(None, child, None, change=True) + document.refresh_from_db() + self.assertIn(parent, document.tags.all()) + + class TestPaperlessAdmin(DirectoriesMixin, TestCase): def setUp(self) -> None: super().setUp() diff --git a/src/documents/views.py b/src/documents/views.py index 4f4f182e8..21a6ba082 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -169,6 +169,7 @@ from documents.tasks import empty_trash from documents.tasks import index_optimize from documents.tasks import sanity_check from documents.tasks import train_classifier +from documents.tasks import update_document_parent_tags from documents.templating.filepath import validate_filepath_template_and_render from documents.utils import get_boolean from paperless import version @@ -346,33 +347,7 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin): tag = serializer.save() new_parent = tag.get_parent() if old_parent != new_parent: - self._update_document_parent_tags(tag, old_parent, new_parent) - - def _update_document_parent_tags(self, tag, old_parent, new_parent): - DocumentTagRelationship = Document.tags.through - doc_ids = list(Document.objects.filter(tags=tag).values_list("pk", flat=True)) - affected = set() - - if new_parent: - parents_to_add = [new_parent, *new_parent.get_ancestors()] - to_create = [] - for parent in parents_to_add: - missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent) - to_create.extend( - DocumentTagRelationship(document_id=doc_id, tag_id=parent.id) - for doc_id in missing.values_list("pk", flat=True) - ) - affected.update(missing.values_list("pk", flat=True)) - if to_create: - DocumentTagRelationship.objects.bulk_create( - to_create, - ignore_conflicts=True, - ) - - if affected: - from documents.tasks import bulk_update_documents - - bulk_update_documents.delay(document_ids=list(affected)) + update_document_parent_tags(tag, new_parent) @extend_schema_view(**generate_object_with_permissions_schema(DocumentTypeSerializer))