diff --git a/src/documents/models.py b/src/documents/models.py index 7adfa4d19..6e2a6ebcb 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -116,37 +116,20 @@ class Tag(MatchingModel, TreeNodeModel): verbose_name = _("tag") verbose_name_plural = _("tags") - def subtree_height(self, node: TreeNodeModel) -> int: - children = list(node.get_children()) - if not children: - return 0 - return 1 + max(self.subtree_height(child) for child in children) - def clean(self): - # Prevent self-parenting - if self.parent == self: - raise ValidationError(_("Cannot set itself as parent.")) - - # Prevent assigning a descendant as parent - if ( - self.parent - and self.pk is not None - and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors()) - ): - raise ValidationError(_("Cannot set parent to a descendant.")) + # Prevent self-parenting and assigning a descendant as parent + parent = self.get_parent() + if parent == self: + raise ValidationError({"parent": _("Cannot set itself as parent.")}) + if parent and self.pk is not None and self.is_ancestor_of(parent): + raise ValidationError({"parent": _("Cannot set parent to a descendant.")}) # Enforce maximum nesting depth new_parent_depth = 0 - if self.parent: - new_parent_depth = len(self.parent.get_ancestors()) + 1 - if self.pk is None: - # Unsaved tag cannot have children; treat as leaf - height = 0 - else: - try: - height = self.subtree_height(self) - except RecursionError: - raise ValidationError(_("Invalid tag hierarchy.")) + if parent: + new_parent_depth = parent.get_ancestors_count() + 1 + + height = 0 if self.pk is None else self.get_depth() deepest_new_depth = (new_parent_depth + 1) + height if deepest_new_depth > self.MAX_NESTING_DEPTH: raise ValidationError(_("Maximum nesting depth exceeded.")) @@ -442,8 +425,9 @@ class Document(SoftDeleteModel, ModelWithOwner): def add_nested_tags(self, tags): for tag in tags: self.tags.add(tag) - if tag.parent: - self.add_nested_tags([tag.parent]) + parent = tag.get_parent() + if parent: + self.add_nested_tags([parent]) class SavedView(ModelWithOwner): diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 447229205..c65ebf099 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -583,19 +583,28 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer): def validate(self, attrs): # Validate when changing parent - parent = attrs.get("parent", self.instance.parent if self.instance else None) + parent = attrs.get( + "tn_parent", + self.instance.get_parent() if self.instance else None, + ) if self.instance: # Temporarily set parent on the instance if updating and use model clean() - original_parent = self.instance.parent + original_parent = self.instance.get_parent() try: - self.instance.set_parent(parent) + # Temporarily set tn_parent in-memory to validate clean() + self.instance.tn_parent = parent self.instance.clean() except ValueError as e: logger.debug("Tag parent validation failed: %s", e) raise serializers.ValidationError({"parent": _("Invalid parent tag.")}) + except ValidationError as e: + logger.debug("Tag parent validation failed: %s", e) + if getattr(e, "message_dict", None): + raise serializers.ValidationError(e.message_dict) + raise serializers.ValidationError({"non_field_errors": e.messages}) finally: - self.instance.set_parent(original_parent) + self.instance.tn_parent = original_parent else: # For new instances, create a transient Tag and validate temp = Tag(tn_parent=parent) diff --git a/src/documents/tests/test_tag_hierarchy.py b/src/documents/tests/test_tag_hierarchy.py index 767e190cb..ce1c756fb 100644 --- a/src/documents/tests/test_tag_hierarchy.py +++ b/src/documents/tests/test_tag_hierarchy.py @@ -1,7 +1,6 @@ from unittest import mock from django.contrib.auth.models import User -from django.core.exceptions import ValidationError from rest_framework.test import APITestCase from documents import bulk_edit @@ -130,7 +129,7 @@ class TestTagHierarchy(APITestCase): format="json", ) assert resp.status_code == 400 - assert "parent" in resp.data + assert "Cannot set itself as parent" in str(resp.data["parent"]) def test_cannot_set_parent_to_descendant(self): a = Tag.objects.create(name="A") @@ -144,7 +143,7 @@ class TestTagHierarchy(APITestCase): format="json", ) assert resp.status_code == 400 - assert "Cannot set parent to a descendant" in str(resp.data["non_field_errors"]) + assert "Cannot set parent to a descendant" in str(resp.data["parent"]) def test_max_depth_on_create(self): a = Tag.objects.create(name="A1") @@ -203,15 +202,4 @@ class TestTagHierarchy(APITestCase): ) assert resp_ok.status_code in (200, 202) x.refresh_from_db() - assert x.parent_id == c.id - - def test_invalid_hierarchy_recursion_error(self): - t = Tag.objects.create(name="TagA") - - with mock.patch( - "documents.models.Tag.subtree_height", - side_effect=RecursionError, - ): - with self.assertRaises(ValidationError) as cm: - t.clean() - assert "Invalid tag hierarchy" in str(cm.exception) + assert x.parent_pk == c.id diff --git a/src/documents/views.py b/src/documents/views.py index adf6a2a90..4f4f182e8 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -342,9 +342,9 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin): ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count") def perform_update(self, serializer): - old_parent = self.get_object().parent + old_parent = self.get_object().get_parent() tag = serializer.save() - new_parent = tag.parent + new_parent = tag.get_parent() if old_parent != new_parent: self._update_document_parent_tags(tag, old_parent, new_parent)