mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-10-12 02:26:09 -05:00
Merge branch 'dev' into feature-ai
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from django.conf import settings
|
||||
from django.contrib import admin
|
||||
from guardian.admin import GuardedModelAdmin
|
||||
from treenode.admin import TreeNodeModelAdmin
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
@@ -14,6 +15,7 @@ from documents.models import SavedViewFilterRule
|
||||
from documents.models import ShareLink
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tasks import update_document_parent_tags
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.admin import LogEntryAdmin
|
||||
@@ -26,12 +28,25 @@ class CorrespondentAdmin(GuardedModelAdmin):
|
||||
list_editable = ("match", "matching_algorithm")
|
||||
|
||||
|
||||
class TagAdmin(GuardedModelAdmin):
|
||||
class TagAdmin(GuardedModelAdmin, TreeNodeModelAdmin):
|
||||
list_display = ("name", "color", "match", "matching_algorithm")
|
||||
list_filter = ("matching_algorithm",)
|
||||
list_editable = ("color", "match", "matching_algorithm")
|
||||
search_fields = ("color", "name")
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
old_parent = None
|
||||
if change and obj.pk:
|
||||
tag = Tag.objects.get(pk=obj.pk)
|
||||
old_parent = tag.get_parent() if tag else None
|
||||
|
||||
super().save_model(request, obj, form, change)
|
||||
|
||||
# sync parent tags on documents if changed
|
||||
new_parent = obj.get_parent()
|
||||
if new_parent and old_parent != new_parent:
|
||||
update_document_parent_tags(obj, new_parent)
|
||||
|
||||
|
||||
class DocumentTypeAdmin(GuardedModelAdmin):
|
||||
list_display = ("name", "match", "matching_algorithm")
|
||||
|
@@ -164,6 +164,9 @@ class BarcodePlugin(ConsumeTaskPlugin):
|
||||
mailrule_id=self.input_doc.mailrule_id,
|
||||
# Can't use same folder or the consume might grab it again
|
||||
original_file=(tmp_dir / new_document.name).resolve(),
|
||||
# Adding optional original_path for later uses in
|
||||
# workflow matching
|
||||
original_path=self.input_doc.original_file,
|
||||
),
|
||||
# All the same metadata
|
||||
self.metadata,
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -13,6 +12,7 @@ from celery import chord
|
||||
from celery import group
|
||||
from celery import shared_task
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -25,6 +25,7 @@ from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.permissions import set_permissions_for_object
|
||||
from documents.plugins.helpers import DocumentsStatusManager
|
||||
from documents.tasks import bulk_update_documents
|
||||
@@ -96,31 +97,45 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera
|
||||
|
||||
|
||||
def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||
qs = Document.objects.filter(Q(id__in=doc_ids) & ~Q(tags__id=tag)).only("pk")
|
||||
affected_docs = list(qs.values_list("pk", flat=True))
|
||||
tag_obj = Tag.objects.get(pk=tag)
|
||||
tags_to_add = [tag_obj, *tag_obj.get_ancestors()]
|
||||
|
||||
DocumentTagRelationship = Document.tags.through
|
||||
to_create = []
|
||||
affected_docs: set[int] = set()
|
||||
|
||||
DocumentTagRelationship.objects.bulk_create(
|
||||
[DocumentTagRelationship(document_id=doc, tag_id=tag) for doc in affected_docs],
|
||||
)
|
||||
for t in tags_to_add:
|
||||
qs = Document.objects.filter(Q(id__in=doc_ids) & ~Q(tags__id=t.id)).only("pk")
|
||||
doc_ids_missing_tag = list(qs.values_list("pk", flat=True))
|
||||
affected_docs.update(doc_ids_missing_tag)
|
||||
to_create.extend(
|
||||
DocumentTagRelationship(document_id=doc, tag_id=t.id)
|
||||
for doc in doc_ids_missing_tag
|
||||
)
|
||||
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
if to_create:
|
||||
DocumentTagRelationship.objects.bulk_create(to_create)
|
||||
|
||||
if affected_docs:
|
||||
bulk_update_documents.delay(document_ids=list(affected_docs))
|
||||
|
||||
return "OK"
|
||||
|
||||
|
||||
def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
|
||||
qs = Document.objects.filter(Q(id__in=doc_ids) & Q(tags__id=tag)).only("pk")
|
||||
affected_docs = list(qs.values_list("pk", flat=True))
|
||||
tag_obj = Tag.objects.get(pk=tag)
|
||||
tag_ids = [tag_obj.id, *tag_obj.get_descendants_pks()]
|
||||
|
||||
DocumentTagRelationship = Document.tags.through
|
||||
qs = DocumentTagRelationship.objects.filter(
|
||||
document_id__in=doc_ids,
|
||||
tag_id__in=tag_ids,
|
||||
)
|
||||
affected_docs = list(qs.values_list("document_id", flat=True).distinct())
|
||||
qs.delete()
|
||||
|
||||
DocumentTagRelationship.objects.filter(
|
||||
Q(document_id__in=affected_docs) & Q(tag_id=tag),
|
||||
).delete()
|
||||
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
if affected_docs:
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
|
||||
return "OK"
|
||||
|
||||
@@ -132,23 +147,57 @@ def modify_tags(
|
||||
) -> Literal["OK"]:
|
||||
qs = Document.objects.filter(id__in=doc_ids).only("pk")
|
||||
affected_docs = list(qs.values_list("pk", flat=True))
|
||||
|
||||
DocumentTagRelationship = Document.tags.through
|
||||
|
||||
DocumentTagRelationship.objects.filter(
|
||||
document_id__in=affected_docs,
|
||||
tag_id__in=remove_tags,
|
||||
).delete()
|
||||
# add with all ancestors
|
||||
expanded_add_tags: set[int] = set()
|
||||
add_tag_objects = Tag.objects.filter(pk__in=add_tags)
|
||||
for t in add_tag_objects:
|
||||
expanded_add_tags.add(int(t.id))
|
||||
expanded_add_tags.update(int(pk) for pk in t.get_ancestors_pks())
|
||||
|
||||
DocumentTagRelationship.objects.bulk_create(
|
||||
[
|
||||
DocumentTagRelationship(document_id=doc, tag_id=tag)
|
||||
for (doc, tag) in itertools.product(affected_docs, add_tags)
|
||||
],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
# remove with all descendants
|
||||
expanded_remove_tags: set[int] = set()
|
||||
remove_tag_objects = Tag.objects.filter(pk__in=remove_tags)
|
||||
for t in remove_tag_objects:
|
||||
expanded_remove_tags.add(int(t.id))
|
||||
expanded_remove_tags.update(int(pk) for pk in t.get_descendants_pks())
|
||||
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
try:
|
||||
with transaction.atomic():
|
||||
if expanded_remove_tags:
|
||||
DocumentTagRelationship.objects.filter(
|
||||
document_id__in=affected_docs,
|
||||
tag_id__in=expanded_remove_tags,
|
||||
).delete()
|
||||
|
||||
to_create = []
|
||||
if expanded_add_tags:
|
||||
existing_pairs = set(
|
||||
DocumentTagRelationship.objects.filter(
|
||||
document_id__in=affected_docs,
|
||||
tag_id__in=expanded_add_tags,
|
||||
).values_list("document_id", "tag_id"),
|
||||
)
|
||||
|
||||
to_create = [
|
||||
DocumentTagRelationship(document_id=doc, tag_id=tag)
|
||||
for doc in affected_docs
|
||||
for tag in expanded_add_tags
|
||||
if (doc, tag) not in existing_pairs
|
||||
]
|
||||
|
||||
if to_create:
|
||||
DocumentTagRelationship.objects.bulk_create(
|
||||
to_create,
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
if affected_docs:
|
||||
bulk_update_documents.delay(document_ids=affected_docs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error modifying tags: {e}")
|
||||
return "ERROR"
|
||||
|
||||
return "OK"
|
||||
|
||||
|
@@ -689,7 +689,7 @@ class ConsumerPlugin(
|
||||
|
||||
if self.metadata.tag_ids:
|
||||
for tag_id in self.metadata.tag_ids:
|
||||
document.tags.add(Tag.objects.get(pk=tag_id))
|
||||
document.add_nested_tags([Tag.objects.get(pk=tag_id)])
|
||||
|
||||
if self.metadata.storage_path_id:
|
||||
document.storage_path = StoragePath.objects.get(
|
||||
|
@@ -156,6 +156,7 @@ class ConsumableDocument:
|
||||
|
||||
source: DocumentSource
|
||||
original_file: Path
|
||||
original_path: Path | None = None
|
||||
mailrule_id: int | None = None
|
||||
mime_type: str = dataclasses.field(init=False, default=None)
|
||||
|
||||
|
@@ -82,6 +82,13 @@ def _is_ignored(filepath: Path) -> bool:
|
||||
|
||||
|
||||
def _consume(filepath: Path) -> None:
|
||||
# Check permissions early
|
||||
try:
|
||||
filepath.stat()
|
||||
except (PermissionError, OSError):
|
||||
logger.warning(f"Not consuming file {filepath}: Permission denied.")
|
||||
return
|
||||
|
||||
if filepath.is_dir() or _is_ignored(filepath):
|
||||
return
|
||||
|
||||
@@ -323,7 +330,12 @@ class Command(BaseCommand):
|
||||
|
||||
# Also make sure the file exists still, some scanners might write a
|
||||
# temporary file first
|
||||
file_still_exists = filepath.exists() and filepath.is_file()
|
||||
try:
|
||||
file_still_exists = filepath.exists() and filepath.is_file()
|
||||
except (PermissionError, OSError): # pragma: no cover
|
||||
# If we can't check, let it fail in the _consume function
|
||||
file_still_exists = True
|
||||
continue
|
||||
|
||||
if waited_long_enough and file_still_exists:
|
||||
_consume(filepath)
|
||||
|
@@ -92,6 +92,9 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
# doc to doc is obviously not useful
|
||||
if first_doc.pk == second_doc.pk:
|
||||
continue
|
||||
# Skip empty documents (e.g. password-protected)
|
||||
if first_doc.content.strip() == "" or second_doc.content.strip() == "":
|
||||
continue
|
||||
# Skip matching which have already been matched together
|
||||
# doc 1 to doc 2 is the same as doc 2 to doc 1
|
||||
doc_1_to_doc_2 = (first_doc.pk, second_doc.pk)
|
||||
|
@@ -314,11 +314,19 @@ def consumable_document_matches_workflow(
|
||||
trigger_matched = False
|
||||
|
||||
# Document path vs trigger path
|
||||
|
||||
# Use the original_path if set, else us the original_file
|
||||
match_against = (
|
||||
document.original_path
|
||||
if document.original_path is not None
|
||||
else document.original_file
|
||||
)
|
||||
|
||||
if (
|
||||
trigger.filter_path is not None
|
||||
and len(trigger.filter_path) > 0
|
||||
and not fnmatch(
|
||||
document.original_file,
|
||||
match_against,
|
||||
trigger.filter_path,
|
||||
)
|
||||
):
|
||||
|
@@ -17,7 +17,7 @@ def move_sender_strings_to_sender_model(apps, schema_editor):
|
||||
if document.sender:
|
||||
(
|
||||
DOCUMENT_SENDER_MAP[document.pk],
|
||||
created,
|
||||
_,
|
||||
) = sender_model.objects.get_or_create(
|
||||
name=document.sender,
|
||||
defaults={"slug": slugify(document.sender)},
|
||||
|
@@ -0,0 +1,159 @@
|
||||
# Generated by Django 5.2.6 on 2025-09-12 18:42
|
||||
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1070_customfieldinstance_value_long_text_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_ancestors_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Ancestors count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_ancestors_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Ancestors pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_children_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Children count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_children_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Children pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_depth",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(0),
|
||||
django.core.validators.MaxValueValidator(10),
|
||||
],
|
||||
verbose_name="Depth",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_descendants_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Descendants count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_descendants_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Descendants pks",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_index",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Index",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_level",
|
||||
field=models.PositiveIntegerField(
|
||||
default=1,
|
||||
editable=False,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(1),
|
||||
django.core.validators.MaxValueValidator(10),
|
||||
],
|
||||
verbose_name="Level",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_order",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Order",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_parent",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="tn_children",
|
||||
to="documents.tag",
|
||||
verbose_name="Parent",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_priority",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
validators=[
|
||||
django.core.validators.MinValueValidator(0),
|
||||
django.core.validators.MaxValueValidator(9999999999),
|
||||
],
|
||||
verbose_name="Priority",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_siblings_count",
|
||||
field=models.PositiveIntegerField(
|
||||
default=0,
|
||||
editable=False,
|
||||
verbose_name="Siblings count",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="tag",
|
||||
name="tn_siblings_pks",
|
||||
field=models.TextField(
|
||||
blank=True,
|
||||
default="",
|
||||
editable=False,
|
||||
verbose_name="Siblings pks",
|
||||
),
|
||||
),
|
||||
]
|
@@ -7,12 +7,14 @@ from celery import states
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import MaxValueValidator
|
||||
from django.core.validators import MinValueValidator
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from multiselectfield import MultiSelectField
|
||||
from treenode.models import TreeNodeModel
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.registry import auditlog
|
||||
@@ -96,8 +98,10 @@ class Correspondent(MatchingModel):
|
||||
verbose_name_plural = _("correspondents")
|
||||
|
||||
|
||||
class Tag(MatchingModel):
|
||||
class Tag(MatchingModel, TreeNodeModel):
|
||||
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
|
||||
# Maximum allowed nesting depth for tags (root = 1, max depth = 5)
|
||||
MAX_NESTING_DEPTH: Final[int] = 5
|
||||
|
||||
is_inbox_tag = models.BooleanField(
|
||||
_("is inbox tag"),
|
||||
@@ -108,10 +112,30 @@ class Tag(MatchingModel):
|
||||
),
|
||||
)
|
||||
|
||||
class Meta(MatchingModel.Meta):
|
||||
class Meta(MatchingModel.Meta, TreeNodeModel.Meta):
|
||||
verbose_name = _("tag")
|
||||
verbose_name_plural = _("tags")
|
||||
|
||||
def clean(self):
|
||||
# Prevent self-parenting and assigning a descendant as parent
|
||||
parent = self.get_parent()
|
||||
if parent == self:
|
||||
raise ValidationError({"parent": _("Cannot set itself as parent.")})
|
||||
if parent and self.pk is not None and self.is_ancestor_of(parent):
|
||||
raise ValidationError({"parent": _("Cannot set parent to a descendant.")})
|
||||
|
||||
# Enforce maximum nesting depth
|
||||
new_parent_depth = 0
|
||||
if parent:
|
||||
new_parent_depth = parent.get_ancestors_count() + 1
|
||||
|
||||
height = 0 if self.pk is None else self.get_depth()
|
||||
deepest_new_depth = (new_parent_depth + 1) + height
|
||||
if deepest_new_depth > self.MAX_NESTING_DEPTH:
|
||||
raise ValidationError(_("Maximum nesting depth exceeded."))
|
||||
|
||||
return super().clean()
|
||||
|
||||
|
||||
class DocumentType(MatchingModel):
|
||||
class Meta(MatchingModel.Meta):
|
||||
@@ -398,6 +422,15 @@ class Document(SoftDeleteModel, ModelWithOwner):
|
||||
def created_date(self):
|
||||
return self.created
|
||||
|
||||
def add_nested_tags(self, tags):
|
||||
tag_ids = set()
|
||||
for tag in tags:
|
||||
tag_ids.add(tag.id)
|
||||
tag_ids.update(tag.get_ancestors_pks())
|
||||
|
||||
tags_to_add = self.tags.model.objects.filter(id__in=tag_ids)
|
||||
self.tags.add(*tags_to_add)
|
||||
|
||||
|
||||
class SavedView(ModelWithOwner):
|
||||
class DisplayMode(models.TextChoices):
|
||||
|
@@ -6,6 +6,7 @@ import re
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Literal
|
||||
|
||||
import magic
|
||||
from celery import states
|
||||
@@ -13,6 +14,7 @@ from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import DecimalValidator
|
||||
from django.core.validators import MaxLengthValidator
|
||||
from django.core.validators import RegexValidator
|
||||
@@ -251,6 +253,35 @@ class OwnedObjectSerializer(
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]):
|
||||
"""
|
||||
Get the given permissions from context or from django-guardian.
|
||||
|
||||
:param codename: The permission codename, e.g. 'view' or 'change'
|
||||
:param target: 'users' or 'groups'
|
||||
"""
|
||||
key = f"{target}_{codename}_perms"
|
||||
cached = self.context.get(key, {}).get(obj.pk)
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
# Permission not found in the context, get it from guardian
|
||||
if target == "users":
|
||||
return list(
|
||||
get_users_with_perms(
|
||||
obj,
|
||||
only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"],
|
||||
with_group_users=False,
|
||||
).values_list("id", flat=True),
|
||||
)
|
||||
else: # groups
|
||||
return list(
|
||||
get_groups_with_only_permission(
|
||||
obj,
|
||||
codename=f"{codename}_{obj.__class__.__name__.lower()}",
|
||||
).values_list("id", flat=True),
|
||||
)
|
||||
|
||||
@extend_schema_field(
|
||||
field={
|
||||
"type": "object",
|
||||
@@ -285,31 +316,14 @@ class OwnedObjectSerializer(
|
||||
},
|
||||
)
|
||||
def get_permissions(self, obj) -> dict:
|
||||
view_codename = f"view_{obj.__class__.__name__.lower()}"
|
||||
change_codename = f"change_{obj.__class__.__name__.lower()}"
|
||||
|
||||
return {
|
||||
"view": {
|
||||
"users": get_users_with_perms(
|
||||
obj,
|
||||
only_with_perms_in=[view_codename],
|
||||
with_group_users=False,
|
||||
).values_list("id", flat=True),
|
||||
"groups": get_groups_with_only_permission(
|
||||
obj,
|
||||
codename=view_codename,
|
||||
).values_list("id", flat=True),
|
||||
"users": self._get_perms(obj, "view", "users"),
|
||||
"groups": self._get_perms(obj, "view", "groups"),
|
||||
},
|
||||
"change": {
|
||||
"users": get_users_with_perms(
|
||||
obj,
|
||||
only_with_perms_in=[change_codename],
|
||||
with_group_users=False,
|
||||
).values_list("id", flat=True),
|
||||
"groups": get_groups_with_only_permission(
|
||||
obj,
|
||||
codename=change_codename,
|
||||
).values_list("id", flat=True),
|
||||
"users": self._get_perms(obj, "change", "users"),
|
||||
"groups": self._get_perms(obj, "change", "groups"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -540,6 +554,32 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
||||
|
||||
text_color = serializers.SerializerMethodField()
|
||||
|
||||
# map to treenode's tn_parent
|
||||
parent = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Tag.objects.all(),
|
||||
allow_null=True,
|
||||
required=False,
|
||||
source="tn_parent",
|
||||
)
|
||||
|
||||
@extend_schema_field(
|
||||
field=serializers.ListSerializer(
|
||||
child=serializers.PrimaryKeyRelatedField(
|
||||
queryset=Tag.objects.all(),
|
||||
),
|
||||
),
|
||||
)
|
||||
def get_children(self, obj):
|
||||
serializer = TagSerializer(
|
||||
obj.get_children(),
|
||||
many=True,
|
||||
context=self.context,
|
||||
)
|
||||
return serializer.data
|
||||
|
||||
# children as nested Tag objects
|
||||
children = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Tag
|
||||
fields = (
|
||||
@@ -557,6 +597,8 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
||||
"permissions",
|
||||
"user_can_change",
|
||||
"set_permissions",
|
||||
"parent",
|
||||
"children",
|
||||
)
|
||||
|
||||
def validate_color(self, color):
|
||||
@@ -565,6 +607,36 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
|
||||
raise serializers.ValidationError(_("Invalid color."))
|
||||
return color
|
||||
|
||||
def validate(self, attrs):
|
||||
# Validate when changing parent
|
||||
parent = attrs.get(
|
||||
"tn_parent",
|
||||
self.instance.get_parent() if self.instance else None,
|
||||
)
|
||||
|
||||
if self.instance:
|
||||
# Temporarily set parent on the instance if updating and use model clean()
|
||||
original_parent = self.instance.get_parent()
|
||||
try:
|
||||
# Temporarily set tn_parent in-memory to validate clean()
|
||||
self.instance.tn_parent = parent
|
||||
self.instance.clean()
|
||||
except ValidationError as e:
|
||||
logger.debug("Tag parent validation failed: %s", e)
|
||||
raise e
|
||||
finally:
|
||||
self.instance.tn_parent = original_parent
|
||||
else:
|
||||
# For new instances, create a transient Tag and validate
|
||||
temp = Tag(tn_parent=parent)
|
||||
try:
|
||||
temp.clean()
|
||||
except ValidationError as e:
|
||||
logger.debug("Tag parent validation failed: %s", e)
|
||||
raise serializers.ValidationError({"parent": _("Invalid parent tag.")})
|
||||
|
||||
return super().validate(attrs)
|
||||
|
||||
|
||||
class CorrespondentField(serializers.PrimaryKeyRelatedField):
|
||||
def get_queryset(self):
|
||||
@@ -1028,6 +1100,28 @@ class DocumentSerializer(
|
||||
custom_field_instance.field,
|
||||
doc_id,
|
||||
)
|
||||
if "tags" in validated_data:
|
||||
# Respect tag hierarchy on updates:
|
||||
# - Adding a child adds its ancestors
|
||||
# - Removing a parent removes all its descendants
|
||||
prev_tags = set(instance.tags.all())
|
||||
requested_tags = set(validated_data["tags"])
|
||||
|
||||
# Tags being removed in this update and all descendants
|
||||
removed_tags = prev_tags - requested_tags
|
||||
blocked_tags = set(removed_tags)
|
||||
for t in removed_tags:
|
||||
blocked_tags.update(t.get_descendants())
|
||||
|
||||
# Add all parent tags
|
||||
final_tags = set(requested_tags)
|
||||
for t in requested_tags:
|
||||
final_tags.update(t.get_ancestors())
|
||||
|
||||
# Drop removed parents and their descendants
|
||||
final_tags.difference_update(blocked_tags)
|
||||
|
||||
validated_data["tags"] = list(final_tags)
|
||||
if validated_data.get("remove_inbox_tags"):
|
||||
tag_ids_being_added = (
|
||||
[
|
||||
@@ -1668,9 +1762,8 @@ class PostDocumentSerializer(serializers.Serializer):
|
||||
max_value=Document.ARCHIVE_SERIAL_NUMBER_MAX,
|
||||
)
|
||||
|
||||
custom_fields = serializers.PrimaryKeyRelatedField(
|
||||
many=True,
|
||||
queryset=CustomField.objects.all(),
|
||||
# Accept either a list of custom field ids or a dict mapping id -> value
|
||||
custom_fields = serializers.JSONField(
|
||||
label="Custom fields",
|
||||
write_only=True,
|
||||
required=False,
|
||||
@@ -1727,11 +1820,60 @@ class PostDocumentSerializer(serializers.Serializer):
|
||||
return None
|
||||
|
||||
def validate_custom_fields(self, custom_fields):
|
||||
if custom_fields:
|
||||
return [custom_field.id for custom_field in custom_fields]
|
||||
else:
|
||||
if not custom_fields:
|
||||
return None
|
||||
|
||||
# Normalize single values to a list
|
||||
if isinstance(custom_fields, int):
|
||||
custom_fields = [custom_fields]
|
||||
if isinstance(custom_fields, dict):
|
||||
custom_field_serializer = CustomFieldInstanceSerializer()
|
||||
normalized = {}
|
||||
for field_id, value in custom_fields.items():
|
||||
try:
|
||||
field_id_int = int(field_id)
|
||||
except (TypeError, ValueError):
|
||||
raise serializers.ValidationError(
|
||||
_("Custom field id must be an integer: %(id)s")
|
||||
% {"id": field_id},
|
||||
)
|
||||
try:
|
||||
field = CustomField.objects.get(id=field_id_int)
|
||||
except CustomField.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
_("Custom field with id %(id)s does not exist")
|
||||
% {"id": field_id_int},
|
||||
)
|
||||
custom_field_serializer.validate(
|
||||
{
|
||||
"field": field,
|
||||
"value": value,
|
||||
},
|
||||
)
|
||||
normalized[field_id_int] = value
|
||||
return normalized
|
||||
elif isinstance(custom_fields, list):
|
||||
try:
|
||||
ids = [int(i) for i in custom_fields]
|
||||
except (TypeError, ValueError):
|
||||
raise serializers.ValidationError(
|
||||
_(
|
||||
"Custom fields must be a list of integers or an object mapping ids to values.",
|
||||
),
|
||||
)
|
||||
if CustomField.objects.filter(id__in=ids).count() != len(set(ids)):
|
||||
raise serializers.ValidationError(
|
||||
_("Some custom fields don't exist or were specified twice."),
|
||||
)
|
||||
return ids
|
||||
raise serializers.ValidationError(
|
||||
_(
|
||||
"Custom fields must be a list of integers or an object mapping ids to values.",
|
||||
),
|
||||
)
|
||||
|
||||
# custom_fields_w_values handled via validate_custom_fields
|
||||
|
||||
def validate_created(self, created):
|
||||
# support datetime format for created for backwards compatibility
|
||||
if isinstance(created, datetime):
|
||||
|
@@ -73,7 +73,7 @@ def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
|
||||
else:
|
||||
tags = Tag.objects.all()
|
||||
inbox_tags = tags.filter(is_inbox_tag=True)
|
||||
document.tags.add(*inbox_tags)
|
||||
document.add_nested_tags(inbox_tags)
|
||||
|
||||
|
||||
def _suggestion_printer(
|
||||
@@ -262,7 +262,7 @@ def set_tags(
|
||||
extra={"group": logging_group},
|
||||
)
|
||||
|
||||
document.tags.add(*relevant_tags)
|
||||
document.add_nested_tags(relevant_tags)
|
||||
|
||||
|
||||
def set_storage_path(
|
||||
@@ -778,14 +778,17 @@ def run_workflows(
|
||||
|
||||
def assignment_action():
|
||||
if action.assign_tags.exists():
|
||||
tag_ids_to_add: set[int] = set()
|
||||
for tag in action.assign_tags.all():
|
||||
tag_ids_to_add.add(tag.pk)
|
||||
tag_ids_to_add.update(int(pk) for pk in tag.get_ancestors_pks())
|
||||
|
||||
if not use_overrides:
|
||||
doc_tag_ids.extend(action.assign_tags.values_list("pk", flat=True))
|
||||
doc_tag_ids[:] = list(set(doc_tag_ids) | tag_ids_to_add)
|
||||
else:
|
||||
if overrides.tag_ids is None:
|
||||
overrides.tag_ids = []
|
||||
overrides.tag_ids.extend(
|
||||
action.assign_tags.values_list("pk", flat=True),
|
||||
)
|
||||
overrides.tag_ids = list(set(overrides.tag_ids) | tag_ids_to_add)
|
||||
|
||||
if action.assign_correspondent:
|
||||
if not use_overrides:
|
||||
@@ -928,14 +931,17 @@ def run_workflows(
|
||||
else:
|
||||
overrides.tag_ids = None
|
||||
else:
|
||||
tag_ids_to_remove: set[int] = set()
|
||||
for tag in action.remove_tags.all():
|
||||
tag_ids_to_remove.add(tag.pk)
|
||||
tag_ids_to_remove.update(int(pk) for pk in tag.get_descendants_pks())
|
||||
|
||||
if not use_overrides:
|
||||
for tag in action.remove_tags.filter(
|
||||
pk__in=document.tags.values_list("pk", flat=True),
|
||||
):
|
||||
doc_tag_ids.remove(tag.pk)
|
||||
doc_tag_ids[:] = [t for t in doc_tag_ids if t not in tag_ids_to_remove]
|
||||
elif overrides.tag_ids:
|
||||
for tag in action.remove_tags.filter(pk__in=overrides.tag_ids):
|
||||
overrides.tag_ids.remove(tag.pk)
|
||||
overrides.tag_ids = [
|
||||
t for t in overrides.tag_ids if t not in tag_ids_to_remove
|
||||
]
|
||||
|
||||
if not use_overrides and (
|
||||
action.remove_all_correspondents
|
||||
|
@@ -532,6 +532,54 @@ def check_scheduled_workflows():
|
||||
)
|
||||
|
||||
|
||||
def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
|
||||
"""
|
||||
When a tag's parent changes, ensure all documents containing the tag also have
|
||||
the parent tag (and its ancestors) applied.
|
||||
"""
|
||||
doc_tag_relationship = Document.tags.through
|
||||
|
||||
doc_ids: list[int] = list(
|
||||
Document.objects.filter(tags=tag).values_list("pk", flat=True),
|
||||
)
|
||||
|
||||
if not doc_ids:
|
||||
return
|
||||
|
||||
parent_ids = [new_parent.id, *new_parent.get_ancestors_pks()]
|
||||
|
||||
parent_ids = list(dict.fromkeys(parent_ids))
|
||||
|
||||
existing_pairs = set(
|
||||
doc_tag_relationship.objects.filter(
|
||||
document_id__in=doc_ids,
|
||||
tag_id__in=parent_ids,
|
||||
).values_list("document_id", "tag_id"),
|
||||
)
|
||||
|
||||
to_create: list = []
|
||||
affected: set[int] = set()
|
||||
|
||||
for doc_id in doc_ids:
|
||||
for parent_id in parent_ids:
|
||||
if (doc_id, parent_id) in existing_pairs:
|
||||
continue
|
||||
|
||||
to_create.append(
|
||||
doc_tag_relationship(document_id=doc_id, tag_id=parent_id),
|
||||
)
|
||||
affected.add(doc_id)
|
||||
|
||||
if to_create:
|
||||
doc_tag_relationship.objects.bulk_create(
|
||||
to_create,
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
if affected:
|
||||
bulk_update_documents.delay(document_ids=list(affected))
|
||||
|
||||
|
||||
@shared_task
|
||||
def llmindex_index(
|
||||
*,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import User
|
||||
@@ -7,7 +8,9 @@ from django.utils import timezone
|
||||
|
||||
from documents import index
|
||||
from documents.admin import DocumentAdmin
|
||||
from documents.admin import TagAdmin
|
||||
from documents.models import Document
|
||||
from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless.admin import PaperlessUserAdmin
|
||||
|
||||
@@ -70,6 +73,24 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(self.doc_admin.created_(doc), "2020-04-12")
|
||||
|
||||
|
||||
class TestTagAdmin(DirectoriesMixin, TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.tag_admin = TagAdmin(model=Tag, admin_site=AdminSite())
|
||||
|
||||
@patch("documents.tasks.bulk_update_documents")
|
||||
def test_parent_tags_get_added(self, mock_bulk_update):
|
||||
document = Document.objects.create(title="test")
|
||||
parent = Tag.objects.create(name="parent")
|
||||
child = Tag.objects.create(name="child")
|
||||
document.tags.add(child)
|
||||
|
||||
child.tn_parent = parent
|
||||
self.tag_admin.save_model(None, child, None, change=True)
|
||||
document.refresh_from_db()
|
||||
self.assertIn(parent, document.tags.all())
|
||||
|
||||
|
||||
class TestPaperlessAdmin(DirectoriesMixin, TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
@@ -839,7 +839,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
m.assert_called()
|
||||
args, kwargs = m.call_args
|
||||
_, kwargs = m.call_args
|
||||
self.assertEqual(kwargs["merge"], False)
|
||||
|
||||
response = self.client.post(
|
||||
@@ -857,7 +857,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
m.assert_called()
|
||||
args, kwargs = m.call_args
|
||||
_, kwargs = m.call_args
|
||||
self.assertEqual(kwargs["merge"], True)
|
||||
|
||||
@mock.patch("documents.serialisers.bulk_edit.set_storage_path")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
@@ -1528,7 +1529,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
|
||||
|
||||
input_doc, overrides = self.get_last_consume_delay_call_args()
|
||||
|
||||
new_overrides, msg = run_workflows(
|
||||
new_overrides, _ = run_workflows(
|
||||
trigger_type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
|
||||
document=input_doc,
|
||||
logging_group=None,
|
||||
@@ -1537,6 +1538,86 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
|
||||
overrides.update(new_overrides)
|
||||
self.assertEqual(overrides.custom_fields, {cf.id: None, cf2.id: 123})
|
||||
|
||||
def test_upload_with_custom_field_values(self):
|
||||
"""
|
||||
GIVEN: A document with a source file
|
||||
WHEN: Upload the document with custom fields and values
|
||||
THEN: Metadata is set correctly
|
||||
"""
|
||||
self.consume_file_mock.return_value = celery.result.AsyncResult(
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
cf_string = CustomField.objects.create(
|
||||
name="stringfield",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
cf_int = CustomField.objects.create(
|
||||
name="intfield",
|
||||
data_type=CustomField.FieldDataType.INT,
|
||||
)
|
||||
|
||||
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
|
||||
response = self.client.post(
|
||||
"/api/documents/post_document/",
|
||||
{
|
||||
"document": f,
|
||||
"custom_fields": json.dumps(
|
||||
{
|
||||
str(cf_string.id): "a string",
|
||||
str(cf_int.id): 123,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
self.consume_file_mock.assert_called_once()
|
||||
|
||||
input_doc, overrides = self.get_last_consume_delay_call_args()
|
||||
|
||||
self.assertEqual(input_doc.original_file.name, "simple.pdf")
|
||||
self.assertEqual(overrides.filename, "simple.pdf")
|
||||
self.assertEqual(
|
||||
overrides.custom_fields,
|
||||
{cf_string.id: "a string", cf_int.id: 123},
|
||||
)
|
||||
|
||||
def test_upload_with_custom_fields_errors(self):
|
||||
"""
|
||||
GIVEN: A document with a source file
|
||||
WHEN: Upload the document with invalid custom fields payloads
|
||||
THEN: The upload is rejected
|
||||
"""
|
||||
self.consume_file_mock.return_value = celery.result.AsyncResult(
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
error_payloads = [
|
||||
# Non-integer key in mapping
|
||||
{"custom_fields": json.dumps({"abc": "a string"})},
|
||||
# List with non-integer entry
|
||||
{"custom_fields": json.dumps(["abc"])},
|
||||
# Nonexistent id in mapping
|
||||
{"custom_fields": json.dumps({99999999: "a string"})},
|
||||
# Nonexistent id in list
|
||||
{"custom_fields": json.dumps([99999999])},
|
||||
# Invalid type (JSON string, not list/dict/int)
|
||||
{"custom_fields": json.dumps("not-a-supported-structure")},
|
||||
]
|
||||
|
||||
for payload in error_payloads:
|
||||
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
|
||||
data = {"document": f, **payload}
|
||||
response = self.client.post(
|
||||
"/api/documents/post_document/",
|
||||
data,
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
self.consume_file_mock.assert_not_called()
|
||||
|
||||
def test_upload_with_webui_source(self):
|
||||
"""
|
||||
GIVEN: A document with a source file
|
||||
@@ -1557,7 +1638,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
|
||||
|
||||
self.consume_file_mock.assert_called_once()
|
||||
|
||||
input_doc, overrides = self.get_last_consume_delay_call_args()
|
||||
input_doc, _ = self.get_last_consume_delay_call_args()
|
||||
|
||||
self.assertEqual(input_doc.source, WorkflowTrigger.DocumentSourceChoices.WEB_UI)
|
||||
|
||||
|
@@ -614,14 +614,16 @@ class TestBarcodeNewConsume(
|
||||
self.assertIsNotFile(temp_copy)
|
||||
|
||||
# Check the split files exist
|
||||
# Check the original_path is set
|
||||
# Check the source is unchanged
|
||||
# Check the overrides are unchanged
|
||||
for (
|
||||
new_input_doc,
|
||||
new_doc_overrides,
|
||||
) in self.get_all_consume_delay_call_args():
|
||||
self.assertEqual(new_input_doc.source, DocumentSource.ConsumeFolder)
|
||||
self.assertIsFile(new_input_doc.original_file)
|
||||
self.assertEqual(new_input_doc.original_path, temp_copy)
|
||||
self.assertEqual(new_input_doc.source, DocumentSource.ConsumeFolder)
|
||||
self.assertEqual(overrides, new_doc_overrides)
|
||||
|
||||
|
||||
|
@@ -74,7 +74,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
)
|
||||
self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 3)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
|
||||
|
||||
def test_unset_correspondent(self):
|
||||
@@ -82,7 +82,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
bulk_edit.set_correspondent([self.doc1.id, self.doc2.id, self.doc3.id], None)
|
||||
self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 0)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
|
||||
|
||||
def test_set_document_type(self):
|
||||
@@ -93,7 +93,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
)
|
||||
self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 3)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
|
||||
|
||||
def test_unset_document_type(self):
|
||||
@@ -101,7 +101,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
bulk_edit.set_document_type([self.doc1.id, self.doc2.id, self.doc3.id], None)
|
||||
self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 0)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
|
||||
|
||||
def test_set_document_storage_path(self):
|
||||
@@ -123,7 +123,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(Document.objects.filter(storage_path=None).count(), 4)
|
||||
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id])
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(Document.objects.filter(storage_path=None).count(), 5)
|
||||
|
||||
self.async_task.assert_called()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id])
|
||||
|
||||
@@ -166,7 +166,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
)
|
||||
self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 4)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc3.id])
|
||||
|
||||
def test_remove_tag(self):
|
||||
@@ -174,7 +174,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
bulk_edit.remove_tag([self.doc1.id, self.doc3.id, self.doc4.id], self.t1.id)
|
||||
self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 1)
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc4.id])
|
||||
|
||||
def test_modify_tags(self):
|
||||
@@ -191,7 +191,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
self.assertCountEqual(list(self.doc3.tags.all()), [self.t2, tag_unrelated])
|
||||
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
# TODO: doc3 should not be affected, but the query for that is rather complicated
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
|
||||
|
||||
@@ -248,7 +248,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
|
||||
|
||||
def test_modify_custom_fields_with_values(self):
|
||||
@@ -325,7 +325,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.async_task.assert_called_once()
|
||||
args, kwargs = self.async_task.call_args
|
||||
_, kwargs = self.async_task.call_args
|
||||
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
|
||||
|
||||
# removal of document link cf, should also remove symmetric link
|
||||
|
@@ -209,6 +209,26 @@ class TestConsumer(DirectoriesMixin, ConsumerThreadMixin, TransactionTestCase):
|
||||
# assert that we have an error logged with this invalid file.
|
||||
error_logger.assert_called_once()
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.warning")
|
||||
def test_permission_error_on_prechecks(self, warning_logger):
|
||||
filepath = Path(self.dirs.consumption_dir) / "selinux.txt"
|
||||
filepath.touch()
|
||||
|
||||
original_stat = Path.stat
|
||||
|
||||
def raising_stat(self, *args, **kwargs):
|
||||
if self == filepath:
|
||||
raise PermissionError("Permission denied")
|
||||
return original_stat(self, *args, **kwargs)
|
||||
|
||||
with mock.patch("pathlib.Path.stat", new=raising_stat):
|
||||
document_consumer._consume(filepath)
|
||||
|
||||
warning_logger.assert_called_once()
|
||||
(args, _) = warning_logger.call_args
|
||||
self.assertIn("Permission denied", args[0])
|
||||
self.consume_file_mock.assert_not_called()
|
||||
|
||||
@override_settings(CONSUMPTION_DIR="does_not_exist")
|
||||
def test_consumption_directory_invalid(self):
|
||||
self.assertRaises(CommandError, call_command, "document_consumer", "--oneshot")
|
||||
|
@@ -206,3 +206,29 @@ class TestFuzzyMatchCommand(TestCase):
|
||||
self.assertEqual(Document.objects.count(), 2)
|
||||
self.assertIsNotNone(Document.objects.get(pk=1))
|
||||
self.assertIsNotNone(Document.objects.get(pk=2))
|
||||
|
||||
def test_empty_content(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- 2 documents exist, content is empty (pw-protected)
|
||||
WHEN:
|
||||
- Command is called
|
||||
THEN:
|
||||
- No matches are found
|
||||
"""
|
||||
Document.objects.create(
|
||||
checksum="BEEFCAFE",
|
||||
title="A",
|
||||
content="",
|
||||
mime_type="application/pdf",
|
||||
filename="test.pdf",
|
||||
)
|
||||
Document.objects.create(
|
||||
checksum="DEADBEAF",
|
||||
title="A",
|
||||
content="",
|
||||
mime_type="application/pdf",
|
||||
filename="other_test.pdf",
|
||||
)
|
||||
stdout, _ = self.call_command()
|
||||
self.assertIn("No matches found", stdout)
|
||||
|
@@ -123,14 +123,14 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
|
||||
def test_add_type(self):
|
||||
call_command("document_retagger", "--document_type")
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.document_type, self.doctype_first)
|
||||
self.assertEqual(d_second.document_type, self.doctype_second)
|
||||
|
||||
def test_add_correspondent(self):
|
||||
call_command("document_retagger", "--correspondent")
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.correspondent, self.correspondent_first)
|
||||
self.assertEqual(d_second.correspondent, self.correspondent_second)
|
||||
@@ -160,7 +160,7 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
|
||||
def test_add_tags_suggest(self):
|
||||
call_command("document_retagger", "--tags", "--suggest")
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, d_auto = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.tags.count(), 0)
|
||||
self.assertEqual(d_second.tags.count(), 0)
|
||||
@@ -168,14 +168,14 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
|
||||
def test_add_type_suggest(self):
|
||||
call_command("document_retagger", "--document_type", "--suggest")
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertIsNone(d_first.document_type)
|
||||
self.assertIsNone(d_second.document_type)
|
||||
|
||||
def test_add_correspondent_suggest(self):
|
||||
call_command("document_retagger", "--correspondent", "--suggest")
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertIsNone(d_first.correspondent)
|
||||
self.assertIsNone(d_second.correspondent)
|
||||
@@ -187,7 +187,7 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
"--suggest",
|
||||
"--base-url=http://localhost",
|
||||
)
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, d_auto = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.tags.count(), 0)
|
||||
self.assertEqual(d_second.tags.count(), 0)
|
||||
@@ -200,7 +200,7 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
"--suggest",
|
||||
"--base-url=http://localhost",
|
||||
)
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertIsNone(d_first.document_type)
|
||||
self.assertIsNone(d_second.document_type)
|
||||
@@ -212,7 +212,7 @@ class TestRetagger(DirectoriesMixin, TestCase):
|
||||
"--suggest",
|
||||
"--base-url=http://localhost",
|
||||
)
|
||||
d_first, d_second, d_unrelated, d_auto = self.get_updated_docs()
|
||||
d_first, d_second, _, _ = self.get_updated_docs()
|
||||
|
||||
self.assertIsNone(d_first.correspondent)
|
||||
self.assertIsNone(d_second.correspondent)
|
||||
|
@@ -4,6 +4,7 @@ import shutil
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.test import override_settings
|
||||
|
||||
@@ -281,6 +282,7 @@ class TestMigrateArchiveFilesErrors(DirectoriesMixin, TestMigrations):
|
||||
migrate_to = "1012_fix_archive_files"
|
||||
auto_migrate = False
|
||||
|
||||
@pytest.mark.skip(reason="Fails with migration tearDown util. Needs investigation.")
|
||||
def test_archive_missing(self):
|
||||
Document = self.apps.get_model("documents", "Document")
|
||||
|
||||
@@ -300,6 +302,7 @@ class TestMigrateArchiveFilesErrors(DirectoriesMixin, TestMigrations):
|
||||
self.performMigration,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="Fails with migration tearDown util. Needs investigation.")
|
||||
def test_parser_missing(self):
|
||||
Document = self.apps.get_model("documents", "Document")
|
||||
|
||||
|
205
src/documents/tests/test_tag_hierarchy.py
Normal file
205
src/documents/tests/test_tag_hierarchy.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from unittest import mock
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents import bulk_edit
|
||||
from documents.models import Document
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
from documents.signals.handlers import run_workflows
|
||||
|
||||
|
||||
class TestTagHierarchy(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_superuser(username="admin")
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.parent = Tag.objects.create(name="Parent")
|
||||
self.child = Tag.objects.create(name="Child", tn_parent=self.parent)
|
||||
|
||||
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
|
||||
self.async_task = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
self.document = Document.objects.create(
|
||||
title="doc",
|
||||
content="",
|
||||
checksum="1",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
def test_document_api_add_child_adds_parent(self):
|
||||
self.client.patch(
|
||||
f"/api/documents/{self.document.pk}/",
|
||||
{"tags": [self.child.pk]},
|
||||
format="json",
|
||||
)
|
||||
self.document.refresh_from_db()
|
||||
tags = set(self.document.tags.values_list("pk", flat=True))
|
||||
assert tags == {self.parent.pk, self.child.pk}
|
||||
|
||||
def test_document_api_remove_parent_removes_children(self):
|
||||
self.document.add_nested_tags([self.parent, self.child])
|
||||
self.client.patch(
|
||||
f"/api/documents/{self.document.pk}/",
|
||||
{"tags": [self.child.pk]},
|
||||
format="json",
|
||||
)
|
||||
self.document.refresh_from_db()
|
||||
assert self.document.tags.count() == 0
|
||||
|
||||
def test_document_api_remove_parent_removes_child(self):
|
||||
self.document.add_nested_tags([self.child])
|
||||
self.client.patch(
|
||||
f"/api/documents/{self.document.pk}/",
|
||||
{"tags": []},
|
||||
format="json",
|
||||
)
|
||||
self.document.refresh_from_db()
|
||||
assert self.document.tags.count() == 0
|
||||
|
||||
def test_bulk_edit_respects_hierarchy(self):
|
||||
bulk_edit.add_tag([self.document.pk], self.child.pk)
|
||||
self.document.refresh_from_db()
|
||||
tags = set(self.document.tags.values_list("pk", flat=True))
|
||||
assert tags == {self.parent.pk, self.child.pk}
|
||||
|
||||
bulk_edit.remove_tag([self.document.pk], self.parent.pk)
|
||||
self.document.refresh_from_db()
|
||||
assert self.document.tags.count() == 0
|
||||
|
||||
bulk_edit.modify_tags([self.document.pk], [self.child.pk], [])
|
||||
self.document.refresh_from_db()
|
||||
tags = set(self.document.tags.values_list("pk", flat=True))
|
||||
assert tags == {self.parent.pk, self.child.pk}
|
||||
|
||||
bulk_edit.modify_tags([self.document.pk], [], [self.parent.pk])
|
||||
self.document.refresh_from_db()
|
||||
assert self.document.tags.count() == 0
|
||||
|
||||
def test_workflow_actions(self):
|
||||
workflow = Workflow.objects.create(name="wf", order=0)
|
||||
trigger = WorkflowTrigger.objects.create(
|
||||
type=WorkflowTrigger.WorkflowTriggerType.DOCUMENT_ADDED,
|
||||
)
|
||||
assign_action = WorkflowAction.objects.create()
|
||||
assign_action.assign_tags.add(self.child)
|
||||
workflow.triggers.add(trigger)
|
||||
workflow.actions.add(assign_action)
|
||||
|
||||
run_workflows(trigger.type, self.document)
|
||||
self.document.refresh_from_db()
|
||||
tags = set(self.document.tags.values_list("pk", flat=True))
|
||||
assert tags == {self.parent.pk, self.child.pk}
|
||||
|
||||
# removal
|
||||
removal_action = WorkflowAction.objects.create(
|
||||
type=WorkflowAction.WorkflowActionType.REMOVAL,
|
||||
)
|
||||
removal_action.remove_tags.add(self.parent)
|
||||
workflow.actions.clear()
|
||||
workflow.actions.add(removal_action)
|
||||
|
||||
run_workflows(trigger.type, self.document)
|
||||
self.document.refresh_from_db()
|
||||
assert self.document.tags.count() == 0
|
||||
|
||||
def test_tag_view_parent_update_adds_parent_to_docs(self):
|
||||
orphan = Tag.objects.create(name="Orphan")
|
||||
self.document.tags.add(orphan)
|
||||
|
||||
self.client.patch(
|
||||
f"/api/tags/{orphan.pk}/",
|
||||
{"parent": self.parent.pk},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.document.refresh_from_db()
|
||||
tags = set(self.document.tags.values_list("pk", flat=True))
|
||||
assert tags == {self.parent.pk, orphan.pk}
|
||||
|
||||
def test_cannot_set_parent_to_self(self):
|
||||
tag = Tag.objects.create(name="Selfie")
|
||||
resp = self.client.patch(
|
||||
f"/api/tags/{tag.pk}/",
|
||||
{"parent": tag.pk},
|
||||
format="json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "Cannot set itself as parent" in str(resp.data["parent"])
|
||||
|
||||
def test_cannot_set_parent_to_descendant(self):
|
||||
a = Tag.objects.create(name="A")
|
||||
b = Tag.objects.create(name="B", tn_parent=a)
|
||||
c = Tag.objects.create(name="C", tn_parent=b)
|
||||
|
||||
# Attempt to set A's parent to C (descendant) should fail
|
||||
resp = self.client.patch(
|
||||
f"/api/tags/{a.pk}/",
|
||||
{"parent": c.pk},
|
||||
format="json",
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "Cannot set parent to a descendant" in str(resp.data["parent"])
|
||||
|
||||
def test_max_depth_on_create(self):
|
||||
a = Tag.objects.create(name="A1")
|
||||
b = Tag.objects.create(name="B1", tn_parent=a)
|
||||
c = Tag.objects.create(name="C1", tn_parent=b)
|
||||
d = Tag.objects.create(name="D1", tn_parent=c)
|
||||
|
||||
# Creating E under D yields depth 5: allowed
|
||||
resp_ok = self.client.post(
|
||||
"/api/tags/",
|
||||
{"name": "E1", "parent": d.pk},
|
||||
format="json",
|
||||
)
|
||||
assert resp_ok.status_code in (200, 201)
|
||||
e_id = (
|
||||
resp_ok.data["id"] if resp_ok.status_code == 201 else resp_ok.data.get("id")
|
||||
)
|
||||
assert e_id is not None
|
||||
|
||||
# Creating F under E would yield depth 6: rejected
|
||||
resp_fail = self.client.post(
|
||||
"/api/tags/",
|
||||
{"name": "F1", "parent": e_id},
|
||||
format="json",
|
||||
)
|
||||
assert resp_fail.status_code == 400
|
||||
assert "parent" in resp_fail.data
|
||||
assert "Invalid" in str(resp_fail.data["parent"])
|
||||
|
||||
def test_max_depth_on_move_subtree(self):
|
||||
a = Tag.objects.create(name="A2")
|
||||
b = Tag.objects.create(name="B2", tn_parent=a)
|
||||
c = Tag.objects.create(name="C2", tn_parent=b)
|
||||
d = Tag.objects.create(name="D2", tn_parent=c)
|
||||
|
||||
x = Tag.objects.create(name="X2")
|
||||
y = Tag.objects.create(name="Y2", tn_parent=x)
|
||||
assert y.parent_pk == x.pk
|
||||
|
||||
# Moving X under D would make deepest node Y exceed depth 5 -> reject
|
||||
resp_fail = self.client.patch(
|
||||
f"/api/tags/{x.pk}/",
|
||||
{"parent": d.pk},
|
||||
format="json",
|
||||
)
|
||||
assert resp_fail.status_code == 400
|
||||
assert "Maximum nesting depth exceeded" in str(
|
||||
resp_fail.data["non_field_errors"],
|
||||
)
|
||||
|
||||
# Moving X under C (depth 3) should be allowed (deepest becomes 5)
|
||||
resp_ok = self.client.patch(
|
||||
f"/api/tags/{x.pk}/",
|
||||
{"parent": c.pk},
|
||||
format="json",
|
||||
)
|
||||
assert resp_ok.status_code in (200, 202)
|
||||
x.refresh_from_db()
|
||||
assert x.parent_pk == c.id
|
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
@@ -5,11 +6,15 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import connection
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
from django.utils import timezone
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
@@ -164,6 +169,116 @@ class TestViews(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
|
||||
self.assertContains(response, b"Share link has expired")
|
||||
|
||||
def test_list_with_full_permissions(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Tags with different permissions
|
||||
WHEN:
|
||||
- Request to get tag list with full permissions is made
|
||||
THEN:
|
||||
- Tag list is returned with the right permission information
|
||||
"""
|
||||
user2 = User.objects.create(username="user2")
|
||||
user3 = User.objects.create(username="user3")
|
||||
group1 = Group.objects.create(name="group1")
|
||||
group2 = Group.objects.create(name="group2")
|
||||
group3 = Group.objects.create(name="group3")
|
||||
t1 = Tag.objects.create(name="invoice", pk=1)
|
||||
assign_perm("view_tag", self.user, t1)
|
||||
assign_perm("view_tag", user2, t1)
|
||||
assign_perm("view_tag", user3, t1)
|
||||
assign_perm("view_tag", group1, t1)
|
||||
assign_perm("view_tag", group2, t1)
|
||||
assign_perm("view_tag", group3, t1)
|
||||
assign_perm("change_tag", self.user, t1)
|
||||
assign_perm("change_tag", user2, t1)
|
||||
assign_perm("change_tag", group1, t1)
|
||||
assign_perm("change_tag", group2, t1)
|
||||
|
||||
Tag.objects.create(name="bank statement", pk=2)
|
||||
d1 = Document.objects.create(
|
||||
title="Invoice 1",
|
||||
content="This is the invoice of a very expensive item",
|
||||
checksum="A",
|
||||
)
|
||||
d1.tags.add(t1)
|
||||
d2 = Document.objects.create(
|
||||
title="Invoice 2",
|
||||
content="Internet invoice, I should pay it to continue contributing",
|
||||
checksum="B",
|
||||
)
|
||||
d2.tags.add(t1)
|
||||
|
||||
view_permissions = Permission.objects.filter(
|
||||
codename__contains="view_tag",
|
||||
)
|
||||
self.user.user_permissions.add(*view_permissions)
|
||||
self.user.save()
|
||||
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get("/api/tags/?page=1&full_perms=true")
|
||||
results = json.loads(response.content)["results"]
|
||||
for tag in results:
|
||||
if tag["name"] == "invoice":
|
||||
assert tag["permissions"] == {
|
||||
"view": {
|
||||
"users": [self.user.pk, user2.pk, user3.pk],
|
||||
"groups": [group1.pk, group2.pk, group3.pk],
|
||||
},
|
||||
"change": {
|
||||
"users": [self.user.pk, user2.pk],
|
||||
"groups": [group1.pk, group2.pk],
|
||||
},
|
||||
}
|
||||
elif tag["name"] == "bank statement":
|
||||
assert tag["permissions"] == {
|
||||
"view": {"users": [], "groups": []},
|
||||
"change": {"users": [], "groups": []},
|
||||
}
|
||||
else:
|
||||
assert False, f"Unexpected tag found: {tag['name']}"
|
||||
|
||||
def test_list_no_n_plus_1_queries(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Tags with different permissions
|
||||
WHEN:
|
||||
- Request to get tag list with full permissions is made
|
||||
THEN:
|
||||
- Permissions are not queries in database tag by tag,
|
||||
i.e. there are no N+1 queries
|
||||
"""
|
||||
view_permissions = Permission.objects.filter(
|
||||
codename__contains="view_tag",
|
||||
)
|
||||
self.user.user_permissions.add(*view_permissions)
|
||||
self.user.save()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
# Start by a small list, and count the number of SQL queries
|
||||
for i in range(2):
|
||||
Tag.objects.create(name=f"tag_{i}")
|
||||
|
||||
with CaptureQueriesContext(connection) as ctx_small:
|
||||
response_small = self.client.get("/api/tags/?full_perms=true")
|
||||
assert response_small.status_code == 200
|
||||
num_queries_small = len(ctx_small.captured_queries)
|
||||
|
||||
# Complete the list, and count the number of SQL queries again
|
||||
for i in range(2, 50):
|
||||
Tag.objects.create(name=f"tag_{i}")
|
||||
|
||||
with CaptureQueriesContext(connection) as ctx_large:
|
||||
response_large = self.client.get("/api/tags/?full_perms=true")
|
||||
assert response_large.status_code == 200
|
||||
num_queries_large = len(ctx_large.captured_queries)
|
||||
|
||||
# A few additional queries are allowed, but not a linear explosion
|
||||
assert num_queries_large <= num_queries_small + 5, (
|
||||
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
|
||||
f"but {num_queries_large} queries for 50 tags"
|
||||
)
|
||||
|
||||
|
||||
class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
def setUp(self):
|
||||
|
@@ -327,6 +327,19 @@ class TestMigrations(TransactionTestCase):
|
||||
def setUpBeforeMigration(self, apps):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
"""
|
||||
Ensure the database schema is restored to the latest migration after
|
||||
each migration test, so subsequent tests run against HEAD.
|
||||
"""
|
||||
try:
|
||||
executor = MigrationExecutor(connection)
|
||||
executor.loader.build_graph()
|
||||
targets = executor.loader.graph.leaf_nodes()
|
||||
executor.migrate(targets)
|
||||
finally:
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class SampleDirMixin:
|
||||
SAMPLE_DIR = Path(__file__).parent / "samples"
|
||||
|
@@ -6,9 +6,11 @@ import platform
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from time import mktime
|
||||
from typing import Literal
|
||||
from unicodedata import normalize
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
@@ -21,6 +23,7 @@ from django.conf import settings
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import connections
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
@@ -60,6 +63,8 @@ from drf_spectacular.utils import OpenApiParameter
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from drf_spectacular.utils import extend_schema_view
|
||||
from drf_spectacular.utils import inline_serializer
|
||||
from guardian.utils import get_group_obj_perms_model
|
||||
from guardian.utils import get_user_obj_perms_model
|
||||
from langdetect import detect
|
||||
from packaging import version as packaging_version
|
||||
from redis import Redis
|
||||
@@ -175,6 +180,7 @@ from documents.tasks import empty_trash
|
||||
from documents.tasks import index_optimize
|
||||
from documents.tasks import sanity_check
|
||||
from documents.tasks import train_classifier
|
||||
from documents.tasks import update_document_parent_tags
|
||||
from documents.templating.filepath import validate_filepath_template_and_render
|
||||
from documents.utils import get_boolean
|
||||
from paperless import version
|
||||
@@ -268,7 +274,101 @@ class PassUserMixin(GenericAPIView):
|
||||
return super().get_serializer(*args, **kwargs)
|
||||
|
||||
|
||||
class PermissionsAwareDocumentCountMixin(PassUserMixin):
|
||||
class BulkPermissionMixin:
|
||||
"""
|
||||
Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries.
|
||||
"""
|
||||
|
||||
def get_permission_codenames(self):
|
||||
model_name = self.queryset.model.__name__.lower()
|
||||
return {
|
||||
"view": f"view_{model_name}",
|
||||
"change": f"change_{model_name}",
|
||||
}
|
||||
|
||||
def _get_object_perms(
|
||||
self,
|
||||
objects: list,
|
||||
perm_codenames: list[str],
|
||||
actor: Literal["users", "groups"],
|
||||
) -> dict[int, dict[str, list[int]]]:
|
||||
"""
|
||||
Collect object-level permissions for either users or groups.
|
||||
"""
|
||||
model = self.queryset.model
|
||||
obj_perm_model = (
|
||||
get_user_obj_perms_model(model)
|
||||
if actor == "users"
|
||||
else get_group_obj_perms_model(model)
|
||||
)
|
||||
id_field = "user_id" if actor == "users" else "group_id"
|
||||
ctype = ContentType.objects.get_for_model(model)
|
||||
object_pks = [obj.pk for obj in objects]
|
||||
|
||||
perms_qs = obj_perm_model.objects.filter(
|
||||
content_type=ctype,
|
||||
object_pk__in=object_pks,
|
||||
permission__codename__in=perm_codenames,
|
||||
).values_list("object_pk", id_field, "permission__codename")
|
||||
|
||||
perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list))
|
||||
for object_pk, actor_id, codename in perms_qs:
|
||||
perms[int(object_pk)][codename].append(actor_id)
|
||||
|
||||
# Ensure that all objects have all codenames, even if empty
|
||||
for pk in object_pks:
|
||||
for codename in perm_codenames:
|
||||
perms[pk][codename]
|
||||
|
||||
return perms
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
Get all permissions of the current list of objects at once and pass them to the serializer.
|
||||
This avoid fetching permissions object by object in database.
|
||||
"""
|
||||
context = super().get_serializer_context()
|
||||
try:
|
||||
full_perms = get_boolean(
|
||||
str(self.request.query_params.get("full_perms", "false")),
|
||||
)
|
||||
except ValueError:
|
||||
full_perms = False
|
||||
|
||||
if not full_perms:
|
||||
return context
|
||||
|
||||
# Check which objects are being paginated
|
||||
page = getattr(self, "paginator", None)
|
||||
if page and hasattr(page, "page"):
|
||||
queryset = page.page.object_list
|
||||
elif hasattr(self, "page"):
|
||||
queryset = self.page
|
||||
else:
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
codenames = self.get_permission_codenames()
|
||||
perm_names = [codenames["view"], codenames["change"]]
|
||||
user_perms = self._get_object_perms(queryset, perm_names, actor="users")
|
||||
group_perms = self._get_object_perms(queryset, perm_names, actor="groups")
|
||||
|
||||
context["users_view_perms"] = {
|
||||
pk: user_perms[pk][codenames["view"]] for pk in user_perms
|
||||
}
|
||||
context["users_change_perms"] = {
|
||||
pk: user_perms[pk][codenames["change"]] for pk in user_perms
|
||||
}
|
||||
context["groups_view_perms"] = {
|
||||
pk: group_perms[pk][codenames["view"]] for pk in group_perms
|
||||
}
|
||||
context["groups_change_perms"] = {
|
||||
pk: group_perms[pk][codenames["change"]] for pk in group_perms
|
||||
}
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
|
||||
"""
|
||||
Mixin to add document count to queryset, permissions-aware if needed
|
||||
"""
|
||||
@@ -356,6 +456,13 @@ class TagViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
|
||||
filterset_class = TagFilterSet
|
||||
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
|
||||
|
||||
def perform_update(self, serializer):
|
||||
old_parent = self.get_object().get_parent()
|
||||
tag = serializer.save()
|
||||
new_parent = tag.get_parent()
|
||||
if new_parent and old_parent != new_parent:
|
||||
update_document_parent_tags(tag, new_parent)
|
||||
|
||||
|
||||
@extend_schema_view(**generate_object_with_permissions_schema(DocumentTypeSerializer))
|
||||
class DocumentTypeViewSet(ModelViewSet, PermissionsAwareDocumentCountMixin):
|
||||
@@ -1624,7 +1731,7 @@ class PostDocumentView(GenericAPIView):
|
||||
title = serializer.validated_data.get("title")
|
||||
created = serializer.validated_data.get("created")
|
||||
archive_serial_number = serializer.validated_data.get("archive_serial_number")
|
||||
custom_field_ids = serializer.validated_data.get("custom_fields")
|
||||
cf = serializer.validated_data.get("custom_fields")
|
||||
from_webui = serializer.validated_data.get("from_webui")
|
||||
|
||||
t = int(mktime(datetime.now().timetuple()))
|
||||
@@ -1643,6 +1750,11 @@ class PostDocumentView(GenericAPIView):
|
||||
source=DocumentSource.WebUI if from_webui else DocumentSource.ApiUpload,
|
||||
original_file=temp_file_path,
|
||||
)
|
||||
custom_fields = None
|
||||
if isinstance(cf, dict) and cf:
|
||||
custom_fields = cf
|
||||
elif isinstance(cf, list) and cf:
|
||||
custom_fields = dict.fromkeys(cf, None)
|
||||
input_doc_overrides = DocumentMetadataOverrides(
|
||||
filename=doc_name,
|
||||
title=title,
|
||||
@@ -1653,10 +1765,7 @@ class PostDocumentView(GenericAPIView):
|
||||
created=created,
|
||||
asn=archive_serial_number,
|
||||
owner_id=request.user.id,
|
||||
# TODO: set values
|
||||
custom_fields={cf_id: None for cf_id in custom_field_ids}
|
||||
if custom_field_ids
|
||||
else None,
|
||||
custom_fields=custom_fields,
|
||||
)
|
||||
|
||||
async_task = consume_file.delay(
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -347,6 +347,7 @@ INSTALLED_APPS = [
|
||||
"allauth.mfa",
|
||||
"drf_spectacular",
|
||||
"drf_spectacular_sidecar",
|
||||
"treenode",
|
||||
*env_apps,
|
||||
]
|
||||
|
||||
@@ -954,7 +955,7 @@ CELERY_ACCEPT_CONTENT = ["application/json", "application/x-python-serialize"]
|
||||
CELERY_BEAT_SCHEDULE = _parse_beat_schedule()
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#beat-schedule-filename
|
||||
CELERY_BEAT_SCHEDULE_FILENAME = DATA_DIR / "celerybeat-schedule.db"
|
||||
CELERY_BEAT_SCHEDULE_FILENAME = str(DATA_DIR / "celerybeat-schedule.db")
|
||||
|
||||
|
||||
# Cachalot: Database read cache.
|
||||
|
@@ -21,7 +21,7 @@ TEST_CHANNEL_LAYERS = {
|
||||
class TestWebSockets(TestCase):
|
||||
async def test_no_auth(self):
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertFalse(connected)
|
||||
await communicator.disconnect()
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestWebSockets(TestCase):
|
||||
_authenticated.return_value = True
|
||||
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
message = {"type": "status_update", "data": {"task_id": "test"}}
|
||||
@@ -63,7 +63,7 @@ class TestWebSockets(TestCase):
|
||||
_authenticated.return_value = True
|
||||
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
await communicator.disconnect()
|
||||
@@ -73,7 +73,7 @@ class TestWebSockets(TestCase):
|
||||
_authenticated.return_value = True
|
||||
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
message = {"type": "status_update", "data": {"task_id": "test"}}
|
||||
@@ -98,7 +98,7 @@ class TestWebSockets(TestCase):
|
||||
communicator.scope["user"].is_superuser = False
|
||||
communicator.scope["user"].id = 1
|
||||
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
# Test as owner
|
||||
@@ -141,7 +141,7 @@ class TestWebSockets(TestCase):
|
||||
_authenticated.return_value = True
|
||||
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
connected, subprotocol = await communicator.connect()
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertTrue(connected)
|
||||
|
||||
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
|
||||
|
@@ -58,6 +58,7 @@ from paperless.views import UserViewSet
|
||||
from paperless_mail.views import MailAccountViewSet
|
||||
from paperless_mail.views import MailRuleViewSet
|
||||
from paperless_mail.views import OauthCallbackView
|
||||
from paperless_mail.views import ProcessedMailViewSet
|
||||
|
||||
api_router = DefaultRouter()
|
||||
api_router.register(r"correspondents", CorrespondentViewSet)
|
||||
@@ -78,6 +79,7 @@ api_router.register(r"workflow_actions", WorkflowActionViewSet)
|
||||
api_router.register(r"workflows", WorkflowViewSet)
|
||||
api_router.register(r"custom_fields", CustomFieldViewSet)
|
||||
api_router.register(r"config", ApplicationConfigurationViewSet)
|
||||
api_router.register(r"processed_mail", ProcessedMailViewSet)
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
|
12
src/paperless_mail/filters.py
Normal file
12
src/paperless_mail/filters.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from django_filters import FilterSet
|
||||
|
||||
from paperless_mail.models import ProcessedMail
|
||||
|
||||
|
||||
class ProcessedMailFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = ProcessedMail
|
||||
fields = {
|
||||
"rule": ["exact"],
|
||||
"status": ["exact"],
|
||||
}
|
@@ -6,6 +6,7 @@ from documents.serialisers import OwnedObjectSerializer
|
||||
from documents.serialisers import TagsField
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.models import ProcessedMail
|
||||
|
||||
|
||||
class ObfuscatedPasswordField(serializers.CharField):
|
||||
@@ -130,3 +131,20 @@ class MailRuleSerializer(OwnedObjectSerializer):
|
||||
if value > 36500: # ~100 years
|
||||
raise serializers.ValidationError("Maximum mail age is unreasonably large.")
|
||||
return value
|
||||
|
||||
|
||||
class ProcessedMailSerializer(OwnedObjectSerializer):
|
||||
class Meta:
|
||||
model = ProcessedMail
|
||||
fields = [
|
||||
"id",
|
||||
"owner",
|
||||
"rule",
|
||||
"folder",
|
||||
"uid",
|
||||
"subject",
|
||||
"received",
|
||||
"processed",
|
||||
"status",
|
||||
"error",
|
||||
]
|
||||
|
@@ -3,6 +3,7 @@ from unittest import mock
|
||||
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils import timezone
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
@@ -13,6 +14,7 @@ from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.models import ProcessedMail
|
||||
from paperless_mail.tests.test_mail import BogusMailBox
|
||||
|
||||
|
||||
@@ -721,3 +723,285 @@ class TestAPIMailRules(DirectoriesMixin, APITestCase):
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("maximum_age", response.data)
|
||||
|
||||
|
||||
class TestAPIProcessedMails(DirectoriesMixin, APITestCase):
|
||||
ENDPOINT = "/api/processed_mail/"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.user = User.objects.create_user(username="temp_admin")
|
||||
self.user.user_permissions.add(*Permission.objects.all())
|
||||
self.user.save()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
def test_get_processed_mails_owner_aware(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured processed mails with different users
|
||||
WHEN:
|
||||
- API call is made to get processed mails
|
||||
THEN:
|
||||
- Only unowned, owned by user or granted processed mails are provided
|
||||
"""
|
||||
user2 = User.objects.create_user(username="temp_admin2")
|
||||
|
||||
account = MailAccount.objects.create(
|
||||
name="Email1",
|
||||
username="username1",
|
||||
password="password1",
|
||||
imap_server="server.example.com",
|
||||
imap_port=443,
|
||||
imap_security=MailAccount.ImapSecurity.SSL,
|
||||
character_set="UTF-8",
|
||||
)
|
||||
|
||||
rule = MailRule.objects.create(
|
||||
name="Rule1",
|
||||
account=account,
|
||||
folder="INBOX",
|
||||
filter_from="from@example.com",
|
||||
order=0,
|
||||
)
|
||||
|
||||
pm1 = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="1",
|
||||
subject="Subj1",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
)
|
||||
|
||||
pm2 = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="2",
|
||||
subject="Subj2",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="FAILED",
|
||||
error="err",
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="3",
|
||||
subject="Subj3",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
owner=user2,
|
||||
)
|
||||
|
||||
pm4 = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="4",
|
||||
subject="Subj4",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
)
|
||||
pm4.owner = user2
|
||||
pm4.save()
|
||||
assign_perm("view_processedmail", self.user, pm4)
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 3)
|
||||
returned_ids = {r["id"] for r in response.data["results"]}
|
||||
self.assertSetEqual(returned_ids, {pm1.id, pm2.id, pm4.id})
|
||||
|
||||
def test_get_processed_mails_filter_by_rule(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Processed mails belonging to two different rules
|
||||
WHEN:
|
||||
- API call is made with rule filter
|
||||
THEN:
|
||||
- Only processed mails for that rule are returned
|
||||
"""
|
||||
account = MailAccount.objects.create(
|
||||
name="Email1",
|
||||
username="username1",
|
||||
password="password1",
|
||||
imap_server="server.example.com",
|
||||
imap_port=443,
|
||||
imap_security=MailAccount.ImapSecurity.SSL,
|
||||
character_set="UTF-8",
|
||||
)
|
||||
|
||||
rule1 = MailRule.objects.create(
|
||||
name="Rule1",
|
||||
account=account,
|
||||
folder="INBOX",
|
||||
filter_from="from1@example.com",
|
||||
order=0,
|
||||
)
|
||||
rule2 = MailRule.objects.create(
|
||||
name="Rule2",
|
||||
account=account,
|
||||
folder="INBOX",
|
||||
filter_from="from2@example.com",
|
||||
order=1,
|
||||
)
|
||||
|
||||
pm1 = ProcessedMail.objects.create(
|
||||
rule=rule1,
|
||||
folder="INBOX",
|
||||
uid="r1-1",
|
||||
subject="R1-A",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
owner=self.user,
|
||||
)
|
||||
pm2 = ProcessedMail.objects.create(
|
||||
rule=rule1,
|
||||
folder="INBOX",
|
||||
uid="r1-2",
|
||||
subject="R1-B",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="FAILED",
|
||||
error="e",
|
||||
)
|
||||
ProcessedMail.objects.create(
|
||||
rule=rule2,
|
||||
folder="INBOX",
|
||||
uid="r2-1",
|
||||
subject="R2-A",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
)
|
||||
|
||||
response = self.client.get(f"{self.ENDPOINT}?rule={rule1.pk}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
returned_ids = {r["id"] for r in response.data["results"]}
|
||||
self.assertSetEqual(returned_ids, {pm1.id, pm2.id})
|
||||
|
||||
def test_bulk_delete_processed_mails(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Processed mails belonging to two different rules and different users
|
||||
WHEN:
|
||||
- API call is made to bulk delete some of the processed mails
|
||||
THEN:
|
||||
- Only the specified processed mails are deleted, respecting ownership and permissions
|
||||
"""
|
||||
user2 = User.objects.create_user(username="temp_admin2")
|
||||
|
||||
account = MailAccount.objects.create(
|
||||
name="Email1",
|
||||
username="username1",
|
||||
password="password1",
|
||||
imap_server="server.example.com",
|
||||
imap_port=443,
|
||||
imap_security=MailAccount.ImapSecurity.SSL,
|
||||
character_set="UTF-8",
|
||||
)
|
||||
|
||||
rule = MailRule.objects.create(
|
||||
name="Rule1",
|
||||
account=account,
|
||||
folder="INBOX",
|
||||
filter_from="from@example.com",
|
||||
order=0,
|
||||
)
|
||||
|
||||
# unowned and owned by self, and one with explicit object perm
|
||||
pm_unowned = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="u1",
|
||||
subject="Unowned",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
)
|
||||
pm_owned = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="u2",
|
||||
subject="Owned",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="FAILED",
|
||||
error="e",
|
||||
owner=self.user,
|
||||
)
|
||||
pm_granted = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="u3",
|
||||
subject="Granted",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
owner=user2,
|
||||
)
|
||||
assign_perm("delete_processedmail", self.user, pm_granted)
|
||||
pm_forbidden = ProcessedMail.objects.create(
|
||||
rule=rule,
|
||||
folder="INBOX",
|
||||
uid="u4",
|
||||
subject="Forbidden",
|
||||
received=timezone.now(),
|
||||
processed=timezone.now(),
|
||||
status="SUCCESS",
|
||||
error=None,
|
||||
owner=user2,
|
||||
)
|
||||
|
||||
# Success for allowed items
|
||||
response = self.client.post(
|
||||
f"{self.ENDPOINT}bulk_delete/",
|
||||
data={
|
||||
"mail_ids": [pm_unowned.id, pm_owned.id, pm_granted.id],
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["result"], "OK")
|
||||
self.assertSetEqual(
|
||||
set(response.data["deleted_mail_ids"]),
|
||||
{pm_unowned.id, pm_owned.id, pm_granted.id},
|
||||
)
|
||||
self.assertFalse(ProcessedMail.objects.filter(id=pm_unowned.id).exists())
|
||||
self.assertFalse(ProcessedMail.objects.filter(id=pm_owned.id).exists())
|
||||
self.assertFalse(ProcessedMail.objects.filter(id=pm_granted.id).exists())
|
||||
self.assertTrue(ProcessedMail.objects.filter(id=pm_forbidden.id).exists())
|
||||
|
||||
# 403 and not deleted
|
||||
response = self.client.post(
|
||||
f"{self.ENDPOINT}bulk_delete/",
|
||||
data={
|
||||
"mail_ids": [pm_forbidden.id],
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
self.assertTrue(ProcessedMail.objects.filter(id=pm_forbidden.id).exists())
|
||||
|
||||
# missing mail_ids
|
||||
response = self.client.post(
|
||||
f"{self.ENDPOINT}bulk_delete/",
|
||||
data={"mail_ids": "not-a-list"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
@@ -3,8 +3,10 @@ import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.http import HttpResponseForbidden
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.utils import timezone
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from drf_spectacular.utils import extend_schema_view
|
||||
@@ -12,23 +14,29 @@ from drf_spectacular.utils import inline_serializer
|
||||
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||
from rest_framework import serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
|
||||
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from documents.permissions import PaperlessObjectPermissions
|
||||
from documents.permissions import has_perms_owner_aware
|
||||
from documents.views import PassUserMixin
|
||||
from paperless.views import StandardPagination
|
||||
from paperless_mail.filters import ProcessedMailFilterSet
|
||||
from paperless_mail.mail import MailError
|
||||
from paperless_mail.mail import get_mailbox
|
||||
from paperless_mail.mail import mailbox_login
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.models import ProcessedMail
|
||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||
from paperless_mail.serialisers import MailAccountSerializer
|
||||
from paperless_mail.serialisers import MailRuleSerializer
|
||||
from paperless_mail.serialisers import ProcessedMailSerializer
|
||||
from paperless_mail.tasks import process_mail_accounts
|
||||
|
||||
|
||||
@@ -126,6 +134,34 @@ class MailAccountViewSet(ModelViewSet, PassUserMixin):
|
||||
return Response({"result": "OK"})
|
||||
|
||||
|
||||
class ProcessedMailViewSet(ReadOnlyModelViewSet, PassUserMixin):
|
||||
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
|
||||
serializer_class = ProcessedMailSerializer
|
||||
pagination_class = StandardPagination
|
||||
filter_backends = (
|
||||
DjangoFilterBackend,
|
||||
OrderingFilter,
|
||||
ObjectOwnedOrGrantedPermissionsFilter,
|
||||
)
|
||||
filterset_class = ProcessedMailFilterSet
|
||||
|
||||
queryset = ProcessedMail.objects.all().order_by("-processed")
|
||||
|
||||
@action(methods=["post"], detail=False)
|
||||
def bulk_delete(self, request):
|
||||
mail_ids = request.data.get("mail_ids", [])
|
||||
if not isinstance(mail_ids, list) or not all(
|
||||
isinstance(i, int) for i in mail_ids
|
||||
):
|
||||
return HttpResponseBadRequest("mail_ids must be a list of integers")
|
||||
mails = ProcessedMail.objects.filter(id__in=mail_ids)
|
||||
for mail in mails:
|
||||
if not has_perms_owner_aware(request.user, "delete_processedmail", mail):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
mail.delete()
|
||||
return Response({"result": "OK", "deleted_mail_ids": mail_ids})
|
||||
|
||||
|
||||
class MailRuleViewSet(ModelViewSet, PassUserMixin):
|
||||
model = MailRule
|
||||
|
||||
|
@@ -132,7 +132,7 @@ class RasterisedDocumentParser(DocumentParser):
|
||||
def get_dpi(self, image) -> int | None:
|
||||
try:
|
||||
with Image.open(image) as im:
|
||||
x, y = im.info["dpi"]
|
||||
x, _ = im.info["dpi"]
|
||||
return round(x)
|
||||
except Exception as e:
|
||||
self.log.warning(f"Error while getting DPI from image {image}: {e}")
|
||||
@@ -141,7 +141,7 @@ class RasterisedDocumentParser(DocumentParser):
|
||||
def calculate_a4_dpi(self, image) -> int | None:
|
||||
try:
|
||||
with Image.open(image) as im:
|
||||
width, height = im.size
|
||||
width, _ = im.size
|
||||
# divide image width by A4 width (210mm) in inches.
|
||||
dpi = int(width / (21 / 2.54))
|
||||
self.log.debug(f"Estimated DPI {dpi} based on image width {width}")
|
||||
|
Reference in New Issue
Block a user