More treenode cleanup

This commit is contained in:
shamoon
2025-09-12 12:37:29 -07:00
parent 5e80eafe66
commit 667c06452f
4 changed files with 31 additions and 50 deletions

View File

@@ -116,37 +116,20 @@ class Tag(MatchingModel, TreeNodeModel):
verbose_name = _("tag") verbose_name = _("tag")
verbose_name_plural = _("tags") verbose_name_plural = _("tags")
def subtree_height(self, node: TreeNodeModel) -> int:
children = list(node.get_children())
if not children:
return 0
return 1 + max(self.subtree_height(child) for child in children)
def clean(self): def clean(self):
# Prevent self-parenting # Prevent self-parenting and assigning a descendant as parent
if self.parent == self: parent = self.get_parent()
raise ValidationError(_("Cannot set itself as parent.")) if parent == self:
raise ValidationError({"parent": _("Cannot set itself as parent.")})
# Prevent assigning a descendant as parent if parent and self.pk is not None and self.is_ancestor_of(parent):
if ( raise ValidationError({"parent": _("Cannot set parent to a descendant.")})
self.parent
and self.pk is not None
and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
):
raise ValidationError(_("Cannot set parent to a descendant."))
# Enforce maximum nesting depth # Enforce maximum nesting depth
new_parent_depth = 0 new_parent_depth = 0
if self.parent: if parent:
new_parent_depth = len(self.parent.get_ancestors()) + 1 new_parent_depth = parent.get_ancestors_count() + 1
if self.pk is None:
# Unsaved tag cannot have children; treat as leaf height = 0 if self.pk is None else self.get_depth()
height = 0
else:
try:
height = self.subtree_height(self)
except RecursionError:
raise ValidationError(_("Invalid tag hierarchy."))
deepest_new_depth = (new_parent_depth + 1) + height deepest_new_depth = (new_parent_depth + 1) + height
if deepest_new_depth > self.MAX_NESTING_DEPTH: if deepest_new_depth > self.MAX_NESTING_DEPTH:
raise ValidationError(_("Maximum nesting depth exceeded.")) raise ValidationError(_("Maximum nesting depth exceeded."))
@@ -442,8 +425,9 @@ class Document(SoftDeleteModel, ModelWithOwner):
def add_nested_tags(self, tags): def add_nested_tags(self, tags):
for tag in tags: for tag in tags:
self.tags.add(tag) self.tags.add(tag)
if tag.parent: parent = tag.get_parent()
self.add_nested_tags([tag.parent]) if parent:
self.add_nested_tags([parent])
class SavedView(ModelWithOwner): class SavedView(ModelWithOwner):

View File

@@ -583,19 +583,28 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
def validate(self, attrs): def validate(self, attrs):
# Validate when changing parent # Validate when changing parent
parent = attrs.get("parent", self.instance.parent if self.instance else None) parent = attrs.get(
"tn_parent",
self.instance.get_parent() if self.instance else None,
)
if self.instance: if self.instance:
# Temporarily set parent on the instance if updating and use model clean() # Temporarily set parent on the instance if updating and use model clean()
original_parent = self.instance.parent original_parent = self.instance.get_parent()
try: try:
self.instance.set_parent(parent) # Temporarily set tn_parent in-memory to validate clean()
self.instance.tn_parent = parent
self.instance.clean() self.instance.clean()
except ValueError as e: except ValueError as e:
logger.debug("Tag parent validation failed: %s", e) logger.debug("Tag parent validation failed: %s", e)
raise serializers.ValidationError({"parent": _("Invalid parent tag.")}) raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
except ValidationError as e:
logger.debug("Tag parent validation failed: %s", e)
if getattr(e, "message_dict", None):
raise serializers.ValidationError(e.message_dict)
raise serializers.ValidationError({"non_field_errors": e.messages})
finally: finally:
self.instance.set_parent(original_parent) self.instance.tn_parent = original_parent
else: else:
# For new instances, create a transient Tag and validate # For new instances, create a transient Tag and validate
temp = Tag(tn_parent=parent) temp = Tag(tn_parent=parent)

View File

@@ -1,7 +1,6 @@
from unittest import mock from unittest import mock
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.exceptions import ValidationError
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents import bulk_edit from documents import bulk_edit
@@ -130,7 +129,7 @@ class TestTagHierarchy(APITestCase):
format="json", format="json",
) )
assert resp.status_code == 400 assert resp.status_code == 400
assert "parent" in resp.data assert "Cannot set itself as parent" in str(resp.data["parent"])
def test_cannot_set_parent_to_descendant(self): def test_cannot_set_parent_to_descendant(self):
a = Tag.objects.create(name="A") a = Tag.objects.create(name="A")
@@ -144,7 +143,7 @@ class TestTagHierarchy(APITestCase):
format="json", format="json",
) )
assert resp.status_code == 400 assert resp.status_code == 400
assert "Cannot set parent to a descendant" in str(resp.data["non_field_errors"]) assert "Cannot set parent to a descendant" in str(resp.data["parent"])
def test_max_depth_on_create(self): def test_max_depth_on_create(self):
a = Tag.objects.create(name="A1") a = Tag.objects.create(name="A1")
@@ -203,15 +202,4 @@ class TestTagHierarchy(APITestCase):
) )
assert resp_ok.status_code in (200, 202) assert resp_ok.status_code in (200, 202)
x.refresh_from_db() x.refresh_from_db()
assert x.parent_id == c.id assert x.parent_pk == c.id
def test_invalid_hierarchy_recursion_error(self):
t = Tag.objects.create(name="TagA")
with mock.patch(
"documents.models.Tag.subtree_height",
side_effect=RecursionError,
):
with self.assertRaises(ValidationError) as cm:
t.clean()
assert "Invalid tag hierarchy" in str(cm.exception)

View File

@@ -342,9 +342,9 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count") ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
def perform_update(self, serializer): def perform_update(self, serializer):
old_parent = self.get_object().parent old_parent = self.get_object().get_parent()
tag = serializer.save() tag = serializer.save()
new_parent = tag.parent new_parent = tag.get_parent()
if old_parent != new_parent: if old_parent != new_parent:
self._update_document_parent_tags(tag, old_parent, new_parent) self._update_document_parent_tags(tag, old_parent, new_parent)