mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 03:16:10 -06:00 
			
		
		
		
	More treenode cleanup
This commit is contained in:
		@@ -116,37 +116,20 @@ class Tag(MatchingModel, TreeNodeModel):
 | 
				
			|||||||
        verbose_name = _("tag")
 | 
					        verbose_name = _("tag")
 | 
				
			||||||
        verbose_name_plural = _("tags")
 | 
					        verbose_name_plural = _("tags")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def subtree_height(self, node: TreeNodeModel) -> int:
 | 
					 | 
				
			||||||
        children = list(node.get_children())
 | 
					 | 
				
			||||||
        if not children:
 | 
					 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
        return 1 + max(self.subtree_height(child) for child in children)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def clean(self):
 | 
					    def clean(self):
 | 
				
			||||||
        # Prevent self-parenting
 | 
					        # Prevent self-parenting and assigning a descendant as parent
 | 
				
			||||||
        if self.parent == self:
 | 
					        parent = self.get_parent()
 | 
				
			||||||
            raise ValidationError(_("Cannot set itself as parent."))
 | 
					        if parent == self:
 | 
				
			||||||
 | 
					            raise ValidationError({"parent": _("Cannot set itself as parent.")})
 | 
				
			||||||
        # Prevent assigning a descendant as parent
 | 
					        if parent and self.pk is not None and self.is_ancestor_of(parent):
 | 
				
			||||||
        if (
 | 
					            raise ValidationError({"parent": _("Cannot set parent to a descendant.")})
 | 
				
			||||||
            self.parent
 | 
					 | 
				
			||||||
            and self.pk is not None
 | 
					 | 
				
			||||||
            and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            raise ValidationError(_("Cannot set parent to a descendant."))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Enforce maximum nesting depth
 | 
					        # Enforce maximum nesting depth
 | 
				
			||||||
        new_parent_depth = 0
 | 
					        new_parent_depth = 0
 | 
				
			||||||
        if self.parent:
 | 
					        if parent:
 | 
				
			||||||
            new_parent_depth = len(self.parent.get_ancestors()) + 1
 | 
					            new_parent_depth = parent.get_ancestors_count() + 1
 | 
				
			||||||
        if self.pk is None:
 | 
					
 | 
				
			||||||
            # Unsaved tag cannot have children; treat as leaf
 | 
					        height = 0 if self.pk is None else self.get_depth()
 | 
				
			||||||
            height = 0
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                height = self.subtree_height(self)
 | 
					 | 
				
			||||||
            except RecursionError:
 | 
					 | 
				
			||||||
                raise ValidationError(_("Invalid tag hierarchy."))
 | 
					 | 
				
			||||||
        deepest_new_depth = (new_parent_depth + 1) + height
 | 
					        deepest_new_depth = (new_parent_depth + 1) + height
 | 
				
			||||||
        if deepest_new_depth > self.MAX_NESTING_DEPTH:
 | 
					        if deepest_new_depth > self.MAX_NESTING_DEPTH:
 | 
				
			||||||
            raise ValidationError(_("Maximum nesting depth exceeded."))
 | 
					            raise ValidationError(_("Maximum nesting depth exceeded."))
 | 
				
			||||||
@@ -442,8 +425,9 @@ class Document(SoftDeleteModel, ModelWithOwner):
 | 
				
			|||||||
    def add_nested_tags(self, tags):
 | 
					    def add_nested_tags(self, tags):
 | 
				
			||||||
        for tag in tags:
 | 
					        for tag in tags:
 | 
				
			||||||
            self.tags.add(tag)
 | 
					            self.tags.add(tag)
 | 
				
			||||||
            if tag.parent:
 | 
					            parent = tag.get_parent()
 | 
				
			||||||
                self.add_nested_tags([tag.parent])
 | 
					            if parent:
 | 
				
			||||||
 | 
					                self.add_nested_tags([parent])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SavedView(ModelWithOwner):
 | 
					class SavedView(ModelWithOwner):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -583,19 +583,28 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs):
 | 
					    def validate(self, attrs):
 | 
				
			||||||
        # Validate when changing parent
 | 
					        # Validate when changing parent
 | 
				
			||||||
        parent = attrs.get("parent", self.instance.parent if self.instance else None)
 | 
					        parent = attrs.get(
 | 
				
			||||||
 | 
					            "tn_parent",
 | 
				
			||||||
 | 
					            self.instance.get_parent() if self.instance else None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.instance:
 | 
					        if self.instance:
 | 
				
			||||||
            # Temporarily set parent on the instance if updating and use model clean()
 | 
					            # Temporarily set parent on the instance if updating and use model clean()
 | 
				
			||||||
            original_parent = self.instance.parent
 | 
					            original_parent = self.instance.get_parent()
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                self.instance.set_parent(parent)
 | 
					                # Temporarily set tn_parent in-memory to validate clean()
 | 
				
			||||||
 | 
					                self.instance.tn_parent = parent
 | 
				
			||||||
                self.instance.clean()
 | 
					                self.instance.clean()
 | 
				
			||||||
            except ValueError as e:
 | 
					            except ValueError as e:
 | 
				
			||||||
                logger.debug("Tag parent validation failed: %s", e)
 | 
					                logger.debug("Tag parent validation failed: %s", e)
 | 
				
			||||||
                raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
 | 
					                raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
 | 
				
			||||||
 | 
					            except ValidationError as e:
 | 
				
			||||||
 | 
					                logger.debug("Tag parent validation failed: %s", e)
 | 
				
			||||||
 | 
					                if getattr(e, "message_dict", None):
 | 
				
			||||||
 | 
					                    raise serializers.ValidationError(e.message_dict)
 | 
				
			||||||
 | 
					                raise serializers.ValidationError({"non_field_errors": e.messages})
 | 
				
			||||||
            finally:
 | 
					            finally:
 | 
				
			||||||
                self.instance.set_parent(original_parent)
 | 
					                self.instance.tn_parent = original_parent
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # For new instances, create a transient Tag and validate
 | 
					            # For new instances, create a transient Tag and validate
 | 
				
			||||||
            temp = Tag(tn_parent=parent)
 | 
					            temp = Tag(tn_parent=parent)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,6 @@
 | 
				
			|||||||
from unittest import mock
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.contrib.auth.models import User
 | 
					from django.contrib.auth.models import User
 | 
				
			||||||
from django.core.exceptions import ValidationError
 | 
					 | 
				
			||||||
from rest_framework.test import APITestCase
 | 
					from rest_framework.test import APITestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from documents import bulk_edit
 | 
					from documents import bulk_edit
 | 
				
			||||||
@@ -130,7 +129,7 @@ class TestTagHierarchy(APITestCase):
 | 
				
			|||||||
            format="json",
 | 
					            format="json",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        assert resp.status_code == 400
 | 
					        assert resp.status_code == 400
 | 
				
			||||||
        assert "parent" in resp.data
 | 
					        assert "Cannot set itself as parent" in str(resp.data["parent"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_cannot_set_parent_to_descendant(self):
 | 
					    def test_cannot_set_parent_to_descendant(self):
 | 
				
			||||||
        a = Tag.objects.create(name="A")
 | 
					        a = Tag.objects.create(name="A")
 | 
				
			||||||
@@ -144,7 +143,7 @@ class TestTagHierarchy(APITestCase):
 | 
				
			|||||||
            format="json",
 | 
					            format="json",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        assert resp.status_code == 400
 | 
					        assert resp.status_code == 400
 | 
				
			||||||
        assert "Cannot set parent to a descendant" in str(resp.data["non_field_errors"])
 | 
					        assert "Cannot set parent to a descendant" in str(resp.data["parent"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_max_depth_on_create(self):
 | 
					    def test_max_depth_on_create(self):
 | 
				
			||||||
        a = Tag.objects.create(name="A1")
 | 
					        a = Tag.objects.create(name="A1")
 | 
				
			||||||
@@ -203,15 +202,4 @@ class TestTagHierarchy(APITestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        assert resp_ok.status_code in (200, 202)
 | 
					        assert resp_ok.status_code in (200, 202)
 | 
				
			||||||
        x.refresh_from_db()
 | 
					        x.refresh_from_db()
 | 
				
			||||||
        assert x.parent_id == c.id
 | 
					        assert x.parent_pk == c.id
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_invalid_hierarchy_recursion_error(self):
 | 
					 | 
				
			||||||
        t = Tag.objects.create(name="TagA")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with mock.patch(
 | 
					 | 
				
			||||||
            "documents.models.Tag.subtree_height",
 | 
					 | 
				
			||||||
            side_effect=RecursionError,
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            with self.assertRaises(ValidationError) as cm:
 | 
					 | 
				
			||||||
                t.clean()
 | 
					 | 
				
			||||||
            assert "Invalid tag hierarchy" in str(cm.exception)
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -342,9 +342,9 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
 | 
				
			|||||||
    ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
 | 
					    ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def perform_update(self, serializer):
 | 
					    def perform_update(self, serializer):
 | 
				
			||||||
        old_parent = self.get_object().parent
 | 
					        old_parent = self.get_object().get_parent()
 | 
				
			||||||
        tag = serializer.save()
 | 
					        tag = serializer.save()
 | 
				
			||||||
        new_parent = tag.parent
 | 
					        new_parent = tag.get_parent()
 | 
				
			||||||
        if old_parent != new_parent:
 | 
					        if old_parent != new_parent:
 | 
				
			||||||
            self._update_document_parent_tags(tag, old_parent, new_parent)
 | 
					            self._update_document_parent_tags(tag, old_parent, new_parent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user