diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index f6adfc8a9..8e3869fb7 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -23,6 +23,7 @@ from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType from documents.models import StoragePath +from documents.models import Tag from documents.permissions import set_permissions_for_object from documents.plugins.helpers import DocumentsStatusManager from documents.tasks import bulk_update_documents @@ -91,31 +92,46 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: - qs = Document.objects.filter(Q(id__in=doc_ids) & ~Q(tags__id=tag)).only("pk") - affected_docs = list(qs.values_list("pk", flat=True)) + tag_obj = Tag.objects.get(pk=tag) + tags_to_add = [tag_obj, *tag_obj.get_all_ancestors()] DocumentTagRelationship = Document.tags.through + to_create = [] + affected_docs: set[int] = set() - DocumentTagRelationship.objects.bulk_create( - [DocumentTagRelationship(document_id=doc, tag_id=tag) for doc in affected_docs], - ) + for t in tags_to_add: + qs = Document.objects.filter(Q(id__in=doc_ids) & ~Q(tags__id=t.id)).only("pk") + doc_ids_missing_tag = list(qs.values_list("pk", flat=True)) + affected_docs.update(doc_ids_missing_tag) + to_create.extend( + DocumentTagRelationship(document_id=doc, tag_id=t.id) + for doc in doc_ids_missing_tag + ) - bulk_update_documents.delay(document_ids=affected_docs) + if to_create: + DocumentTagRelationship.objects.bulk_create(to_create) + + if affected_docs: + bulk_update_documents.delay(document_ids=list(affected_docs)) return "OK" def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: - qs = Document.objects.filter(Q(id__in=doc_ids) & Q(tags__id=tag)).only("pk") - affected_docs = list(qs.values_list("pk", flat=True)) + tag_obj = Tag.objects.get(pk=tag) + tags_to_remove = [tag_obj, *tag_obj.get_all_descendants()] + tag_ids = [t.id for t in tags_to_remove] DocumentTagRelationship = Document.tags.through + qs = DocumentTagRelationship.objects.filter( + document_id__in=doc_ids, + tag_id__in=tag_ids, + ) + affected_docs = list(qs.values_list("document_id", flat=True).distinct()) + qs.delete() - DocumentTagRelationship.objects.filter( - Q(document_id__in=affected_docs) & Q(tag_id=tag), - ).delete() - - bulk_update_documents.delay(document_ids=affected_docs) + if affected_docs: + bulk_update_documents.delay(document_ids=affected_docs) return "OK" @@ -127,23 +143,35 @@ def modify_tags( ) -> Literal["OK"]: qs = Document.objects.filter(id__in=doc_ids).only("pk") affected_docs = list(qs.values_list("pk", flat=True)) - DocumentTagRelationship = Document.tags.through - DocumentTagRelationship.objects.filter( - document_id__in=affected_docs, - tag_id__in=remove_tags, - ).delete() + # add with all ancestors + expanded_add_tags: set[int] = set() + for tag_id in add_tags: + t = Tag.objects.get(pk=tag_id) + expanded_add_tags.update([t.id for t in [t, *t.get_all_ancestors()]]) - DocumentTagRelationship.objects.bulk_create( - [ - DocumentTagRelationship(document_id=doc, tag_id=tag) - for (doc, tag) in itertools.product(affected_docs, add_tags) - ], - ignore_conflicts=True, - ) + # remove with all descendants + expanded_remove_tags: set[int] = set() + for tag_id in remove_tags: + t = Tag.objects.get(pk=tag_id) + expanded_remove_tags.update([t.id for t in [t, *t.get_all_descendants()]]) - bulk_update_documents.delay(document_ids=affected_docs) + if expanded_remove_tags: + DocumentTagRelationship.objects.filter( + document_id__in=affected_docs, + tag_id__in=expanded_remove_tags, + ).delete() + + to_create = [ + DocumentTagRelationship(document_id=doc, tag_id=tag) + for (doc, tag) in itertools.product(affected_docs, expanded_add_tags) + ] + if to_create: + DocumentTagRelationship.objects.bulk_create(to_create, ignore_conflicts=True) + + if affected_docs: + bulk_update_documents.delay(document_ids=affected_docs) return "OK" diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index fa940d89c..afd12b412 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -627,14 +627,17 @@ def run_workflows( def assignment_action(): if action.assign_tags.exists(): + tag_ids_to_add: set[int] = set() + for tag in action.assign_tags.all(): + tag_ids_to_add.add(tag.pk) + tag_ids_to_add.update(t.pk for t in tag.get_all_ancestors()) + if not use_overrides: - doc_tag_ids.extend(action.assign_tags.values_list("pk", flat=True)) + doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add) else: if overrides.tag_ids is None: overrides.tag_ids = [] - overrides.tag_ids.extend( - action.assign_tags.values_list("pk", flat=True), - ) + overrides.tag_ids = list(set(overrides.tag_ids) | tag_ids_to_add) if action.assign_correspondent: if not use_overrides: @@ -760,14 +763,17 @@ def run_workflows( else: overrides.tag_ids = None else: + tag_ids_to_remove: set[int] = set() + for tag in action.remove_tags.all(): + tag_ids_to_remove.add(tag.pk) + tag_ids_to_remove.update(t.pk for t in tag.get_all_descendants()) + if not use_overrides: - for tag in action.remove_tags.filter( - pk__in=document.tags.values_list("pk", flat=True), - ): - doc_tag_ids.remove(tag.pk) + doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove] elif overrides.tag_ids: - for tag in action.remove_tags.filter(pk__in=overrides.tag_ids): - overrides.tag_ids.remove(tag.pk) + overrides.tag_ids = [ + t for t in overrides.tag_ids if t not in tag_ids_to_remove + ] if not use_overrides and ( action.remove_all_correspondents