mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-16 21:55:37 -05:00
Try replacing with TreeNodeModel
This commit is contained in:
@@ -98,7 +98,7 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera
|
|||||||
|
|
||||||
def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||||
tag_obj = Tag.objects.get(pk=tag)
|
tag_obj = Tag.objects.get(pk=tag)
|
||||||
tags_to_add = [tag_obj, *tag_obj.get_all_ancestors()]
|
tags_to_add = [tag_obj, *tag_obj.get_ancestors()]
|
||||||
|
|
||||||
DocumentTagRelationship = Document.tags.through
|
DocumentTagRelationship = Document.tags.through
|
||||||
to_create = []
|
to_create = []
|
||||||
@@ -124,7 +124,7 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
|||||||
|
|
||||||
def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||||
tag_obj = Tag.objects.get(pk=tag)
|
tag_obj = Tag.objects.get(pk=tag)
|
||||||
tags_to_remove = [tag_obj, *tag_obj.get_all_descendants()]
|
tags_to_remove = [tag_obj, *tag_obj.get_descendants()]
|
||||||
tag_ids = [t.id for t in tags_to_remove]
|
tag_ids = [t.id for t in tags_to_remove]
|
||||||
|
|
||||||
DocumentTagRelationship = Document.tags.through
|
DocumentTagRelationship = Document.tags.through
|
||||||
@@ -154,13 +154,13 @@ def modify_tags(
|
|||||||
expanded_add_tags: set[int] = set()
|
expanded_add_tags: set[int] = set()
|
||||||
for tag_id in add_tags:
|
for tag_id in add_tags:
|
||||||
t = Tag.objects.get(pk=tag_id)
|
t = Tag.objects.get(pk=tag_id)
|
||||||
expanded_add_tags.update([t.id for t in [t, *t.get_all_ancestors()]])
|
expanded_add_tags.update([t.id for t in [t, *t.get_ancestors()]])
|
||||||
|
|
||||||
# remove with all descendants
|
# remove with all descendants
|
||||||
expanded_remove_tags: set[int] = set()
|
expanded_remove_tags: set[int] = set()
|
||||||
for tag_id in remove_tags:
|
for tag_id in remove_tags:
|
||||||
t = Tag.objects.get(pk=tag_id)
|
t = Tag.objects.get(pk=tag_id)
|
||||||
expanded_remove_tags.update([t.id for t in [t, *t.get_all_descendants()]])
|
expanded_remove_tags.update([t.id for t in [t, *t.get_descendants()]])
|
||||||
|
|
||||||
if expanded_remove_tags:
|
if expanded_remove_tags:
|
||||||
DocumentTagRelationship.objects.filter(
|
DocumentTagRelationship.objects.filter(
|
||||||
|
@@ -1,26 +0,0 @@
|
|||||||
# Generated by Django 5.1.5 on 2025-02-10 06:02
|
|
||||||
|
|
||||||
import django.db.models.deletion
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="tag",
|
|
||||||
name="parent",
|
|
||||||
field=models.ForeignKey(
|
|
||||||
blank=True,
|
|
||||||
null=True,
|
|
||||||
on_delete=django.db.models.deletion.CASCADE,
|
|
||||||
related_name="children",
|
|
||||||
to="documents.tag",
|
|
||||||
verbose_name="parent",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
@@ -0,0 +1,159 @@
|
|||||||
|
# Generated by Django 5.2.6 on 2025-09-12 18:42
|
||||||
|
|
||||||
|
import django.core.validators
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("documents", "1069_workflowtrigger_filter_has_storage_path_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_ancestors_count",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Ancestors count",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_ancestors_pks",
|
||||||
|
field=models.TextField(
|
||||||
|
blank=True,
|
||||||
|
default="",
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Ancestors pks",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_children_count",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Children count",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_children_pks",
|
||||||
|
field=models.TextField(
|
||||||
|
blank=True,
|
||||||
|
default="",
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Children pks",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_depth",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
validators=[
|
||||||
|
django.core.validators.MinValueValidator(0),
|
||||||
|
django.core.validators.MaxValueValidator(10),
|
||||||
|
],
|
||||||
|
verbose_name="Depth",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_descendants_count",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Descendants count",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_descendants_pks",
|
||||||
|
field=models.TextField(
|
||||||
|
blank=True,
|
||||||
|
default="",
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Descendants pks",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_index",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Index",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_level",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=1,
|
||||||
|
editable=False,
|
||||||
|
validators=[
|
||||||
|
django.core.validators.MinValueValidator(1),
|
||||||
|
django.core.validators.MaxValueValidator(10),
|
||||||
|
],
|
||||||
|
verbose_name="Level",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_order",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Order",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_parent",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="tn_children",
|
||||||
|
to="documents.tag",
|
||||||
|
verbose_name="Parent",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_priority",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
validators=[
|
||||||
|
django.core.validators.MinValueValidator(0),
|
||||||
|
django.core.validators.MaxValueValidator(9999999999),
|
||||||
|
],
|
||||||
|
verbose_name="Priority",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_siblings_count",
|
||||||
|
field=models.PositiveIntegerField(
|
||||||
|
default=0,
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Siblings count",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="tag",
|
||||||
|
name="tn_siblings_pks",
|
||||||
|
field=models.TextField(
|
||||||
|
blank=True,
|
||||||
|
default="",
|
||||||
|
editable=False,
|
||||||
|
verbose_name="Siblings pks",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
@@ -14,6 +14,7 @@ from django.db import models
|
|||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from multiselectfield import MultiSelectField
|
from multiselectfield import MultiSelectField
|
||||||
|
from treenode.models import TreeNodeModel
|
||||||
|
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
from auditlog.registry import auditlog
|
from auditlog.registry import auditlog
|
||||||
@@ -97,7 +98,7 @@ class Correspondent(MatchingModel):
|
|||||||
verbose_name_plural = _("correspondents")
|
verbose_name_plural = _("correspondents")
|
||||||
|
|
||||||
|
|
||||||
class Tag(MatchingModel):
|
class Tag(MatchingModel, TreeNodeModel):
|
||||||
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
|
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
|
||||||
# Maximum allowed nesting depth for tags (root = 1, max depth = 5)
|
# Maximum allowed nesting depth for tags (root = 1, max depth = 5)
|
||||||
MAX_NESTING_DEPTH: Final[int] = 5
|
MAX_NESTING_DEPTH: Final[int] = 5
|
||||||
@@ -111,35 +112,12 @@ class Tag(MatchingModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parent = models.ForeignKey(
|
|
||||||
"self",
|
|
||||||
blank=True,
|
|
||||||
null=True,
|
|
||||||
on_delete=models.CASCADE,
|
|
||||||
related_name="children",
|
|
||||||
verbose_name=_("parent"),
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta(MatchingModel.Meta):
|
class Meta(MatchingModel.Meta):
|
||||||
verbose_name = _("tag")
|
verbose_name = _("tag")
|
||||||
verbose_name_plural = _("tags")
|
verbose_name_plural = _("tags")
|
||||||
|
|
||||||
def get_all_descendants(self):
|
def subtree_height(self, node: TreeNodeModel) -> int:
|
||||||
descendants = []
|
children = list(node.children)
|
||||||
for child in self.children.all():
|
|
||||||
descendants.append(child)
|
|
||||||
descendants.extend(child.get_all_descendants())
|
|
||||||
return descendants
|
|
||||||
|
|
||||||
def get_all_ancestors(self):
|
|
||||||
ancestors = []
|
|
||||||
if self.parent:
|
|
||||||
ancestors.append(self.parent)
|
|
||||||
ancestors.extend(self.parent.get_all_ancestors())
|
|
||||||
return ancestors
|
|
||||||
|
|
||||||
def subtree_height(self, node) -> int:
|
|
||||||
children = list(node.children.all())
|
|
||||||
if not children:
|
if not children:
|
||||||
return 0
|
return 0
|
||||||
return 1 + max(self.subtree_height(child) for child in children)
|
return 1 + max(self.subtree_height(child) for child in children)
|
||||||
@@ -153,16 +131,14 @@ class Tag(MatchingModel):
|
|||||||
if (
|
if (
|
||||||
self.parent
|
self.parent
|
||||||
and self.pk is not None
|
and self.pk is not None
|
||||||
and any(
|
and any(ancestor.pk == self.pk for ancestor in self.parent.get_ancestors())
|
||||||
ancestor.pk == self.pk for ancestor in self.parent.get_all_ancestors()
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
raise ValidationError(_("Cannot set parent to a descendant."))
|
raise ValidationError(_("Cannot set parent to a descendant."))
|
||||||
|
|
||||||
# Enforce maximum nesting depth
|
# Enforce maximum nesting depth
|
||||||
new_parent_depth = 0
|
new_parent_depth = 0
|
||||||
if self.parent:
|
if self.parent:
|
||||||
new_parent_depth = len(self.parent.get_all_ancestors()) + 1
|
new_parent_depth = len(self.parent.get_ancestors()) + 1
|
||||||
if self.pk is None:
|
if self.pk is None:
|
||||||
# Unsaved tag cannot have children; treat as leaf
|
# Unsaved tag cannot have children; treat as leaf
|
||||||
height = 0
|
height = 0
|
||||||
|
@@ -541,17 +541,8 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
|||||||
|
|
||||||
text_color = serializers.SerializerMethodField()
|
text_color = serializers.SerializerMethodField()
|
||||||
|
|
||||||
children = SerializerMethodField()
|
|
||||||
|
|
||||||
@extend_schema_field(
|
|
||||||
field=serializers.ListSerializer(
|
|
||||||
child=serializers.PrimaryKeyRelatedField(
|
|
||||||
queryset=Tag.objects.all(),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
def get_children(self, obj):
|
def get_children(self, obj):
|
||||||
return TagSerializer(obj.children.all(), many=True).data
|
return obj.get_children_pks()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Tag
|
model = Tag
|
||||||
@@ -588,16 +579,16 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
|||||||
# Temporarily set parent on the instance if updating and use model clean()
|
# Temporarily set parent on the instance if updating and use model clean()
|
||||||
original_parent = self.instance.parent
|
original_parent = self.instance.parent
|
||||||
try:
|
try:
|
||||||
self.instance.parent = parent
|
self.instance.tn_parent = parent
|
||||||
self.instance.clean()
|
self.instance.clean()
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.debug("Tag parent validation failed: %s", e)
|
logger.debug("Tag parent validation failed: %s", e)
|
||||||
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
||||||
finally:
|
finally:
|
||||||
self.instance.parent = original_parent
|
self.instance.tn_parent = original_parent
|
||||||
else:
|
else:
|
||||||
# For new instances, create a transient Tag and validate
|
# For new instances, create a transient Tag and validate
|
||||||
temp = Tag(parent=parent)
|
temp = Tag(tn_parent=parent)
|
||||||
try:
|
try:
|
||||||
temp.clean()
|
temp.clean()
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
@@ -1073,7 +1064,7 @@ class DocumentSerializer(
|
|||||||
# add all parent tags
|
# add all parent tags
|
||||||
all_ancestor_tags = set(validated_data["tags"])
|
all_ancestor_tags = set(validated_data["tags"])
|
||||||
for tag in validated_data["tags"]:
|
for tag in validated_data["tags"]:
|
||||||
all_ancestor_tags.update(tag.get_all_ancestors())
|
all_ancestor_tags.update(tag.get_ancestors())
|
||||||
validated_data["tags"] = list(all_ancestor_tags)
|
validated_data["tags"] = list(all_ancestor_tags)
|
||||||
# remove any children for parents that are being removed
|
# remove any children for parents that are being removed
|
||||||
tag_parents_being_removed = [
|
tag_parents_being_removed = [
|
||||||
|
@@ -770,7 +770,7 @@ def run_workflows(
|
|||||||
tag_ids_to_add: set[int] = set()
|
tag_ids_to_add: set[int] = set()
|
||||||
for tag in action.assign_tags.all():
|
for tag in action.assign_tags.all():
|
||||||
tag_ids_to_add.add(tag.pk)
|
tag_ids_to_add.add(tag.pk)
|
||||||
tag_ids_to_add.update(t.pk for t in tag.get_all_ancestors())
|
tag_ids_to_add.update(t.pk for t in tag.get_ancestors())
|
||||||
|
|
||||||
if not use_overrides:
|
if not use_overrides:
|
||||||
doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add)
|
doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add)
|
||||||
@@ -923,7 +923,7 @@ def run_workflows(
|
|||||||
tag_ids_to_remove: set[int] = set()
|
tag_ids_to_remove: set[int] = set()
|
||||||
for tag in action.remove_tags.all():
|
for tag in action.remove_tags.all():
|
||||||
tag_ids_to_remove.add(tag.pk)
|
tag_ids_to_remove.add(tag.pk)
|
||||||
tag_ids_to_remove.update(t.pk for t in tag.get_all_descendants())
|
tag_ids_to_remove.update(t.pk for t in tag.get_descendants())
|
||||||
|
|
||||||
if not use_overrides:
|
if not use_overrides:
|
||||||
doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove]
|
doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove]
|
||||||
|
@@ -19,7 +19,7 @@ class TestTagHierarchy(APITestCase):
|
|||||||
self.client.force_authenticate(user=self.user)
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
self.parent = Tag.objects.create(name="Parent")
|
self.parent = Tag.objects.create(name="Parent")
|
||||||
self.child = Tag.objects.create(name="Child", parent=self.parent)
|
self.child = Tag.objects.create(name="Child", tn_parent=self.parent)
|
||||||
|
|
||||||
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
|
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
|
||||||
self.async_task = patcher.start()
|
self.async_task = patcher.start()
|
||||||
@@ -134,8 +134,8 @@ class TestTagHierarchy(APITestCase):
|
|||||||
|
|
||||||
def test_cannot_set_parent_to_descendant(self):
|
def test_cannot_set_parent_to_descendant(self):
|
||||||
a = Tag.objects.create(name="A")
|
a = Tag.objects.create(name="A")
|
||||||
b = Tag.objects.create(name="B", parent=a)
|
b = Tag.objects.create(name="B", tn_parent=a)
|
||||||
c = Tag.objects.create(name="C", parent=b)
|
c = Tag.objects.create(name="C", tn_parent=b)
|
||||||
|
|
||||||
# Attempt to set A's parent to C (descendant) should fail
|
# Attempt to set A's parent to C (descendant) should fail
|
||||||
resp = self.client.patch(
|
resp = self.client.patch(
|
||||||
@@ -148,9 +148,9 @@ class TestTagHierarchy(APITestCase):
|
|||||||
|
|
||||||
def test_max_depth_on_create(self):
|
def test_max_depth_on_create(self):
|
||||||
a = Tag.objects.create(name="A1")
|
a = Tag.objects.create(name="A1")
|
||||||
b = Tag.objects.create(name="B1", parent=a)
|
b = Tag.objects.create(name="B1", tn_parent=a)
|
||||||
c = Tag.objects.create(name="C1", parent=b)
|
c = Tag.objects.create(name="C1", tn_parent=b)
|
||||||
d = Tag.objects.create(name="D1", parent=c)
|
d = Tag.objects.create(name="D1", tn_parent=c)
|
||||||
|
|
||||||
# Creating E under D yields depth 5: allowed
|
# Creating E under D yields depth 5: allowed
|
||||||
resp_ok = self.client.post(
|
resp_ok = self.client.post(
|
||||||
@@ -176,12 +176,12 @@ class TestTagHierarchy(APITestCase):
|
|||||||
|
|
||||||
def test_max_depth_on_move_subtree(self):
|
def test_max_depth_on_move_subtree(self):
|
||||||
a = Tag.objects.create(name="A2")
|
a = Tag.objects.create(name="A2")
|
||||||
b = Tag.objects.create(name="B2", parent=a)
|
b = Tag.objects.create(name="B2", tn_parent=a)
|
||||||
c = Tag.objects.create(name="C2", parent=b)
|
c = Tag.objects.create(name="C2", tn_parent=b)
|
||||||
d = Tag.objects.create(name="D2", parent=c)
|
d = Tag.objects.create(name="D2", tn_parent=c)
|
||||||
|
|
||||||
x = Tag.objects.create(name="X2")
|
x = Tag.objects.create(name="X2")
|
||||||
y = Tag.objects.create(name="Y2", parent=x)
|
y = Tag.objects.create(name="Y2", tn_parent=x)
|
||||||
assert y.parent_id == x.id
|
assert y.parent_id == x.id
|
||||||
|
|
||||||
# Moving X under D would make deepest node Y exceed depth 5 -> reject
|
# Moving X under D would make deepest node Y exceed depth 5 -> reject
|
||||||
|
@@ -354,7 +354,7 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
|
|||||||
affected = set()
|
affected = set()
|
||||||
|
|
||||||
if new_parent:
|
if new_parent:
|
||||||
parents_to_add = [new_parent, *new_parent.get_all_ancestors()]
|
parents_to_add = [new_parent, *new_parent.get_ancestors()]
|
||||||
to_create = []
|
to_create = []
|
||||||
for parent in parents_to_add:
|
for parent in parents_to_add:
|
||||||
missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent)
|
missing = Document.objects.filter(id__in=doc_ids).exclude(tags=parent)
|
||||||
|
Reference in New Issue
Block a user