Try replacing with TreeNodeModel

This commit is contained in:
shamoon
2025-09-12 11:49:31 -07:00
parent 113e9a329a
commit 7b3a6877c3
8 changed files with 187 additions and 87 deletions

View File

@@ -98,7 +98,7 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera
def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
tag_obj = Tag.objects.get(pk=tag)
tags_to_add = [tag_obj, *tag_obj.get_all_ancestors()]
tags_to_add = [tag_obj, *tag_obj.get_ancestors()]
DocumentTagRelationship = Document.tags.through
to_create = []
@@ -124,7 +124,7 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
tag_obj = Tag.objects.get(pk=tag)
tags_to_remove = [tag_obj, *tag_obj.get_all_descendants()]
tags_to_remove = [tag_obj, *tag_obj.get_descendants()]
tag_ids = [t.id for t in tags_to_remove]
DocumentTagRelationship = Document.tags.through
@@ -154,13 +154,13 @@ def modify_tags(
expanded_add_tags: set[int] = set()
for tag_id in add_tags:
t = Tag.objects.get(pk=tag_id)
expanded_add_tags.update([t.id for t in [t, *t.get_all_ancestors()]])
expanded_add_tags.update([t.id for t in [t, *t.get_ancestors()]])
# remove with all descendants
expanded_remove_tags: set[int] = set()
for tag_id in remove_tags:
t = Tag.objects.get(pk=tag_id)
expanded_remove_tags.update([t.id for t in [t, *t.get_all_descendants()]])
expanded_remove_tags.update([t.id for t in [t, *t.get_descendants()]])
if expanded_remove_tags:
DocumentTagRelationship.objects.filter(

View File

@@ -1,26 +0,0 @@
# Generated by Django 5.1.5 on 2025-02-10 06:02
import django.db.models.deletion
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
]
operations = [
migrations.AddField(
model_name="tag",
name="parent",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="children",
to="documents.tag",
verbose_name="parent",
),
),
]

View File

@@ -0,0 +1,159 @@
# Generated by Django 5.2.6 on 2025-09-12 18:42
import django.core.validators
import django.db.models.deletion
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
]
operations = [
migrations.AddField(
model_name="tag",
name="tn_ancestors_count",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Ancestors count",
),
),
migrations.AddField(
model_name="tag",
name="tn_ancestors_pks",
field=models.TextField(
blank=True,
default="",
editable=False,
verbose_name="Ancestors pks",
),
),
migrations.AddField(
model_name="tag",
name="tn_children_count",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Children count",
),
),
migrations.AddField(
model_name="tag",
name="tn_children_pks",
field=models.TextField(
blank=True,
default="",
editable=False,
verbose_name="Children pks",
),
),
migrations.AddField(
model_name="tag",
name="tn_depth",
field=models.PositiveIntegerField(
default=0,
editable=False,
validators=[
django.core.validators.MinValueValidator(0),
django.core.validators.MaxValueValidator(10),
],
verbose_name="Depth",
),
),
migrations.AddField(
model_name="tag",
name="tn_descendants_count",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Descendants count",
),
),
migrations.AddField(
model_name="tag",
name="tn_descendants_pks",
field=models.TextField(
blank=True,
default="",
editable=False,
verbose_name="Descendants pks",
),
),
migrations.AddField(
model_name="tag",
name="tn_index",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Index",
),
),
migrations.AddField(
model_name="tag",
name="tn_level",
field=models.PositiveIntegerField(
default=1,
editable=False,
validators=[
django.core.validators.MinValueValidator(1),
django.core.validators.MaxValueValidator(10),
],
verbose_name="Level",
),
),
migrations.AddField(
model_name="tag",
name="tn_order",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Order",
),
),
migrations.AddField(
model_name="tag",
name="tn_parent",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="tn_children",
to="documents.tag",
verbose_name="Parent",
),
),
migrations.AddField(
model_name="tag",
name="tn_priority",
field=models.PositiveIntegerField(
default=0,
validators=[
django.core.validators.MinValueValidator(0),
django.core.validators.MaxValueValidator(9999999999),
],
verbose_name="Priority",
),
),
migrations.AddField(
model_name="tag",
name="tn_siblings_count",
field=models.PositiveIntegerField(
default=0,
editable=False,
verbose_name="Siblings count",
),
),
migrations.AddField(
model_name="tag",
name="tn_siblings_pks",
field=models.TextField(
blank=True,
default="",
editable=False,
verbose_name="Siblings pks",
),
),
]

View File

@@ -14,6 +14,7 @@ from django.db import models
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from multiselectfield import MultiSelectField
from treenode.models import TreeNodeModel
if settings.AUDIT_LOG_ENABLED:
from auditlog.registry import auditlog
@@ -97,7 +98,7 @@ class Correspondent(MatchingModel):
verbose_name_plural = _("correspondents")
class Tag(MatchingModel):
class Tag(MatchingModel, TreeNodeModel):
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
# Maximum allowed nesting depth for tags (root = 1, max depth = 5)
MAX_NESTING_DEPTH: Final[int] = 5
@@ -111,35 +112,12 @@ class Tag(MatchingModel):
),
)
parent = models.ForeignKey(
"self",
blank=True,
null=True,
on_delete=models.CASCADE,
related_name="children",
verbose_name=_("parent"),
)
class Meta(MatchingModel.Meta):
verbose_name = _("tag")
verbose_name_plural = _("tags")
def get_all_descendants(self):
descendants = []
for child in self.children.all():
descendants.append(child)
descendants.extend(child.get_all_descendants())
return descendants
def get_all_ancestors(self):
ancestors = []
if self.parent:
ancestors.append(self.parent)
ancestors.extend(self.parent.get_all_ancestors())
return ancestors
def subtree_height(self, node) -> int:
children = list(node.children.all())
def subtree_height(self, node: TreeNodeModel) -> int:
children = list(node.children)
if not children:
return 0
return 1 + max(self.subtree_height(child) for child in children)
@@ -153,16 +131,14 @@ class Tag(MatchingModel):
if (
self.parent
and self.pk is not None
and any(
ancestor.pk == self.pk for ancestor in self.parent.get_all_ancestors()
)
and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
):
raise ValidationError(_("Cannot set parent to a descendant."))
# Enforce maximum nesting depth
new_parent_depth = 0
if self.parent:
new_parent_depth = len(self.parent.get_all_ancestors()) + 1
new_parent_depth = len(self.parent.get_ancestors()) + 1
if self.pk is None:
# Unsaved tag cannot have children; treat as leaf
height = 0

