mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-14 21:45:37 -05:00
Try replacing with TreeNodeModel
This commit is contained in:
@@ -98,7 +98,7 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera
|
||||
|
||||
def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||
tag_obj = Tag.objects.get(pk=tag)
|
||||
tags_to_add = [tag_obj, *tag_obj.get_all_ancestors()]
|
||||
tags_to_add = [tag_obj, *tag_obj.get_ancestors()]
|
||||
|
||||
DocumentTagRelationship = Document.tags.through
|
||||
to_create = []
|
||||
@@ -124,7 +124,7 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||
|
||||
def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||
tag_obj = Tag.objects.get(pk=tag)
|
||||
tags_to_remove = [tag_obj, *tag_obj.get_all_descendants()]
|
||||
tags_to_remove = [tag_obj, *tag_obj.get_descendants()]
|
||||
tag_ids = [t.id for t in tags_to_remove]
|
||||
|
||||
DocumentTagRelationship = Document.tags.through
|
||||
@@ -154,13 +154,13 @@ def modify_tags(
|
||||
expanded_add_tags: set[int] = set()
|
||||
for tag_id in add_tags:
|
||||
t = Tag.objects.get(pk=tag_id)
|
||||
expanded_add_tags.update([t.id for t in [t, *t.get_all_ancestors()]])
|
||||
expanded_add_tags.update([t.id for t in [t, *t.get_ancestors()]])
|
||||
|
||||
# remove with all descendants
|
||||
expanded_remove_tags: set[int] = set()
|
||||
for tag_id in remove_tags:
|
||||
t = Tag.objects.get(pk=tag_id)
|
||||
expanded_remove_tags.update([t.id for t in [t, *t.get_all_descendants()]])
|
||||
expanded_remove_tags.update([t.id for t in [t, *t.get_descendants()]])
|
||||
|
||||
if expanded_remove_tags:
|
||||
DocumentTagRelationship.objects.filter(
|
||||
|
@@ -1,26 +0,0 @@
|
||||
# Generated by Django 5.1.5 on 2025-02-10 06:02
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="parent",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="children",
|
||||
to="documents.tag",
|
||||
verbose_name="parent",
|
||||
),
|
||||
),
|
||||
]
|
@@ -0,0 +1,159 @@
|
||||
# Generated by Django 5.2.6 on 2025-09-12 18:42
|
||||
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_ancestors_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Ancestors count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_ancestors_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Ancestors pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_children_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Children count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_children_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Children pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_depth",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(0),
|
||||
django.core.validators.MaxValueValidator(10),
|
||||
],
|
||||
verbose_name="Depth",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_descendants_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Descendants count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_descendants_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Descendants pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_index",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Index",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_level",
|
||||
field=models.PositiveIntegerField(
|
||||
default=1,
|
||||
editable=False,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(1),
|
||||
django.core.validators.MaxValueValidator(10),
|
||||
],
|
||||
verbose_name="Level",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_order",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Order",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_parent",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="tn_children",
|
||||
to="documents.tag",
|
||||
verbose_name="Parent",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_priority",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(0),
|
||||
django.core.validators.MaxValueValidator(9999999999),
|
||||
],
|
||||
verbose_name="Priority",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_siblings_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Siblings count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_siblings_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Siblings pks",
|
||||
),
|
||||
),
|
||||
]
|
@@ -14,6 +14,7 @@ from django.db import models
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from multiselectfield import MultiSelectField
|
||||
from treenode.models import TreeNodeModel
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.registry import auditlog
|
||||
@@ -97,7 +98,7 @@ class Correspondent(MatchingModel):
|
||||
verbose_name_plural = _("correspondents")
|
||||
|
||||
|
||||
class Tag(MatchingModel):
|
||||
class Tag(MatchingModel, TreeNodeModel):
|
||||
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
|
||||
# Maximum allowed nesting depth for tags (root = 1, max depth = 5)
|
||||
MAX_NESTING_DEPTH: Final[int] = 5
|
||||
@@ -111,35 +112,12 @@ class Tag(MatchingModel):
|
||||
),
|
||||
)
|
||||
|
||||
parent = models.ForeignKey(
|
||||
"self",
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="children",
|
||||
verbose_name=_("parent"),
|
||||
)
|
||||
|
||||
class Meta(MatchingModel.Meta):
|
||||
verbose_name = _("tag")
|
||||
verbose_name_plural = _("tags")
|
||||
|
||||
def get_all_descendants(self):
|
||||
descendants = []
|
||||
for child in self.children.all():
|
||||
descendants.append(child)
|
||||
descendants.extend(child.get_all_descendants())
|
||||
return descendants
|
||||
|
||||
def get_all_ancestors(self):
|
||||
ancestors = []
|
||||
if self.parent:
|
||||
ancestors.append(self.parent)
|
||||
ancestors.extend(self.parent.get_all_ancestors())
|
||||
return ancestors
|
||||
|
||||
def subtree_height(self, node) -> int:
|
||||
children = list(node.children.all())
|
||||
def subtree_height(self, node: TreeNodeModel) -> int:
|
||||
children = list(node.children)
|
||||
if not children:
|
||||
return 0
|
||||
return 1 + max(self.subtree_height(child) for child in children)
|
||||
@@ -153,16 +131,14 @@ class Tag(MatchingModel):
|
||||
if (
|
||||
self.parent
|
||||
and self.pk is not None
|
||||
and any(
|
||||
ancestor.pk == self.pk for ancestor in self.parent.get_all_ancestors()
|
||||
)
|
||||
and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
|
||||
):
|
||||
raise ValidationError(_("Cannot set parent to a descendant."))
|
||||
|
||||
# Enforce maximum nesting depth
|
||||
new_parent_depth = 0
|
||||
if self.parent:
|
||||
new_parent_depth = len(self.parent.get_all_ancestors()) + 1
|
||||
new_parent_depth = len(self.parent.get_ancestors()) + 1
|
||||
if self.pk is None:
|
||||
# Unsaved tag cannot have children; treat as leaf
|
||||
height = 0
|
||||
|
@@ -541,17 +541,8 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
||||
|
||||
text_color = serializers.SerializerMethodField()
|
||||
|
||||
children = SerializerMethodField()
|
||||
|
||||
@extend_schema_field(
|
||||
field=serializers.ListSerializer(
|
||||
child=serializers.PrimaryKeyRelatedField(
|
||||
queryset=Tag.objects.all(),
|
||||
),
|
||||
),
|
||||
)
|
||||
def get_children(self, obj):
|
||||
return TagSerializer(obj.children.all(), many=True).data
|
||||
return obj.get_children_pks()
|
||||
|
||||
class Meta:
|
||||
model = Tag
|
||||
@@ -588,16 +579,16 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
||||
# Temporarily set parent on the instance if updating and use model clean()
|
||||
original_parent = self.instance.parent
|
||||
try:
|
||||
self.instance.parent = parent
|
||||
self.instance.tn_parent = parent
|
||||
self.instance.clean()
|
||||
except ValidationError as e:
|
||||
logger.debug("Tag parent validation failed: %s", e)
|
||||
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
||||
finally:
|
||||
self.instance.parent = original_parent
|
||||
self.instance.tn_parent = original_parent
|
||||
else:
|
||||
# For new instances, create a transient Tag and validate
|
||||
temp = Tag(parent=parent)
|
||||
temp = Tag(tn_parent=parent)
|
||||
try:
|
||||
temp.clean()
|
||||
except ValidationError as e:
|
||||
@@ -1073,7 +1064,7 @@ class DocumentSerializer(
|
||||
# add all parent tags
|
||||
all_ancestor_tags = set(validated_data["tags"])
|
||||
for tag in validated_data["tags"]:
|
||||
all_ancestor_tags.update(tag.get_all_ancestors())
|
||||
all_ancestor_tags.update(tag.get_ancestors())
|
||||
validated_data["tags"] = list(all_ancestor_tags)
|
||||
# remove any children for parents that are being removed
|
||||
tag_parents_being_removed = [
|
||||
|
@@ -770,7 +770,7 @@ def run_workflows(
|
||||
tag_ids_to_add: set[int] = set()
|
||||
for tag in action.assign_tags.all():
|
||||
tag_ids_to_add.add(tag.pk)
|
||||
tag_ids_to_add.update(t.pk for t in tag.get_all_ancestors())
|
||||
tag_ids_to_add.update(t.pk for t in tag.get_ancestors())
|
||||
|
||||
if not use_overrides:
|
||||
doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add)
|
||||
@@ -923,7 +923,7 @@ def run_workflows(
|
||||
tag_ids_to_remove: set[int] = set()
|
||||
for tag in action.remove_tags.all():
|
||||
tag_ids_to_remove.add(tag.pk)
|
||||
tag_ids_to_remove.update(t.pk for t in tag.get_all_descendants())
|
||||
tag_ids_to_remove.update(t.pk for t in tag.get_descendants())
|
||||
|
||||
if not use_overrides:
|
||||
doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove]
|
||||
|
@@ -19,7 +19,7 @@ class TestTagHierarchy(APITestCase):
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.parent = Tag.objects.create(name="Parent")
|
||||
self.child = Tag.objects.create(name="Child", parent=self.parent)
|
||||
self.child = Tag.objects.create(name="Child", tn_parent=self.parent)
|
||||
|
||||
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
|
||||
self.async_task = patcher.start()
|
||||
@@ -134,8 +134,8 @@ class TestTagHierarchy(APITestCase):
|
||||
|
||||
def test_cannot_set_parent_to_descendant(self):
|
||||
a = Tag.objects.create(name="A")
|
||||
b = Tag.objects.create(name="B", parent=a)
|
||||
c = Tag.objects.create(name="C", parent=b)
|
||||
b = Tag.objects.create(name="B", tn_parent=a)
|
||||
c = Tag.objects.create(name="C", tn_parent=b)
|
||||
|
||||
# Attempt to set A's parent to C (descendant) should fail
|
||||
resp = self.client.patch(
|
||||
@@ -148,9 +148,9 @@ class TestTagHierarchy(APITestCase):
|
||||
|
||||
def test_max_depth_on_create(self):
|
||||
a = Tag.objects.create(name="A1")
|
||||
b = Tag.objects.create(name="B1", parent=a)
|
||||
c = Tag.objects.create(name="C1", parent=b)
|
||||
d = Tag.objects.create(name="D1", parent=c)
|
||||
b = Tag.objects.create(name="B1", tn_parent=a)
|
||||
c = Tag.objects.create(name="C1", tn_parent=b)
|
||||
d = Tag.objects.create(name="D1", tn_parent=c)
|
||||
|
||||
# Creating E under D yields depth 5: allowed
|
||||
resp_ok = self.client.post(
|
||||
@@ -176,12 +176,12 @@ class TestTagHierarchy(APITestCase):
|
||||
|
||||
def test_max_depth_on_move_subtree(self):
|
||||
a = Tag.objects.create(name="A2")
|
||||
b = Tag.objects.create(name="B2", parent=a)
|
||||
c = Tag.objects.create(name="C2", parent=b)
|
||||
d = Tag.objects.create(name="D2", parent=c)
|
||||
b = Tag.objects.create(name="B2", tn_parent=a)
|
||||
c = Tag.objects.create(name="C2", tn_parent=b)
|
||||
d = Tag.objects.create(name="D2", tn_parent=c)
|
||||
|
||||
x = Tag.objects.create(name="X2")
|
||||
y = Tag.objects.create(name="Y2", parent=x)
|
||||
y = Tag.objects.create(name="Y2", tn_parent=x)
|
||||
assert y.parent_id == x.id
|
||||
|
||||
# Moving X under D would make deepest node Y exceed depth 5 -> reject
|
||||
|
@@ -354,7 +354,7 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
|
||||
affected = set()
|
||||
|
||||
if new_parent:
|
||||
parents_to_add = [new_parent, *new_parent.get_all_ancestors()]
|
||||
parents_to_add = [new_parent, *new_parent.get_ancestors()]
|
||||
to_create = []
|
||||
for parent in parents_to_add:
|
||||
missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent)
|
||||
|
Reference in New Issue
Block a user