From 69143c5e0e62ffb3aebf3a3ddf2ba969935e349b Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:38:08 -0700 Subject: [PATCH] Handle parent tag removal in backend --- src/documents/serialisers.py | 37 +++++++++++++---------- src/documents/tests/test_tag_hierarchy.py | 10 ++++++ 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index ce2b0c48a..0b01e221b 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1088,22 +1088,27 @@ class DocumentSerializer( doc_id, ) if "tags" in validated_data: - # add all parent tags - all_ancestor_tags = set(validated_data["tags"]) - for tag in validated_data["tags"]: - all_ancestor_tags.update(tag.get_ancestors()) - validated_data["tags"] = list(all_ancestor_tags) - # remove any children for parents that are being removed - tag_parents_being_removed = [ - tag - for tag in instance.tags.all() - if tag not in validated_data["tags"] and tag.get_children_count() > 0 - ] - validated_data["tags"] = [ - tag - for tag in validated_data["tags"] - if tag not in tag_parents_being_removed - ] + # Respect tag hierarchy on updates: + # - Adding a child adds its ancestors + # - Removing a parent removes all its descendants + prev_tags = set(instance.tags.all()) + requested_tags = set(validated_data["tags"]) + + # Tags being removed in this update and all descendants + removed_tags = prev_tags - requested_tags + blocked_tags = set(removed_tags) + for t in removed_tags: + blocked_tags.update(t.get_descendants()) + + # Add all parent tags + final_tags = set(requested_tags) + for t in requested_tags: + final_tags.update(t.get_ancestors()) + + # Drop removed parents and their descendants + final_tags.difference_update(blocked_tags) + + validated_data["tags"] = list(final_tags) if validated_data.get("remove_inbox_tags"): tag_ids_being_added = ( [ diff --git a/src/documents/tests/test_tag_hierarchy.py b/src/documents/tests/test_tag_hierarchy.py index ce1c756fb..bb9eb6a60 100644 --- a/src/documents/tests/test_tag_hierarchy.py +++ b/src/documents/tests/test_tag_hierarchy.py @@ -51,6 +51,16 @@ class TestTagHierarchy(APITestCase): tags = set(self.document.tags.values_list("pk", flat=True)) assert tags == {self.parent.pk, self.child.pk} + def test_document_api_remove_parent_removes_children(self): + self.document.add_nested_tags([self.parent, self.child]) + self.client.patch( + f"/api/documents/{self.document.pk}/", + {"tags": [self.child.pk]}, + format="json", + ) + self.document.refresh_from_db() + assert self.document.tags.count() == 0 + def test_document_api_remove_parent_removes_child(self): self.document.add_nested_tags([self.child]) self.client.patch(