View File

@@ -541,17 +541,8 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
text_color = serializers.SerializerMethodField()
children = SerializerMethodField()
@extend_schema_field(
field=serializers.ListSerializer(
child=serializers.PrimaryKeyRelatedField(
queryset=Tag.objects.all(),
),
),
)
def get_children(self, obj):
return TagSerializer(obj.children.all(), many=True).data
return obj.get_children_pks()
class Meta:
model = Tag
@@ -588,16 +579,16 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
# Temporarily set parent on the instance if updating and use model clean()
original_parent = self.instance.parent
try:
self.instance.parent = parent
self.instance.tn_parent = parent
self.instance.clean()
except ValidationError as e:
logger.debug("Tag parent validation failed: %s", e)
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
finally:
self.instance.parent = original_parent
self.instance.tn_parent = original_parent
else:
# For new instances, create a transient Tag and validate
temp = Tag(parent=parent)
temp = Tag(tn_parent=parent)
try:
temp.clean()
except ValidationError as e:
@@ -1073,7 +1064,7 @@ class DocumentSerializer(
# add all parent tags
all_ancestor_tags = set(validated_data["tags"])
for tag in validated_data["tags"]:
all_ancestor_tags.update(tag.get_all_ancestors())
all_ancestor_tags.update(tag.get_ancestors())
validated_data["tags"] = list(all_ancestor_tags)
# remove any children for parents that are being removed
tag_parents_being_removed = [

View File

@@ -770,7 +770,7 @@ def run_workflows(
tag_ids_to_add: set[int] = set()
for tag in action.assign_tags.all():
tag_ids_to_add.add(tag.pk)
tag_ids_to_add.update(t.pk for t in tag.get_all_ancestors())
tag_ids_to_add.update(t.pk for t in tag.get_ancestors())
if not use_overrides:
doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add)
@@ -923,7 +923,7 @@ def run_workflows(
tag_ids_to_remove: set[int] = set()
for tag in action.remove_tags.all():
tag_ids_to_remove.add(tag.pk)
tag_ids_to_remove.update(t.pk for t in tag.get_all_descendants())
tag_ids_to_remove.update(t.pk for t in tag.get_descendants())
if not use_overrides:
doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove]

View File

@@ -19,7 +19,7 @@ class TestTagHierarchy(APITestCase):
self.client.force_authenticate(user=self.user)
self.parent = Tag.objects.create(name="Parent")
self.child = Tag.objects.create(name="Child", parent=self.parent)
self.child = Tag.objects.create(name="Child", tn_parent=self.parent)
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
self.async_task = patcher.start()
@@ -134,8 +134,8 @@ class TestTagHierarchy(APITestCase):
def test_cannot_set_parent_to_descendant(self):
a = Tag.objects.create(name="A")
b = Tag.objects.create(name="B", parent=a)
c = Tag.objects.create(name="C", parent=b)
b = Tag.objects.create(name="B", tn_parent=a)
c = Tag.objects.create(name="C", tn_parent=b)
# Attempt to set A's parent to C (descendant) should fail
resp = self.client.patch(
@@ -148,9 +148,9 @@ class TestTagHierarchy(APITestCase):
def test_max_depth_on_create(self):
a = Tag.objects.create(name="A1")
b = Tag.objects.create(name="B1", parent=a)
c = Tag.objects.create(name="C1", parent=b)
d = Tag.objects.create(name="D1", parent=c)
b = Tag.objects.create(name="B1", tn_parent=a)
c = Tag.objects.create(name="C1", tn_parent=b)
d = Tag.objects.create(name="D1", tn_parent=c)
# Creating E under D yields depth 5: allowed
resp_ok = self.client.post(
@@ -176,12 +176,12 @@ class TestTagHierarchy(APITestCase):
def test_max_depth_on_move_subtree(self):
a = Tag.objects.create(name="A2")
b = Tag.objects.create(name="B2", parent=a)
c = Tag.objects.create(name="C2", parent=b)
d = Tag.objects.create(name="D2", parent=c)
b = Tag.objects.create(name="B2", tn_parent=a)
c = Tag.objects.create(name="C2", tn_parent=b)
d = Tag.objects.create(name="D2", tn_parent=c)
x = Tag.objects.create(name="X2")
y = Tag.objects.create(name="Y2", parent=x)
y = Tag.objects.create(name="Y2", tn_parent=x)
assert y.parent_id == x.id
# Moving X under D would make deepest node Y exceed depth 5 -> reject

View File

@@ -354,7 +354,7 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
affected = set()
if new_parent:
parents_to_add = [new_parent, *new_parent.get_all_ancestors()]
parents_to_add = [new_parent, *new_parent.get_ancestors()]
to_create = []
for parent in parents_to_add:
missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent)