From 7b3a6877c37c0f891082ef62b93e3a7d0be64e5b Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:49:31 -0700 Subject: [PATCH] Try replacing with TreeNodeModel --- src/documents/bulk_edit.py | 8 +- src/documents/migrations/1070_tag_parent.py | 26 --- ...ors_count_tag_tn_ancestors_pks_and_more.py | 159 ++++++++++++++++++ src/documents/models.py | 36 +--- src/documents/serialisers.py | 19 +-- src/documents/signals/handlers.py | 4 +- src/documents/tests/test_tag_hierarchy.py | 20 +-- src/documents/views.py | 2 +- 8 files changed, 187 insertions(+), 87 deletions(-) delete mode 100644 src/documents/migrations/1070_tag_parent.py create mode 100644 src/documents/migrations/1070_tag_tn_ancestors_count_tag_tn_ancestors_pks_and_more.py diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index e07e03691..4774ac5e8 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -98,7 +98,7 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: tag_obj = Tag.objects.get(pk=tag) - tags_to_add = [tag_obj, *tag_obj.get_all_ancestors()] + tags_to_add = [tag_obj, *tag_obj.get_ancestors()] DocumentTagRelationship = Document.tags.through to_create = [] @@ -124,7 +124,7 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: tag_obj = Tag.objects.get(pk=tag) - tags_to_remove = [tag_obj, *tag_obj.get_all_descendants()] + tags_to_remove = [tag_obj, *tag_obj.get_descendants()] tag_ids = [t.id for t in tags_to_remove] DocumentTagRelationship = Document.tags.through @@ -154,13 +154,13 @@ def modify_tags( expanded_add_tags: set[int] = set() for tag_id in add_tags: t = Tag.objects.get(pk=tag_id) - expanded_add_tags.update([t.id for t in [t, *t.get_all_ancestors()]]) + expanded_add_tags.update([t.id for t in [t, *t.get_ancestors()]]) # remove with all descendants expanded_remove_tags: set[int] = set() for tag_id in remove_tags: t = Tag.objects.get(pk=tag_id) - expanded_remove_tags.update([t.id for t in [t, *t.get_all_descendants()]]) + expanded_remove_tags.update([t.id for t in [t, *t.get_descendants()]]) if expanded_remove_tags: DocumentTagRelationship.objects.filter( diff --git a/src/documents/migrations/1070_tag_parent.py b/src/documents/migrations/1070_tag_parent.py deleted file mode 100644 index 8993335f3..000000000 --- a/src/documents/migrations/1070_tag_parent.py +++ /dev/null @@ -1,26 +0,0 @@ -# Generated by Django 5.1.5 on 2025-02-10 06:02 - -import django.db.models.deletion -from django.db import migrations -from django.db import models - - -class Migration(migrations.Migration): - dependencies = [ - ("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"), - ] - - operations = [ - migrations.AddField( - model_name="tag", - name="parent", - field=models.ForeignKey( - blank=True, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="children", - to="documents.tag", - verbose_name="parent", - ), - ), - ] diff --git a/src/documents/migrations/1070_tag_tn_ancestors_count_tag_tn_ancestors_pks_and_more.py b/src/documents/migrations/1070_tag_tn_ancestors_count_tag_tn_ancestors_pks_and_more.py new file mode 100644 index 000000000..983355dac --- /dev/null +++ b/src/documents/migrations/1070_tag_tn_ancestors_count_tag_tn_ancestors_pks_and_more.py @@ -0,0 +1,159 @@ +# Generated by Django 5.2.6 on 2025-09-12 18:42 + +import django.core.validators +import django.db.models.deletion +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + dependencies = [ + ("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="tag", + name="tn_ancestors_count", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Ancestors count", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_ancestors_pks", + field=models.TextField( + blank=True, + default="", + editable=False, + verbose_name="Ancestors pks", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_children_count", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Children count", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_children_pks", + field=models.TextField( + blank=True, + default="", + editable=False, + verbose_name="Children pks", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_depth", + field=models.PositiveIntegerField( + default=0, + editable=False, + validators=[ + django.core.validators.MinValueValidator(0), + django.core.validators.MaxValueValidator(10), + ], + verbose_name="Depth", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_descendants_count", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Descendants count", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_descendants_pks", + field=models.TextField( + blank=True, + default="", + editable=False, + verbose_name="Descendants pks", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_index", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Index", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_level", + field=models.PositiveIntegerField( + default=1, + editable=False, + validators=[ + django.core.validators.MinValueValidator(1), + django.core.validators.MaxValueValidator(10), + ], + verbose_name="Level", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_order", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Order", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_parent", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="tn_children", + to="documents.tag", + verbose_name="Parent", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_priority", + field=models.PositiveIntegerField( + default=0, + validators=[ + django.core.validators.MinValueValidator(0), + django.core.validators.MaxValueValidator(9999999999), + ], + verbose_name="Priority", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_siblings_count", + field=models.PositiveIntegerField( + default=0, + editable=False, + verbose_name="Siblings count", + ), + ), + migrations.AddField( + model_name="tag", + name="tn_siblings_pks", + field=models.TextField( + blank=True, + default="", + editable=False, + verbose_name="Siblings pks", + ), + ), + ] diff --git a/src/documents/models.py b/src/documents/models.py index c8778b4f5..6b7bf84a7 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -14,6 +14,7 @@ from django.db import models from django.utils import timezone from django.utils.translation import gettext_lazy as _ from multiselectfield import MultiSelectField +from treenode.models import TreeNodeModel if settings.AUDIT_LOG_ENABLED: from auditlog.registry import auditlog @@ -97,7 +98,7 @@ class Correspondent(MatchingModel): verbose_name_plural = _("correspondents") -class Tag(MatchingModel): +class Tag(MatchingModel, TreeNodeModel): color = models.CharField(_("color"), max_length=7, default="#a6cee3") # Maximum allowed nesting depth for tags (root = 1, max depth = 5) MAX_NESTING_DEPTH: Final[int] = 5 @@ -111,35 +112,12 @@ class Tag(MatchingModel): ), ) - parent = models.ForeignKey( - "self", - blank=True, - null=True, - on_delete=models.CASCADE, - related_name="children", - verbose_name=_("parent"), - ) - class Meta(MatchingModel.Meta): verbose_name = _("tag") verbose_name_plural = _("tags") - def get_all_descendants(self): - descendants = [] - for child in self.children.all(): - descendants.append(child) - descendants.extend(child.get_all_descendants()) - return descendants - - def get_all_ancestors(self): - ancestors = [] - if self.parent: - ancestors.append(self.parent) - ancestors.extend(self.parent.get_all_ancestors()) - return ancestors - - def subtree_height(self, node) -> int: - children = list(node.children.all()) + def subtree_height(self, node: TreeNodeModel) -> int: + children = list(node.children) if not children: return 0 return 1 + max(self.subtree_height(child) for child in children) @@ -153,16 +131,14 @@ class Tag(MatchingModel): if ( self.parent and self.pk is not None - and any( - ancestor.pk == self.pk for ancestor in self.parent.get_all_ancestors() - ) + and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors()) ): raise ValidationError(_("Cannot set parent to a descendant.")) # Enforce maximum nesting depth new_parent_depth = 0 if self.parent: - new_parent_depth = len(self.parent.get_all_ancestors()) + 1 + new_parent_depth = len(self.parent.get_ancestors()) + 1 if self.pk is None: # Unsaved tag cannot have children; treat as leaf height = 0 diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 956da8129..14560c2c3 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -541,17 +541,8 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer): text_color = serializers.SerializerMethodField() - children = SerializerMethodField() - - @extend_schema_field( - field=serializers.ListSerializer( - child=serializers.PrimaryKeyRelatedField( - queryset=Tag.objects.all(), - ), - ), - ) def get_children(self, obj): - return TagSerializer(obj.children.all(), many=True).data + return obj.get_children_pks() class Meta: model = Tag @@ -588,16 +579,16 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer): # Temporarily set parent on the instance if updating and use model clean() original_parent = self.instance.parent try: - self.instance.parent = parent + self.instance.tn_parent = parent self.instance.clean() except ValidationError as e: logger.debug("Tag parent validation failed: %s", e) raise serializers.ValidationError({"parent": _("Invalid parent tag.")}) finally: - self.instance.parent = original_parent + self.instance.tn_parent = original_parent else: # For new instances, create a transient Tag and validate - temp = Tag(parent=parent) + temp = Tag(tn_parent=parent) try: temp.clean() except ValidationError as e: @@ -1073,7 +1064,7 @@ class DocumentSerializer( # add all parent tags all_ancestor_tags = set(validated_data["tags"]) for tag in validated_data["tags"]: - all_ancestor_tags.update(tag.get_all_ancestors()) + all_ancestor_tags.update(tag.get_ancestors()) validated_data["tags"] = list(all_ancestor_tags) # remove any children for parents that are being removed tag_parents_being_removed = [ diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index f75831294..8086c46b3 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -770,7 +770,7 @@ def run_workflows( tag_ids_to_add: set[int] = set() for tag in action.assign_tags.all(): tag_ids_to_add.add(tag.pk) - tag_ids_to_add.update(t.pk for t in tag.get_all_ancestors()) + tag_ids_to_add.update(t.pk for t in tag.get_ancestors()) if not use_overrides: doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add) @@ -923,7 +923,7 @@ def run_workflows( tag_ids_to_remove: set[int] = set() for tag in action.remove_tags.all(): tag_ids_to_remove.add(tag.pk) - tag_ids_to_remove.update(t.pk for t in tag.get_all_descendants()) + tag_ids_to_remove.update(t.pk for t in tag.get_descendants()) if not use_overrides: doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove] diff --git a/src/documents/tests/test_tag_hierarchy.py b/src/documents/tests/test_tag_hierarchy.py index 8a4f703ab..0be167f76 100644 --- a/src/documents/tests/test_tag_hierarchy.py +++ b/src/documents/tests/test_tag_hierarchy.py @@ -19,7 +19,7 @@ class TestTagHierarchy(APITestCase): self.client.force_authenticate(user=self.user) self.parent = Tag.objects.create(name="Parent") - self.child = Tag.objects.create(name="Child", parent=self.parent) + self.child = Tag.objects.create(name="Child", tn_parent=self.parent) patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay") self.async_task = patcher.start() @@ -134,8 +134,8 @@ class TestTagHierarchy(APITestCase): def test_cannot_set_parent_to_descendant(self): a = Tag.objects.create(name="A") - b = Tag.objects.create(name="B", parent=a) - c = Tag.objects.create(name="C", parent=b) + b = Tag.objects.create(name="B", tn_parent=a) + c = Tag.objects.create(name="C", tn_parent=b) # Attempt to set A's parent to C (descendant) should fail resp = self.client.patch( @@ -148,9 +148,9 @@ class TestTagHierarchy(APITestCase): def test_max_depth_on_create(self): a = Tag.objects.create(name="A1") - b = Tag.objects.create(name="B1", parent=a) - c = Tag.objects.create(name="C1", parent=b) - d = Tag.objects.create(name="D1", parent=c) + b = Tag.objects.create(name="B1", tn_parent=a) + c = Tag.objects.create(name="C1", tn_parent=b) + d = Tag.objects.create(name="D1", tn_parent=c) # Creating E under D yields depth 5: allowed resp_ok = self.client.post( @@ -176,12 +176,12 @@ class TestTagHierarchy(APITestCase): def test_max_depth_on_move_subtree(self): a = Tag.objects.create(name="A2") - b = Tag.objects.create(name="B2", parent=a) - c = Tag.objects.create(name="C2", parent=b) - d = Tag.objects.create(name="D2", parent=c) + b = Tag.objects.create(name="B2", tn_parent=a) + c = Tag.objects.create(name="C2", tn_parent=b) + d = Tag.objects.create(name="D2", tn_parent=c) x = Tag.objects.create(name="X2") - y = Tag.objects.create(name="Y2", parent=x) + y = Tag.objects.create(name="Y2", tn_parent=x) assert y.parent_id == x.id # Moving X under D would make deepest node Y exceed depth 5 -> reject diff --git a/src/documents/views.py b/src/documents/views.py index d7585be7e..adf6a2a90 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -354,7 +354,7 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin): affected = set() if new_parent: - parents_to_add = [new_parent, *new_parent.get_all_ancestors()] + parents_to_add = [new_parent, *new_parent.get_ancestors()] to_create = [] for parent in parents_to_add: missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent)