mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-12 21:35:40 -05:00
More treenode cleanup
This commit is contained in:
@@ -116,37 +116,20 @@ class Tag(MatchingModel, TreeNodeModel):
|
|||||||
verbose_name = _("tag")
|
verbose_name = _("tag")
|
||||||
verbose_name_plural = _("tags")
|
verbose_name_plural = _("tags")
|
||||||
|
|
||||||
def subtree_height(self, node: TreeNodeModel) -> int:
|
|
||||||
children = list(node.get_children())
|
|
||||||
if not children:
|
|
||||||
return 0
|
|
||||||
return 1 + max(self.subtree_height(child) for child in children)
|
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
# Prevent self-parenting
|
# Prevent self-parenting and assigning a descendant as parent
|
||||||
if self.parent == self:
|
parent = self.get_parent()
|
||||||
raise ValidationError(_("Cannot set itself as parent."))
|
if parent == self:
|
||||||
|
raise ValidationError({"parent": _("Cannot set itself as parent.")})
|
||||||
# Prevent assigning a descendant as parent
|
if parent and self.pk is not None and self.is_ancestor_of(parent):
|
||||||
if (
|
raise ValidationError({"parent": _("Cannot set parent to a descendant.")})
|
||||||
self.parent
|
|
||||||
and self.pk is not None
|
|
||||||
and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
|
|
||||||
):
|
|
||||||
raise ValidationError(_("Cannot set parent to a descendant."))
|
|
||||||
|
|
||||||
# Enforce maximum nesting depth
|
# Enforce maximum nesting depth
|
||||||
new_parent_depth = 0
|
new_parent_depth = 0
|
||||||
if self.parent:
|
if parent:
|
||||||
new_parent_depth = len(self.parent.get_ancestors()) + 1
|
new_parent_depth = parent.get_ancestors_count() + 1
|
||||||
if self.pk is None:
|
|
||||||
# Unsaved tag cannot have children; treat as leaf
|
height = 0 if self.pk is None else self.get_depth()
|
||||||
height = 0
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
height = self.subtree_height(self)
|
|
||||||
except RecursionError:
|
|
||||||
raise ValidationError(_("Invalid tag hierarchy."))
|
|
||||||
deepest_new_depth = (new_parent_depth + 1) + height
|
deepest_new_depth = (new_parent_depth + 1) + height
|
||||||
if deepest_new_depth > self.MAX_NESTING_DEPTH:
|
if deepest_new_depth > self.MAX_NESTING_DEPTH:
|
||||||
raise ValidationError(_("Maximum nesting depth exceeded."))
|
raise ValidationError(_("Maximum nesting depth exceeded."))
|
||||||
@@ -442,8 +425,9 @@ class Document(SoftDeleteModel, ModelWithOwner):
|
|||||||
def add_nested_tags(self, tags):
|
def add_nested_tags(self, tags):
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
self.tags.add(tag)
|
self.tags.add(tag)
|
||||||
if tag.parent:
|
parent = tag.get_parent()
|
||||||
self.add_nested_tags([tag.parent])
|
if parent:
|
||||||
|
self.add_nested_tags([parent])
|
||||||
|
|
||||||
|
|
||||||
class SavedView(ModelWithOwner):
|
class SavedView(ModelWithOwner):
|
||||||
|
@@ -583,19 +583,28 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
|||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
# Validate when changing parent
|
# Validate when changing parent
|
||||||
parent = attrs.get("parent", self.instance.parent if self.instance else None)
|
parent = attrs.get(
|
||||||
|
"tn_parent",
|
||||||
|
self.instance.get_parent() if self.instance else None,
|
||||||
|
)
|
||||||
|
|
||||||
if self.instance:
|
if self.instance:
|
||||||
# Temporarily set parent on the instance if updating and use model clean()
|
# Temporarily set parent on the instance if updating and use model clean()
|
||||||
original_parent = self.instance.parent
|
original_parent = self.instance.get_parent()
|
||||||
try:
|
try:
|
||||||
self.instance.set_parent(parent)
|
# Temporarily set tn_parent in-memory to validate clean()
|
||||||
|
self.instance.tn_parent = parent
|
||||||
self.instance.clean()
|
self.instance.clean()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug("Tag parent validation failed: %s", e)
|
logger.debug("Tag parent validation failed: %s", e)
|
||||||
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.debug("Tag parent validation failed: %s", e)
|
||||||
|
if getattr(e, "message_dict", None):
|
||||||
|
raise serializers.ValidationError(e.message_dict)
|
||||||
|
raise serializers.ValidationError({"non_field_errors": e.messages})
|
||||||
finally:
|
finally:
|
||||||
self.instance.set_parent(original_parent)
|
self.instance.tn_parent = original_parent
|
||||||
else:
|
else:
|
||||||
# For new instances, create a transient Tag and validate
|
# For new instances, create a transient Tag and validate
|
||||||
temp = Tag(tn_parent=parent)
|
temp = Tag(tn_parent=parent)
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.exceptions import ValidationError
|
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
from documents import bulk_edit
|
from documents import bulk_edit
|
||||||
@@ -130,7 +129,7 @@ class TestTagHierarchy(APITestCase):
|
|||||||
format="json",
|
format="json",
|
||||||
)
|
)
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
assert "parent" in resp.data
|
assert "Cannot set itself as parent" in str(resp.data["parent"])
|
||||||
|
|
||||||
def test_cannot_set_parent_to_descendant(self):
|
def test_cannot_set_parent_to_descendant(self):
|
||||||
a = Tag.objects.create(name="A")
|
a = Tag.objects.create(name="A")
|
||||||
@@ -144,7 +143,7 @@ class TestTagHierarchy(APITestCase):
|
|||||||
format="json",
|
format="json",
|
||||||
)
|
)
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
assert "Cannot set parent to a descendant" in str(resp.data["non_field_errors"])
|
assert "Cannot set parent to a descendant" in str(resp.data["parent"])
|
||||||
|
|
||||||
def test_max_depth_on_create(self):
|
def test_max_depth_on_create(self):
|
||||||
a = Tag.objects.create(name="A1")
|
a = Tag.objects.create(name="A1")
|
||||||
@@ -203,15 +202,4 @@ class TestTagHierarchy(APITestCase):
|
|||||||
)
|
)
|
||||||
assert resp_ok.status_code in (200, 202)
|
assert resp_ok.status_code in (200, 202)
|
||||||
x.refresh_from_db()
|
x.refresh_from_db()
|
||||||
assert x.parent_id == c.id
|
assert x.parent_pk == c.id
|
||||||
|
|
||||||
def test_invalid_hierarchy_recursion_error(self):
|
|
||||||
t = Tag.objects.create(name="TagA")
|
|
||||||
|
|
||||||
with mock.patch(
|
|
||||||
"documents.models.Tag.subtree_height",
|
|
||||||
side_effect=RecursionError,
|
|
||||||
):
|
|
||||||
with self.assertRaises(ValidationError) as cm:
|
|
||||||
t.clean()
|
|
||||||
assert "Invalid tag hierarchy" in str(cm.exception)
|
|
||||||
|
@@ -342,9 +342,9 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
|
|||||||
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
|
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
|
||||||
|
|
||||||
def perform_update(self, serializer):
|
def perform_update(self, serializer):
|
||||||
old_parent = self.get_object().parent
|
old_parent = self.get_object().get_parent()
|
||||||
tag = serializer.save()
|
tag = serializer.save()
|
||||||
new_parent = tag.parent
|
new_parent = tag.get_parent()
|
||||||
if old_parent != new_parent:
|
if old_parent != new_parent:
|
||||||
self._update_document_parent_tags(tag, old_parent, new_parent)
|
self._update_document_parent_tags(tag, old_parent, new_parent)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user