mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-25 10:49:30 -05:00
Merge filters
This commit is contained in:
parent
fe404e4a24
commit
6850a8a4e7
@ -1,950 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db.models import Case
|
||||
from django.db.models import CharField
|
||||
from django.db.models import Count
|
||||
from django.db.models import Exists
|
||||
from django.db.models import IntegerField
|
||||
from django.db.models import OuterRef
|
||||
from django.db.models import Q
|
||||
from django.db.models import Subquery
|
||||
from django.db.models import Sum
|
||||
from django.db.models import Value
|
||||
from django.db.models import When
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters.rest_framework import BooleanFilter
|
||||
from django_filters.rest_framework import Filter
|
||||
from django_filters.rest_framework import FilterSet
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from guardian.utils import get_group_obj_perms_model
|
||||
from guardian.utils import get_user_obj_perms_model
|
||||
from rest_framework import serializers
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework_guardian.filters import ObjectPermissionsFilter
|
||||
|
||||
from paperless.models import Correspondent
|
||||
from paperless.models import CustomField
|
||||
from paperless.models import CustomFieldInstance
|
||||
from paperless.models import Document
|
||||
from paperless.models import DocumentType
|
||||
from paperless.models import PaperlessTask
|
||||
from paperless.models import ShareLink
|
||||
from paperless.models import StoragePath
|
||||
from paperless.models import Tag
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
|
||||
ID_KWARGS = ["in", "exact"]
|
||||
INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
|
||||
DATE_KWARGS = [
|
||||
"year",
|
||||
"month",
|
||||
"day",
|
||||
"date__gt",
|
||||
"date__gte",
|
||||
"gt",
|
||||
"gte",
|
||||
"date__lt",
|
||||
"date__lte",
|
||||
"lt",
|
||||
"lte",
|
||||
]
|
||||
|
||||
CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
|
||||
CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
|
||||
|
||||
|
||||
class CorrespondentFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Correspondent
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class TagFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Tag
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class DocumentTypeFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = DocumentType
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class StoragePathFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = StoragePath
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
"path": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class ObjectFilter(Filter):
|
||||
def __init__(self, *, exclude=False, in_list=False, field_name=""):
|
||||
super().__init__()
|
||||
self.exclude = exclude
|
||||
self.in_list = in_list
|
||||
self.field_name = field_name
|
||||
|
||||
def filter(self, qs, value):
|
||||
if not value:
|
||||
return qs
|
||||
|
||||
try:
|
||||
object_ids = [int(x) for x in value.split(",")]
|
||||
except ValueError:
|
||||
return qs
|
||||
|
||||
if self.in_list:
|
||||
qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
|
||||
else:
|
||||
for obj_id in object_ids:
|
||||
if self.exclude:
|
||||
qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
|
||||
else:
|
||||
qs = qs.filter(**{f"{self.field_name}__id": obj_id})
|
||||
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.BooleanField)
|
||||
class InboxFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value == "true":
|
||||
return qs.filter(tags__is_inbox_tag=True)
|
||||
elif value == "false":
|
||||
return qs.exclude(tags__is_inbox_tag=True)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class TitleContentFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.BooleanField)
|
||||
class SharedByUser(Filter):
|
||||
def filter(self, qs, value):
|
||||
ctype = ContentType.objects.get_for_model(self.model)
|
||||
UserObjectPermission = get_user_obj_perms_model()
|
||||
GroupObjectPermission = get_group_obj_perms_model()
|
||||
# see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
|
||||
# to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
|
||||
return (
|
||||
qs.filter(
|
||||
owner_id=value,
|
||||
)
|
||||
.annotate(
|
||||
num_shared_users=Count(
|
||||
UserObjectPermission.objects.filter(
|
||||
content_type=ctype,
|
||||
object_pk=Cast(OuterRef("pk"), CharField()),
|
||||
).values("user_id")[:1],
|
||||
),
|
||||
)
|
||||
.annotate(
|
||||
num_shared_groups=Count(
|
||||
GroupObjectPermission.objects.filter(
|
||||
content_type=ctype,
|
||||
object_pk=Cast(OuterRef("pk"), CharField()),
|
||||
).values("group_id")[:1],
|
||||
),
|
||||
)
|
||||
.filter(
|
||||
Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
|
||||
)
|
||||
if value is not None
|
||||
else qs
|
||||
)
|
||||
|
||||
|
||||
class CustomFieldFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = CustomField
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class CustomFieldsFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
fields_with_matching_selects = CustomField.objects.filter(
|
||||
extra_data__icontains=value,
|
||||
)
|
||||
option_ids = []
|
||||
if fields_with_matching_selects.count() > 0:
|
||||
for field in fields_with_matching_selects:
|
||||
options = field.extra_data.get("select_options", [])
|
||||
for _, option in enumerate(options):
|
||||
if option.get("label").lower().find(value.lower()) != -1:
|
||||
option_ids.extend([option.get("id")])
|
||||
return (
|
||||
qs.filter(custom_fields__field__name__icontains=value)
|
||||
| qs.filter(custom_fields__value_text__icontains=value)
|
||||
| qs.filter(custom_fields__value_bool__icontains=value)
|
||||
| qs.filter(custom_fields__value_int__icontains=value)
|
||||
| qs.filter(custom_fields__value_float__icontains=value)
|
||||
| qs.filter(custom_fields__value_date__icontains=value)
|
||||
| qs.filter(custom_fields__value_url__icontains=value)
|
||||
| qs.filter(custom_fields__value_monetary__icontains=value)
|
||||
| qs.filter(custom_fields__value_document_ids__icontains=value)
|
||||
| qs.filter(custom_fields__value_select__in=option_ids)
|
||||
)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
class MimeTypeFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
return qs.filter(mime_type__icontains=value)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
class SelectField(serializers.CharField):
|
||||
def __init__(self, custom_field: CustomField):
|
||||
self._options = custom_field.extra_data["select_options"]
|
||||
super().__init__(max_length=16)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
# If the supplied value is the option label instead of the ID
|
||||
try:
|
||||
data = next(
|
||||
option.get("id")
|
||||
for option in self._options
|
||||
if option.get("label") == data
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
def handle_validation_prefix(func: Callable):
|
||||
"""
|
||||
Catch ValidationErrors raised by the wrapped function
|
||||
and add a prefix to the exception detail to track what causes the exception,
|
||||
similar to nested serializers.
|
||||
"""
|
||||
|
||||
def wrapper(*args, validation_prefix=None, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except serializers.ValidationError as e:
|
||||
raise serializers.ValidationError({validation_prefix: e.detail})
|
||||
|
||||
# Update the signature to include the validation_prefix argument
|
||||
old_sig = inspect.signature(func)
|
||||
new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
|
||||
new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
|
||||
|
||||
# Apply functools.wraps and manually set the new signature
|
||||
functools.update_wrapper(wrapper, func)
|
||||
wrapper.__signature__ = new_sig
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CustomFieldQueryParser:
|
||||
EXPR_BY_CATEGORY = {
|
||||
"basic": ["exact", "in", "isnull", "exists"],
|
||||
"string": [
|
||||
"icontains",
|
||||
"istartswith",
|
||||
"iendswith",
|
||||
],
|
||||
"arithmetic": [
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"range",
|
||||
],
|
||||
"containment": ["contains"],
|
||||
}
|
||||
|
||||
SUPPORTED_EXPR_CATEGORIES = {
|
||||
CustomField.FieldDataType.STRING: ("basic", "string"),
|
||||
CustomField.FieldDataType.URL: ("basic", "string"),
|
||||
CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.BOOL: ("basic",),
|
||||
CustomField.FieldDataType.INT: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
|
||||
CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
|
||||
CustomField.FieldDataType.SELECT: ("basic",),
|
||||
}
|
||||
|
||||
DATE_COMPONENTS = [
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"day",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validation_prefix,
|
||||
max_query_depth=10,
|
||||
max_atom_count=20,
|
||||
) -> None:
|
||||
"""
|
||||
A helper class that parses the query string into a `django.db.models.Q` for filtering
|
||||
documents based on custom field values.
|
||||
|
||||
The syntax of the query expression is illustrated with the below pseudo code rules:
|
||||
1. parse([`custom_field`, "exists", true]):
|
||||
matches documents with Q(custom_fields__field=`custom_field`)
|
||||
2. parse([`custom_field`, "exists", false]):
|
||||
matches documents with ~Q(custom_fields__field=`custom_field`)
|
||||
3. parse([`custom_field`, `op`, `value`]):
|
||||
matches documents with
|
||||
Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
|
||||
4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
|
||||
-> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
|
||||
5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
|
||||
-> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
|
||||
6. parse(["NOT", `q`])
|
||||
-> ~parse(`q`)
|
||||
|
||||
Args:
|
||||
validation_prefix: Used to generate the ValidationError message.
|
||||
max_query_depth: Limits the maximum nesting depth of queries.
|
||||
max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
|
||||
|
||||
`max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
|
||||
complex SQL queries.
|
||||
"""
|
||||
self._custom_fields: dict[int | str, CustomField] = {}
|
||||
self._validation_prefix = validation_prefix
|
||||
# Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
|
||||
self._model_serializer = serializers.ModelSerializer()
|
||||
# Used for sanity check
|
||||
self._max_query_depth = max_query_depth
|
||||
self._max_atom_count = max_atom_count
|
||||
self._current_depth = 0
|
||||
self._atom_count = 0
|
||||
# The set of annotations that we need to apply to the queryset
|
||||
self._annotations = {}
|
||||
|
||||
def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
|
||||
"""
|
||||
Parses the query string into a `django.db.models.Q`
|
||||
and a set of annotations to be applied to the queryset.
|
||||
"""
|
||||
try:
|
||||
expr = json.loads(query)
|
||||
except json.JSONDecodeError:
|
||||
raise serializers.ValidationError(
|
||||
{self._validation_prefix: [_("Value must be valid JSON.")]},
|
||||
)
|
||||
return (
|
||||
self._parse_expr(expr, validation_prefix=self._validation_prefix),
|
||||
self._annotations,
|
||||
)
|
||||
|
||||
@handle_validation_prefix
|
||||
def _parse_expr(self, expr) -> Q:
|
||||
"""
|
||||
Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
|
||||
"""
|
||||
with self._track_query_depth():
|
||||
if isinstance(expr, list | tuple):
|
||||
if len(expr) == 2:
|
||||
return self._parse_logical_expr(*expr)
|
||||
elif len(expr) == 3:
|
||||
return self._parse_atom(*expr)
|
||||
raise serializers.ValidationError(
|
||||
[_("Invalid custom field query expression")],
|
||||
)
|
||||
|
||||
@handle_validation_prefix
|
||||
def _parse_expr_list(self, exprs) -> list[Q]:
|
||||
"""
|
||||
Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
|
||||
"""
|
||||
if not isinstance(exprs, list | tuple) or not exprs:
|
||||
raise serializers.ValidationError(
|
||||
[_("Invalid expression list. Must be nonempty.")],
|
||||
)
|
||||
return [
|
||||
self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
|
||||
]
|
||||
|
||||
def _parse_logical_expr(self, op, args) -> Q:
|
||||
"""
|
||||
Handles rule 4, 5, 6.
|
||||
"""
|
||||
op_lower = op.lower()
|
||||
|
||||
if op_lower == "not":
|
||||
return ~self._parse_expr(args, validation_prefix=1)
|
||||
|
||||
if op_lower == "and":
|
||||
op_func = operator.and_
|
||||
elif op_lower == "or":
|
||||
op_func = operator.or_
|
||||
else:
|
||||
raise serializers.ValidationError(
|
||||
{"0": [_("Invalid logical operator {op!r}").format(op=op)]},
|
||||
)
|
||||
|
||||
qs = self._parse_expr_list(args, validation_prefix="1")
|
||||
return functools.reduce(op_func, qs)
|
||||
|
||||
def _parse_atom(self, id_or_name, op, value) -> Q:
|
||||
"""
|
||||
Handles rule 1, 2, 3.
|
||||
"""
|
||||
# Guard against queries with too many conditions.
|
||||
self._atom_count += 1
|
||||
if self._atom_count > self._max_atom_count:
|
||||
raise serializers.ValidationError(
|
||||
[_("Maximum number of query conditions exceeded.")],
|
||||
)
|
||||
|
||||
custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
|
||||
op = self._validate_atom_op(custom_field, op, validation_prefix="1")
|
||||
value = self._validate_atom_value(
|
||||
custom_field,
|
||||
op,
|
||||
value,
|
||||
validation_prefix="2",
|
||||
)
|
||||
|
||||
# Needed because not all DB backends support Array __contains
|
||||
if (
|
||||
custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
|
||||
and op == "contains"
|
||||
):
|
||||
return self._parse_atom_doc_link_contains(custom_field, value)
|
||||
|
||||
value_field_name = CustomFieldInstance.get_value_field_name(
|
||||
custom_field.data_type,
|
||||
)
|
||||
if (
|
||||
custom_field.data_type == CustomField.FieldDataType.MONETARY
|
||||
and op in self.EXPR_BY_CATEGORY["arithmetic"]
|
||||
):
|
||||
value_field_name = "value_monetary_amount"
|
||||
has_field = Q(custom_fields__field=custom_field)
|
||||
|
||||
# We need to use an annotation here because different atoms
|
||||
# might be referring to different instances of custom fields.
|
||||
annotation_name = f"_custom_field_filter_{len(self._annotations)}"
|
||||
|
||||
# Our special exists operator.
|
||||
if op == "exists":
|
||||
annotation = Count("custom_fields", filter=has_field)
|
||||
# A Document should have > 0 match if it has this field, or 0 if doesn't.
|
||||
query_op = "gt" if value else "exact"
|
||||
query = Q(**{f"{annotation_name}__{query_op}": 0})
|
||||
else:
|
||||
# Check if 1) custom field name matches, and 2) value satisfies condition
|
||||
field_filter = has_field & Q(
|
||||
**{f"custom_fields__{value_field_name}__{op}": value},
|
||||
)
|
||||
# Annotate how many matching custom fields each document has
|
||||
annotation = Count("custom_fields", filter=field_filter)
|
||||
# Filter document by count
|
||||
query = Q(**{f"{annotation_name}__gt": 0})
|
||||
|
||||
self._annotations[annotation_name] = annotation
|
||||
return query
|
||||
|
||||
@handle_validation_prefix
|
||||
def _get_custom_field(self, id_or_name):
|
||||
"""Get the CustomField instance by id or name."""
|
||||
if id_or_name in self._custom_fields:
|
||||
return self._custom_fields[id_or_name]
|
||||
|
||||
kwargs = (
|
||||
{"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
|
||||
)
|
||||
try:
|
||||
custom_field = CustomField.objects.get(**kwargs)
|
||||
except CustomField.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
|
||||
)
|
||||
self._custom_fields[custom_field.id] = custom_field
|
||||
self._custom_fields[custom_field.name] = custom_field
|
||||
return custom_field
|
||||
|
||||
@staticmethod
|
||||
def _split_op(full_op):
|
||||
*prefix, op = str(full_op).rsplit("__", maxsplit=1)
|
||||
prefix = prefix[0] if prefix else None
|
||||
return prefix, op
|
||||
|
||||
@handle_validation_prefix
|
||||
def _validate_atom_op(self, custom_field, raw_op):
|
||||
"""Check if the `op` is compatible with the type of the custom field."""
|
||||
prefix, op = self._split_op(raw_op)
|
||||
|
||||
# Check if the operator is supported for the current data_type.
|
||||
supported = False
|
||||
for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
|
||||
if op in self.EXPR_BY_CATEGORY[category]:
|
||||
supported = True
|
||||
break
|
||||
|
||||
# Check prefix
|
||||
if prefix is not None:
|
||||
if (
|
||||
prefix in self.DATE_COMPONENTS
|
||||
and custom_field.data_type == CustomField.FieldDataType.DATE
|
||||
):
|
||||
pass # ok - e.g., "year__exact" for date field
|
||||
else:
|
||||
supported = False # anything else is invalid
|
||||
|
||||
if not supported:
|
||||
raise serializers.ValidationError(
|
||||
[
|
||||
_("{data_type} does not support query expr {expr!r}.").format(
|
||||
data_type=custom_field.data_type,
|
||||
expr=raw_op,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return raw_op
|
||||
|
||||
def _get_serializer_field(self, custom_field, full_op):
|
||||
"""Return a serializers.Field for value validation."""
|
||||
prefix, op = self._split_op(full_op)
|
||||
field = None
|
||||
|
||||
if op in ("isnull", "exists"):
|
||||
# `isnull` takes either True or False regardless of the data_type.
|
||||
field = serializers.BooleanField()
|
||||
elif (
|
||||
custom_field.data_type == CustomField.FieldDataType.DATE
|
||||
and prefix in self.DATE_COMPONENTS
|
||||
):
|
||||
# DateField admits queries in the form of `year__exact`, etc. These take integers.
|
||||
field = serializers.IntegerField()
|
||||
elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
|
||||
# We can be more specific here and make sure the value is a list.
|
||||
field = serializers.ListField(child=serializers.IntegerField())
|
||||
elif custom_field.data_type == CustomField.FieldDataType.SELECT:
|
||||
# We use this custom field to permit SELECT option names.
|
||||
field = SelectField(custom_field)
|
||||
elif custom_field.data_type == CustomField.FieldDataType.URL:
|
||||
# For URL fields we don't need to be strict about validation (e.g., for istartswith).
|
||||
field = serializers.CharField()
|
||||
else:
|
||||
# The general case: inferred from the corresponding field in CustomFieldInstance.
|
||||
value_field_name = CustomFieldInstance.get_value_field_name(
|
||||
custom_field.data_type,
|
||||
)
|
||||
model_field = CustomFieldInstance._meta.get_field(value_field_name)
|
||||
field_name = model_field.deconstruct()[0]
|
||||
field_class, field_kwargs = self._model_serializer.build_standard_field(
|
||||
field_name,
|
||||
model_field,
|
||||
)
|
||||
field = field_class(**field_kwargs)
|
||||
field.allow_null = False
|
||||
|
||||
# Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
|
||||
# See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
|
||||
if isinstance(field, serializers.CharField):
|
||||
field.allow_blank = True
|
||||
|
||||
if op == "in":
|
||||
# `in` takes a list of values.
|
||||
field = serializers.ListField(child=field, allow_empty=False)
|
||||
elif op == "range":
|
||||
# `range` takes a list of values, i.e., [start, end].
|
||||
field = serializers.ListField(
|
||||
child=field,
|
||||
min_length=2,
|
||||
max_length=2,
|
||||
)
|
||||
|
||||
return field
|
||||
|
||||
@handle_validation_prefix
|
||||
def _validate_atom_value(self, custom_field, op, value):
|
||||
"""Check if `value` is valid for the custom field and `op`. Returns the validated value."""
|
||||
serializer_field = self._get_serializer_field(custom_field, op)
|
||||
return serializer_field.run_validation(value)
|
||||
|
||||
def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
|
||||
"""
|
||||
Handles document link `contains` in a way that is supported by all DB backends.
|
||||
"""
|
||||
|
||||
# If the value is an empty set,
|
||||
# this is trivially true for any document with not null document links.
|
||||
if not value:
|
||||
return Q(
|
||||
custom_fields__field=custom_field,
|
||||
custom_fields__value_document_ids__isnull=False,
|
||||
)
|
||||
|
||||
# First we look up reverse links from the requested documents.
|
||||
links = CustomFieldInstance.objects.filter(
|
||||
document_id__in=value,
|
||||
field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
|
||||
)
|
||||
|
||||
# Check if any of the requested IDs are missing.
|
||||
missing_ids = set(value) - set(link.document_id for link in links)
|
||||
if missing_ids:
|
||||
# The result should be an empty set in this case.
|
||||
return Q(id__in=[])
|
||||
|
||||
# Take the intersection of the reverse links - this should be what we are looking for.
|
||||
document_ids_we_want = functools.reduce(
|
||||
operator.and_,
|
||||
(set(link.value_document_ids) for link in links),
|
||||
)
|
||||
|
||||
return Q(id__in=document_ids_we_want)
|
||||
|
||||
@contextmanager
|
||||
def _track_query_depth(self):
|
||||
# guard against queries that are too deeply nested
|
||||
self._current_depth += 1
|
||||
if self._current_depth > self._max_query_depth:
|
||||
raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._current_depth -= 1
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class CustomFieldQueryFilter(Filter):
|
||||
def __init__(self, validation_prefix):
|
||||
"""
|
||||
A filter that filters documents based on custom field name and value.
|
||||
|
||||
Args:
|
||||
validation_prefix: Used to generate the ValidationError message.
|
||||
"""
|
||||
super().__init__()
|
||||
self._validation_prefix = validation_prefix
|
||||
|
||||
def filter(self, qs, value):
|
||||
if not value:
|
||||
return qs
|
||||
|
||||
parser = CustomFieldQueryParser(
|
||||
self._validation_prefix,
|
||||
max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
|
||||
max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
|
||||
)
|
||||
q, annotations = parser.parse(value)
|
||||
|
||||
return qs.annotate(**annotations).filter(q)
|
||||
|
||||
|
||||
class DocumentFilterSet(FilterSet):
|
||||
is_tagged = BooleanFilter(
|
||||
label="Is tagged",
|
||||
field_name="tags",
|
||||
lookup_expr="isnull",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
tags__id__all = ObjectFilter(field_name="tags")
|
||||
|
||||
tags__id__none = ObjectFilter(field_name="tags", exclude=True)
|
||||
|
||||
tags__id__in = ObjectFilter(field_name="tags", in_list=True)
|
||||
|
||||
correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
|
||||
|
||||
document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
|
||||
|
||||
storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
|
||||
|
||||
is_in_inbox = InboxFilter()
|
||||
|
||||
title_content = TitleContentFilter()
|
||||
|
||||
owner__id__none = ObjectFilter(field_name="owner", exclude=True)
|
||||
|
||||
custom_fields__icontains = CustomFieldsFilter()
|
||||
|
||||
custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
|
||||
|
||||
custom_fields__id__none = ObjectFilter(
|
||||
field_name="custom_fields__field",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
custom_fields__id__in = ObjectFilter(
|
||||
field_name="custom_fields__field",
|
||||
in_list=True,
|
||||
)
|
||||
|
||||
has_custom_fields = BooleanFilter(
|
||||
label="Has custom field",
|
||||
field_name="custom_fields",
|
||||
lookup_expr="isnull",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
custom_field_query = CustomFieldQueryFilter("custom_field_query")
|
||||
|
||||
shared_by__id = SharedByUser()
|
||||
|
||||
mime_type = MimeTypeFilter()
|
||||
|
||||
class Meta:
|
||||
model = Document
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"title": CHAR_KWARGS,
|
||||
"content": CHAR_KWARGS,
|
||||
"archive_serial_number": INT_KWARGS,
|
||||
"created": DATE_KWARGS,
|
||||
"added": DATE_KWARGS,
|
||||
"modified": DATE_KWARGS,
|
||||
"original_filename": CHAR_KWARGS,
|
||||
"checksum": CHAR_KWARGS,
|
||||
"correspondent": ["isnull"],
|
||||
"correspondent__id": ID_KWARGS,
|
||||
"correspondent__name": CHAR_KWARGS,
|
||||
"tags__id": ID_KWARGS,
|
||||
"tags__name": CHAR_KWARGS,
|
||||
"document_type": ["isnull"],
|
||||
"document_type__id": ID_KWARGS,
|
||||
"document_type__name": CHAR_KWARGS,
|
||||
"storage_path": ["isnull"],
|
||||
"storage_path__id": ID_KWARGS,
|
||||
"storage_path__name": CHAR_KWARGS,
|
||||
"owner": ["isnull"],
|
||||
"owner__id": ID_KWARGS,
|
||||
"custom_fields": ["icontains"],
|
||||
}
|
||||
|
||||
|
||||
class ShareLinkFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = ShareLink
|
||||
fields = {
|
||||
"created": DATE_KWARGS,
|
||||
"expiration": DATE_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class PaperlessTaskFilterSet(FilterSet):
|
||||
acknowledged = BooleanFilter(
|
||||
label="Acknowledged",
|
||||
field_name="acknowledged",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
fields = {
|
||||
"type": ["exact"],
|
||||
"task_name": ["exact"],
|
||||
"status": ["exact"],
|
||||
}
|
||||
|
||||
|
||||
class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
|
||||
"""
|
||||
A filter backend that limits results to those where the requesting user
|
||||
has read object level permissions, owns the objects, or objects without
|
||||
an owner (for backwards compat)
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
objects_with_perms = super().filter_queryset(request, queryset, view)
|
||||
objects_owned = queryset.filter(owner=request.user)
|
||||
objects_unowned = queryset.filter(owner__isnull=True)
|
||||
return objects_with_perms | objects_owned | objects_unowned
|
||||
|
||||
|
||||
class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
|
||||
"""
|
||||
A filter backend that limits results to those where the requesting user
|
||||
owns the objects or objects without an owner (for backwards compat)
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if request.user.is_superuser:
|
||||
return queryset
|
||||
objects_owned = queryset.filter(owner=request.user)
|
||||
objects_unowned = queryset.filter(owner__isnull=True)
|
||||
return objects_owned | objects_unowned
|
||||
|
||||
|
||||
class DocumentsOrderingFilter(OrderingFilter):
|
||||
field_name = "ordering"
|
||||
prefix = "custom_field_"
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
param = request.query_params.get("ordering")
|
||||
if param and self.prefix in param:
|
||||
custom_field_id = int(param.split(self.prefix)[1])
|
||||
try:
|
||||
field = CustomField.objects.get(pk=custom_field_id)
|
||||
except CustomField.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
|
||||
)
|
||||
|
||||
annotation = None
|
||||
match field.data_type:
|
||||
case CustomField.FieldDataType.STRING:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_text")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.INT:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_int")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.FLOAT:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_float")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.DATE:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_date")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.MONETARY:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_monetary_amount")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.SELECT:
|
||||
# Select options are a little more complicated since the value is the id of the option, not
|
||||
# the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
|
||||
# case statement for each option, setting the value to the index of the option in a list
|
||||
# sorted by label, and then summing the results to give a single value for the annotation
|
||||
|
||||
select_options = sorted(
|
||||
field.extra_data.get("select_options", []),
|
||||
key=lambda x: x.get("label"),
|
||||
)
|
||||
whens = [
|
||||
When(
|
||||
custom_fields__field_id=custom_field_id,
|
||||
custom_fields__value_select=option.get("id"),
|
||||
then=Value(idx, output_field=IntegerField()),
|
||||
)
|
||||
for idx, option in enumerate(select_options)
|
||||
]
|
||||
whens.append(
|
||||
When(
|
||||
custom_fields__field_id=custom_field_id,
|
||||
custom_fields__value_select__isnull=True,
|
||||
then=Value(
|
||||
len(select_options),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
),
|
||||
)
|
||||
annotation = Sum(
|
||||
Case(
|
||||
*whens,
|
||||
default=Value(0),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
)
|
||||
case CustomField.FieldDataType.DOCUMENTLINK:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_document_ids")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.URL:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_url")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.BOOL:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_bool")[:1],
|
||||
)
|
||||
|
||||
if not annotation:
|
||||
# Only happens if a new data type is added and not handled here
|
||||
raise ValueError("Invalid custom field data type")
|
||||
|
||||
queryset = (
|
||||
queryset.annotate(
|
||||
# We need to annotate the queryset with the custom field value
|
||||
custom_field_value=annotation,
|
||||
# We also need to annotate the queryset with a boolean for sorting whether the field exists
|
||||
has_field=Exists(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
),
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
"-has_field",
|
||||
param.replace(
|
||||
self.prefix + str(custom_field_id),
|
||||
"custom_field_value",
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return super().filter_queryset(request, queryset, view)
|
@ -1,8 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db.models import Case
|
||||
from django.db.models import CharField
|
||||
from django.db.models import Count
|
||||
from django.db.models import Exists
|
||||
from django.db.models import IntegerField
|
||||
from django.db.models import OuterRef
|
||||
from django.db.models import Q
|
||||
from django.db.models import Subquery
|
||||
from django.db.models import Sum
|
||||
from django.db.models import Value
|
||||
from django.db.models import When
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters.rest_framework import BooleanFilter
|
||||
from django_filters.rest_framework import Filter
|
||||
from django_filters.rest_framework import FilterSet
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from guardian.utils import get_group_obj_perms_model
|
||||
from guardian.utils import get_user_obj_perms_model
|
||||
from rest_framework import serializers
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework_guardian.filters import ObjectPermissionsFilter
|
||||
|
||||
from documents.filters import CHAR_KWARGS
|
||||
from paperless.models import Correspondent
|
||||
from paperless.models import CustomField
|
||||
from paperless.models import CustomFieldInstance
|
||||
from paperless.models import Document
|
||||
from paperless.models import DocumentType
|
||||
from paperless.models import PaperlessTask
|
||||
from paperless.models import ShareLink
|
||||
from paperless.models import StoragePath
|
||||
from paperless.models import Tag
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
|
||||
ID_KWARGS = ["in", "exact"]
|
||||
INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
|
||||
DATE_KWARGS = [
|
||||
"year",
|
||||
"month",
|
||||
"day",
|
||||
"date__gt",
|
||||
"date__gte",
|
||||
"gt",
|
||||
"gte",
|
||||
"date__lt",
|
||||
"date__lte",
|
||||
"lt",
|
||||
"lte",
|
||||
]
|
||||
|
||||
CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
|
||||
CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
|
||||
|
||||
|
||||
class UserFilterSet(FilterSet):
|
||||
@ -15,3 +77,888 @@ class GroupFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Group
|
||||
fields = {"name": CHAR_KWARGS}
|
||||
|
||||
|
||||
class CorrespondentFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Correspondent
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class TagFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Tag
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class DocumentTypeFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = DocumentType
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class StoragePathFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = StoragePath
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
"path": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class ObjectFilter(Filter):
|
||||
def __init__(self, *, exclude=False, in_list=False, field_name=""):
|
||||
super().__init__()
|
||||
self.exclude = exclude
|
||||
self.in_list = in_list
|
||||
self.field_name = field_name
|
||||
|
||||
def filter(self, qs, value):
|
||||
if not value:
|
||||
return qs
|
||||
|
||||
try:
|
||||
object_ids = [int(x) for x in value.split(",")]
|
||||
except ValueError:
|
||||
return qs
|
||||
|
||||
if self.in_list:
|
||||
qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
|
||||
else:
|
||||
for obj_id in object_ids:
|
||||
if self.exclude:
|
||||
qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
|
||||
else:
|
||||
qs = qs.filter(**{f"{self.field_name}__id": obj_id})
|
||||
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.BooleanField)
|
||||
class InboxFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value == "true":
|
||||
return qs.filter(tags__is_inbox_tag=True)
|
||||
elif value == "false":
|
||||
return qs.exclude(tags__is_inbox_tag=True)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class TitleContentFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
@extend_schema_field(serializers.BooleanField)
|
||||
class SharedByUser(Filter):
|
||||
def filter(self, qs, value):
|
||||
ctype = ContentType.objects.get_for_model(self.model)
|
||||
UserObjectPermission = get_user_obj_perms_model()
|
||||
GroupObjectPermission = get_group_obj_perms_model()
|
||||
# see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
|
||||
# to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
|
||||
return (
|
||||
qs.filter(
|
||||
owner_id=value,
|
||||
)
|
||||
.annotate(
|
||||
num_shared_users=Count(
|
||||
UserObjectPermission.objects.filter(
|
||||
content_type=ctype,
|
||||
object_pk=Cast(OuterRef("pk"), CharField()),
|
||||
).values("user_id")[:1],
|
||||
),
|
||||
)
|
||||
.annotate(
|
||||
num_shared_groups=Count(
|
||||
GroupObjectPermission.objects.filter(
|
||||
content_type=ctype,
|
||||
object_pk=Cast(OuterRef("pk"), CharField()),
|
||||
).values("group_id")[:1],
|
||||
),
|
||||
)
|
||||
.filter(
|
||||
Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
|
||||
)
|
||||
if value is not None
|
||||
else qs
|
||||
)
|
||||
|
||||
|
||||
class CustomFieldFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = CustomField
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"name": CHAR_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class CustomFieldsFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
fields_with_matching_selects = CustomField.objects.filter(
|
||||
extra_data__icontains=value,
|
||||
)
|
||||
option_ids = []
|
||||
if fields_with_matching_selects.count() > 0:
|
||||
for field in fields_with_matching_selects:
|
||||
options = field.extra_data.get("select_options", [])
|
||||
for _, option in enumerate(options):
|
||||
if option.get("label").lower().find(value.lower()) != -1:
|
||||
option_ids.extend([option.get("id")])
|
||||
return (
|
||||
qs.filter(custom_fields__field__name__icontains=value)
|
||||
| qs.filter(custom_fields__value_text__icontains=value)
|
||||
| qs.filter(custom_fields__value_bool__icontains=value)
|
||||
| qs.filter(custom_fields__value_int__icontains=value)
|
||||
| qs.filter(custom_fields__value_float__icontains=value)
|
||||
| qs.filter(custom_fields__value_date__icontains=value)
|
||||
| qs.filter(custom_fields__value_url__icontains=value)
|
||||
| qs.filter(custom_fields__value_monetary__icontains=value)
|
||||
| qs.filter(custom_fields__value_document_ids__icontains=value)
|
||||
| qs.filter(custom_fields__value_select__in=option_ids)
|
||||
)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
class MimeTypeFilter(Filter):
|
||||
def filter(self, qs, value):
|
||||
if value:
|
||||
return qs.filter(mime_type__icontains=value)
|
||||
else:
|
||||
return qs
|
||||
|
||||
|
||||
class SelectField(serializers.CharField):
|
||||
def __init__(self, custom_field: CustomField):
|
||||
self._options = custom_field.extra_data["select_options"]
|
||||
super().__init__(max_length=16)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
# If the supplied value is the option label instead of the ID
|
||||
try:
|
||||
data = next(
|
||||
option.get("id")
|
||||
for option in self._options
|
||||
if option.get("label") == data
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
def handle_validation_prefix(func: Callable):
|
||||
"""
|
||||
Catch ValidationErrors raised by the wrapped function
|
||||
and add a prefix to the exception detail to track what causes the exception,
|
||||
similar to nested serializers.
|
||||
"""
|
||||
|
||||
def wrapper(*args, validation_prefix=None, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except serializers.ValidationError as e:
|
||||
raise serializers.ValidationError({validation_prefix: e.detail})
|
||||
|
||||
# Update the signature to include the validation_prefix argument
|
||||
old_sig = inspect.signature(func)
|
||||
new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
|
||||
new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
|
||||
|
||||
# Apply functools.wraps and manually set the new signature
|
||||
functools.update_wrapper(wrapper, func)
|
||||
wrapper.__signature__ = new_sig
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CustomFieldQueryParser:
|
||||
EXPR_BY_CATEGORY = {
|
||||
"basic": ["exact", "in", "isnull", "exists"],
|
||||
"string": [
|
||||
"icontains",
|
||||
"istartswith",
|
||||
"iendswith",
|
||||
],
|
||||
"arithmetic": [
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"range",
|
||||
],
|
||||
"containment": ["contains"],
|
||||
}
|
||||
|
||||
SUPPORTED_EXPR_CATEGORIES = {
|
||||
CustomField.FieldDataType.STRING: ("basic", "string"),
|
||||
CustomField.FieldDataType.URL: ("basic", "string"),
|
||||
CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.BOOL: ("basic",),
|
||||
CustomField.FieldDataType.INT: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
|
||||
CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
|
||||
CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
|
||||
CustomField.FieldDataType.SELECT: ("basic",),
|
||||
}
|
||||
|
||||
DATE_COMPONENTS = [
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"day",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validation_prefix,
|
||||
max_query_depth=10,
|
||||
max_atom_count=20,
|
||||
) -> None:
|
||||
"""
|
||||
A helper class that parses the query string into a `django.db.models.Q` for filtering
|
||||
documents based on custom field values.
|
||||
|
||||
The syntax of the query expression is illustrated with the below pseudo code rules:
|
||||
1. parse([`custom_field`, "exists", true]):
|
||||
matches documents with Q(custom_fields__field=`custom_field`)
|
||||
2. parse([`custom_field`, "exists", false]):
|
||||
matches documents with ~Q(custom_fields__field=`custom_field`)
|
||||
3. parse([`custom_field`, `op`, `value`]):
|
||||
matches documents with
|
||||
Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
|
||||
4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
|
||||
-> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
|
||||
5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
|
||||
-> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
|
||||
6. parse(["NOT", `q`])
|
||||
-> ~parse(`q`)
|
||||
|
||||
Args:
|
||||
validation_prefix: Used to generate the ValidationError message.
|
||||
max_query_depth: Limits the maximum nesting depth of queries.
|
||||
max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
|
||||
|
||||
`max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
|
||||
complex SQL queries.
|
||||
"""
|
||||
self._custom_fields: dict[int | str, CustomField] = {}
|
||||
self._validation_prefix = validation_prefix
|
||||
# Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
|
||||
self._model_serializer = serializers.ModelSerializer()
|
||||
# Used for sanity check
|
||||
self._max_query_depth = max_query_depth
|
||||
self._max_atom_count = max_atom_count
|
||||
self._current_depth = 0
|
||||
self._atom_count = 0
|
||||
# The set of annotations that we need to apply to the queryset
|
||||
self._annotations = {}
|
||||
|
||||
def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
|
||||
"""
|
||||
Parses the query string into a `django.db.models.Q`
|
||||
and a set of annotations to be applied to the queryset.
|
||||
"""
|
||||
try:
|
||||
expr = json.loads(query)
|
||||
except json.JSONDecodeError:
|
||||
raise serializers.ValidationError(
|
||||
{self._validation_prefix: [_("Value must be valid JSON.")]},
|
||||
)
|
||||
return (
|
||||
self._parse_expr(expr, validation_prefix=self._validation_prefix),
|
||||
self._annotations,
|
||||
)
|
||||
|
||||
@handle_validation_prefix
|
||||
def _parse_expr(self, expr) -> Q:
|
||||
"""
|
||||
Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
|
||||
"""
|
||||
with self._track_query_depth():
|
||||
if isinstance(expr, list | tuple):
|
||||
if len(expr) == 2:
|
||||
return self._parse_logical_expr(*expr)
|
||||
elif len(expr) == 3:
|
||||
return self._parse_atom(*expr)
|
||||
raise serializers.ValidationError(
|
||||
[_("Invalid custom field query expression")],
|
||||
)
|
||||
|
||||
@handle_validation_prefix
|
||||
def _parse_expr_list(self, exprs) -> list[Q]:
|
||||
"""
|
||||
Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
|
||||
"""
|
||||
if not isinstance(exprs, list | tuple) or not exprs:
|
||||
raise serializers.ValidationError(
|
||||
[_("Invalid expression list. Must be nonempty.")],
|
||||
)
|
||||
return [
|
||||
self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
|
||||
]
|
||||
|
||||
def _parse_logical_expr(self, op, args) -> Q:
|
||||
"""
|
||||
Handles rule 4, 5, 6.
|
||||
"""
|
||||
op_lower = op.lower()
|
||||
|
||||
if op_lower == "not":
|
||||
return ~self._parse_expr(args, validation_prefix=1)
|
||||
|
||||
if op_lower == "and":
|
||||
op_func = operator.and_
|
||||
elif op_lower == "or":
|
||||
op_func = operator.or_
|
||||
else:
|
||||
raise serializers.ValidationError(
|
||||
{"0": [_("Invalid logical operator {op!r}").format(op=op)]},
|
||||
)
|
||||
|
||||
qs = self._parse_expr_list(args, validation_prefix="1")
|
||||
return functools.reduce(op_func, qs)
|
||||
|
||||
def _parse_atom(self, id_or_name, op, value) -> Q:
|
||||
"""
|
||||
Handles rule 1, 2, 3.
|
||||
"""
|
||||
# Guard against queries with too many conditions.
|
||||
self._atom_count += 1
|
||||
if self._atom_count > self._max_atom_count:
|
||||
raise serializers.ValidationError(
|
||||
[_("Maximum number of query conditions exceeded.")],
|
||||
)
|
||||
|
||||
custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
|
||||
op = self._validate_atom_op(custom_field, op, validation_prefix="1")
|
||||
value = self._validate_atom_value(
|
||||
custom_field,
|
||||
op,
|
||||
value,
|
||||
validation_prefix="2",
|
||||
)
|
||||
|
||||
# Needed because not all DB backends support Array __contains
|
||||
if (
|
||||
custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
|
||||
and op == "contains"
|
||||
):
|
||||
return self._parse_atom_doc_link_contains(custom_field, value)
|
||||
|
||||
value_field_name = CustomFieldInstance.get_value_field_name(
|
||||
custom_field.data_type,
|
||||
)
|
||||
if (
|
||||
custom_field.data_type == CustomField.FieldDataType.MONETARY
|
||||
and op in self.EXPR_BY_CATEGORY["arithmetic"]
|
||||
):
|
||||
value_field_name = "value_monetary_amount"
|
||||
has_field = Q(custom_fields__field=custom_field)
|
||||
|
||||
# We need to use an annotation here because different atoms
|
||||
# might be referring to different instances of custom fields.
|
||||
annotation_name = f"_custom_field_filter_{len(self._annotations)}"
|
||||
|
||||
# Our special exists operator.
|
||||
if op == "exists":
|
||||
annotation = Count("custom_fields", filter=has_field)
|
||||
# A Document should have > 0 match if it has this field, or 0 if doesn't.
|
||||
query_op = "gt" if value else "exact"
|
||||
query = Q(**{f"{annotation_name}__{query_op}": 0})
|
||||
else:
|
||||
# Check if 1) custom field name matches, and 2) value satisfies condition
|
||||
field_filter = has_field & Q(
|
||||
**{f"custom_fields__{value_field_name}__{op}": value},
|
||||
)
|
||||
# Annotate how many matching custom fields each document has
|
||||
annotation = Count("custom_fields", filter=field_filter)
|
||||
# Filter document by count
|
||||
query = Q(**{f"{annotation_name}__gt": 0})
|
||||
|
||||
self._annotations[annotation_name] = annotation
|
||||
return query
|
||||
|
||||
@handle_validation_prefix
|
||||
def _get_custom_field(self, id_or_name):
|
||||
"""Get the CustomField instance by id or name."""
|
||||
if id_or_name in self._custom_fields:
|
||||
return self._custom_fields[id_or_name]
|
||||
|
||||
kwargs = (
|
||||
{"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
|
||||
)
|
||||
try:
|
||||
custom_field = CustomField.objects.get(**kwargs)
|
||||
except CustomField.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
|
||||
)
|
||||
self._custom_fields[custom_field.id] = custom_field
|
||||
self._custom_fields[custom_field.name] = custom_field
|
||||
return custom_field
|
||||
|
||||
@staticmethod
|
||||
def _split_op(full_op):
|
||||
*prefix, op = str(full_op).rsplit("__", maxsplit=1)
|
||||
prefix = prefix[0] if prefix else None
|
||||
return prefix, op
|
||||
|
||||
@handle_validation_prefix
|
||||
def _validate_atom_op(self, custom_field, raw_op):
|
||||
"""Check if the `op` is compatible with the type of the custom field."""
|
||||
prefix, op = self._split_op(raw_op)
|
||||
|
||||
# Check if the operator is supported for the current data_type.
|
||||
supported = False
|
||||
for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
|
||||
if op in self.EXPR_BY_CATEGORY[category]:
|
||||
supported = True
|
||||
break
|
||||
|
||||
# Check prefix
|
||||
if prefix is not None:
|
||||
if (
|
||||
prefix in self.DATE_COMPONENTS
|
||||
and custom_field.data_type == CustomField.FieldDataType.DATE
|
||||
):
|
||||
pass # ok - e.g., "year__exact" for date field
|
||||
else:
|
||||
supported = False # anything else is invalid
|
||||
|
||||
if not supported:
|
||||
raise serializers.ValidationError(
|
||||
[
|
||||
_("{data_type} does not support query expr {expr!r}.").format(
|
||||
data_type=custom_field.data_type,
|
||||
expr=raw_op,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return raw_op
|
||||
|
||||
def _get_serializer_field(self, custom_field, full_op):
|
||||
"""Return a serializers.Field for value validation."""
|
||||
prefix, op = self._split_op(full_op)
|
||||
field = None
|
||||
|
||||
if op in ("isnull", "exists"):
|
||||
# `isnull` takes either True or False regardless of the data_type.
|
||||
field = serializers.BooleanField()
|
||||
elif (
|
||||
custom_field.data_type == CustomField.FieldDataType.DATE
|
||||
and prefix in self.DATE_COMPONENTS
|
||||
):
|
||||
# DateField admits queries in the form of `year__exact`, etc. These take integers.
|
||||
field = serializers.IntegerField()
|
||||
elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
|
||||
# We can be more specific here and make sure the value is a list.
|
||||
field = serializers.ListField(child=serializers.IntegerField())
|
||||
elif custom_field.data_type == CustomField.FieldDataType.SELECT:
|
||||
# We use this custom field to permit SELECT option names.
|
||||
field = SelectField(custom_field)
|
||||
elif custom_field.data_type == CustomField.FieldDataType.URL:
|
||||
# For URL fields we don't need to be strict about validation (e.g., for istartswith).
|
||||
field = serializers.CharField()
|
||||
else:
|
||||
# The general case: inferred from the corresponding field in CustomFieldInstance.
|
||||
value_field_name = CustomFieldInstance.get_value_field_name(
|
||||
custom_field.data_type,
|
||||
)
|
||||
model_field = CustomFieldInstance._meta.get_field(value_field_name)
|
||||
field_name = model_field.deconstruct()[0]
|
||||
field_class, field_kwargs = self._model_serializer.build_standard_field(
|
||||
field_name,
|
||||
model_field,
|
||||
)
|
||||
field = field_class(**field_kwargs)
|
||||
field.allow_null = False
|
||||
|
||||
# Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
|
||||
# See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
|
||||
if isinstance(field, serializers.CharField):
|
||||
field.allow_blank = True
|
||||
|
||||
if op == "in":
|
||||
# `in` takes a list of values.
|
||||
field = serializers.ListField(child=field, allow_empty=False)
|
||||
elif op == "range":
|
||||
# `range` takes a list of values, i.e., [start, end].
|
||||
field = serializers.ListField(
|
||||
child=field,
|
||||
min_length=2,
|
||||
max_length=2,
|
||||
)
|
||||
|
||||
return field
|
||||
|
||||
@handle_validation_prefix
|
||||
def _validate_atom_value(self, custom_field, op, value):
|
||||
"""Check if `value` is valid for the custom field and `op`. Returns the validated value."""
|
||||
serializer_field = self._get_serializer_field(custom_field, op)
|
||||
return serializer_field.run_validation(value)
|
||||
|
||||
def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
|
||||
"""
|
||||
Handles document link `contains` in a way that is supported by all DB backends.
|
||||
"""
|
||||
|
||||
# If the value is an empty set,
|
||||
# this is trivially true for any document with not null document links.
|
||||
if not value:
|
||||
return Q(
|
||||
custom_fields__field=custom_field,
|
||||
custom_fields__value_document_ids__isnull=False,
|
||||
)
|
||||
|
||||
# First we look up reverse links from the requested documents.
|
||||
links = CustomFieldInstance.objects.filter(
|
||||
document_id__in=value,
|
||||
field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
|
||||
)
|
||||
|
||||
# Check if any of the requested IDs are missing.
|
||||
missing_ids = set(value) - set(link.document_id for link in links)
|
||||
if missing_ids:
|
||||
# The result should be an empty set in this case.
|
||||
return Q(id__in=[])
|
||||
|
||||
# Take the intersection of the reverse links - this should be what we are looking for.
|
||||
document_ids_we_want = functools.reduce(
|
||||
operator.and_,
|
||||
(set(link.value_document_ids) for link in links),
|
||||
)
|
||||
|
||||
return Q(id__in=document_ids_we_want)
|
||||
|
||||
@contextmanager
|
||||
def _track_query_depth(self):
|
||||
# guard against queries that are too deeply nested
|
||||
self._current_depth += 1
|
||||
if self._current_depth > self._max_query_depth:
|
||||
raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._current_depth -= 1
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
class CustomFieldQueryFilter(Filter):
|
||||
def __init__(self, validation_prefix):
|
||||
"""
|
||||
A filter that filters documents based on custom field name and value.
|
||||
|
||||
Args:
|
||||
validation_prefix: Used to generate the ValidationError message.
|
||||
"""
|
||||
super().__init__()
|
||||
self._validation_prefix = validation_prefix
|
||||
|
||||
def filter(self, qs, value):
|
||||
if not value:
|
||||
return qs
|
||||
|
||||
parser = CustomFieldQueryParser(
|
||||
self._validation_prefix,
|
||||
max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
|
||||
max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
|
||||
)
|
||||
q, annotations = parser.parse(value)
|
||||
|
||||
return qs.annotate(**annotations).filter(q)
|
||||
|
||||
|
||||
class DocumentFilterSet(FilterSet):
|
||||
is_tagged = BooleanFilter(
|
||||
label="Is tagged",
|
||||
field_name="tags",
|
||||
lookup_expr="isnull",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
tags__id__all = ObjectFilter(field_name="tags")
|
||||
|
||||
tags__id__none = ObjectFilter(field_name="tags", exclude=True)
|
||||
|
||||
tags__id__in = ObjectFilter(field_name="tags", in_list=True)
|
||||
|
||||
correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
|
||||
|
||||
document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
|
||||
|
||||
storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
|
||||
|
||||
is_in_inbox = InboxFilter()
|
||||
|
||||
title_content = TitleContentFilter()
|
||||
|
||||
owner__id__none = ObjectFilter(field_name="owner", exclude=True)
|
||||
|
||||
custom_fields__icontains = CustomFieldsFilter()
|
||||
|
||||
custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
|
||||
|
||||
custom_fields__id__none = ObjectFilter(
|
||||
field_name="custom_fields__field",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
custom_fields__id__in = ObjectFilter(
|
||||
field_name="custom_fields__field",
|
||||
in_list=True,
|
||||
)
|
||||
|
||||
has_custom_fields = BooleanFilter(
|
||||
label="Has custom field",
|
||||
field_name="custom_fields",
|
||||
lookup_expr="isnull",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
custom_field_query = CustomFieldQueryFilter("custom_field_query")
|
||||
|
||||
shared_by__id = SharedByUser()
|
||||
|
||||
mime_type = MimeTypeFilter()
|
||||
|
||||
class Meta:
|
||||
model = Document
|
||||
fields = {
|
||||
"id": ID_KWARGS,
|
||||
"title": CHAR_KWARGS,
|
||||
"content": CHAR_KWARGS,
|
||||
"archive_serial_number": INT_KWARGS,
|
||||
"created": DATE_KWARGS,
|
||||
"added": DATE_KWARGS,
|
||||
"modified": DATE_KWARGS,
|
||||
"original_filename": CHAR_KWARGS,
|
||||
"checksum": CHAR_KWARGS,
|
||||
"correspondent": ["isnull"],
|
||||
"correspondent__id": ID_KWARGS,
|
||||
"correspondent__name": CHAR_KWARGS,
|
||||
"tags__id": ID_KWARGS,
|
||||
"tags__name": CHAR_KWARGS,
|
||||
"document_type": ["isnull"],
|
||||
"document_type__id": ID_KWARGS,
|
||||
"document_type__name": CHAR_KWARGS,
|
||||
"storage_path": ["isnull"],
|
||||
"storage_path__id": ID_KWARGS,
|
||||
"storage_path__name": CHAR_KWARGS,
|
||||
"owner": ["isnull"],
|
||||
"owner__id": ID_KWARGS,
|
||||
"custom_fields": ["icontains"],
|
||||
}
|
||||
|
||||
|
||||
class ShareLinkFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = ShareLink
|
||||
fields = {
|
||||
"created": DATE_KWARGS,
|
||||
"expiration": DATE_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
class PaperlessTaskFilterSet(FilterSet):
|
||||
acknowledged = BooleanFilter(
|
||||
label="Acknowledged",
|
||||
field_name="acknowledged",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
fields = {
|
||||
"type": ["exact"],
|
||||
"task_name": ["exact"],
|
||||
"status": ["exact"],
|
||||
}
|
||||
|
||||
|
||||
class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
|
||||
"""
|
||||
A filter backend that limits results to those where the requesting user
|
||||
has read object level permissions, owns the objects, or objects without
|
||||
an owner (for backwards compat)
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
objects_with_perms = super().filter_queryset(request, queryset, view)
|
||||
objects_owned = queryset.filter(owner=request.user)
|
||||
objects_unowned = queryset.filter(owner__isnull=True)
|
||||
return objects_with_perms | objects_owned | objects_unowned
|
||||
|
||||
|
||||
class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
|
||||
"""
|
||||
A filter backend that limits results to those where the requesting user
|
||||
owns the objects or objects without an owner (for backwards compat)
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if request.user.is_superuser:
|
||||
return queryset
|
||||
objects_owned = queryset.filter(owner=request.user)
|
||||
objects_unowned = queryset.filter(owner__isnull=True)
|
||||
return objects_owned | objects_unowned
|
||||
|
||||
|
||||
class DocumentsOrderingFilter(OrderingFilter):
|
||||
field_name = "ordering"
|
||||
prefix = "custom_field_"
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
param = request.query_params.get("ordering")
|
||||
if param and self.prefix in param:
|
||||
custom_field_id = int(param.split(self.prefix)[1])
|
||||
try:
|
||||
field = CustomField.objects.get(pk=custom_field_id)
|
||||
except CustomField.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
|
||||
)
|
||||
|
||||
annotation = None
|
||||
match field.data_type:
|
||||
case CustomField.FieldDataType.STRING:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_text")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.INT:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_int")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.FLOAT:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_float")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.DATE:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_date")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.MONETARY:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_monetary_amount")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.SELECT:
|
||||
# Select options are a little more complicated since the value is the id of the option, not
|
||||
# the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
|
||||
# case statement for each option, setting the value to the index of the option in a list
|
||||
# sorted by label, and then summing the results to give a single value for the annotation
|
||||
|
||||
select_options = sorted(
|
||||
field.extra_data.get("select_options", []),
|
||||
key=lambda x: x.get("label"),
|
||||
)
|
||||
whens = [
|
||||
When(
|
||||
custom_fields__field_id=custom_field_id,
|
||||
custom_fields__value_select=option.get("id"),
|
||||
then=Value(idx, output_field=IntegerField()),
|
||||
)
|
||||
for idx, option in enumerate(select_options)
|
||||
]
|
||||
whens.append(
|
||||
When(
|
||||
custom_fields__field_id=custom_field_id,
|
||||
custom_fields__value_select__isnull=True,
|
||||
then=Value(
|
||||
len(select_options),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
),
|
||||
)
|
||||
annotation = Sum(
|
||||
Case(
|
||||
*whens,
|
||||
default=Value(0),
|
||||
output_field=IntegerField(),
|
||||
),
|
||||
)
|
||||
case CustomField.FieldDataType.DOCUMENTLINK:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_document_ids")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.URL:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_url")[:1],
|
||||
)
|
||||
case CustomField.FieldDataType.BOOL:
|
||||
annotation = Subquery(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
).values("value_bool")[:1],
|
||||
)
|
||||
|
||||
if not annotation:
|
||||
# Only happens if a new data type is added and not handled here
|
||||
raise ValueError("Invalid custom field data type")
|
||||
|
||||
queryset = (
|
||||
queryset.annotate(
|
||||
# We need to annotate the queryset with the custom field value
|
||||
custom_field_value=annotation,
|
||||
# We also need to annotate the queryset with a boolean for sorting whether the field exists
|
||||
has_field=Exists(
|
||||
CustomFieldInstance.objects.filter(
|
||||
document_id=OuterRef("id"),
|
||||
field_id=custom_field_id,
|
||||
),
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
"-has_field",
|
||||
param.replace(
|
||||
self.prefix + str(custom_field_id),
|
||||
"custom_field_value",
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return super().filter_queryset(request, queryset, view)
|
||||
|
@ -89,17 +89,6 @@ from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from rest_framework.viewsets import ViewSet
|
||||
|
||||
from documents.filters import CorrespondentFilterSet
|
||||
from documents.filters import CustomFieldFilterSet
|
||||
from documents.filters import DocumentFilterSet
|
||||
from documents.filters import DocumentsOrderingFilter
|
||||
from documents.filters import DocumentTypeFilterSet
|
||||
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from documents.filters import ObjectOwnedPermissionsFilter
|
||||
from documents.filters import PaperlessTaskFilterSet
|
||||
from documents.filters import ShareLinkFilterSet
|
||||
from documents.filters import StoragePathFilterSet
|
||||
from documents.filters import TagFilterSet
|
||||
from documents.schema import generate_object_with_permissions_schema
|
||||
from documents.signals import document_updated
|
||||
from documents.templating.filepath import validate_filepath_template_and_render
|
||||
@ -129,7 +118,18 @@ from paperless.data_models import ConsumableDocument
|
||||
from paperless.data_models import DocumentMetadataOverrides
|
||||
from paperless.data_models import DocumentSource
|
||||
from paperless.db import GnuPG
|
||||
from paperless.filters import CorrespondentFilterSet
|
||||
from paperless.filters import CustomFieldFilterSet
|
||||
from paperless.filters import DocumentFilterSet
|
||||
from paperless.filters import DocumentsOrderingFilter
|
||||
from paperless.filters import DocumentTypeFilterSet
|
||||
from paperless.filters import GroupFilterSet
|
||||
from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from paperless.filters import ObjectOwnedPermissionsFilter
|
||||
from paperless.filters import PaperlessTaskFilterSet
|
||||
from paperless.filters import ShareLinkFilterSet
|
||||
from paperless.filters import StoragePathFilterSet
|
||||
from paperless.filters import TagFilterSet
|
||||
from paperless.filters import UserFilterSet
|
||||
from paperless.index import DelayedQuery
|
||||
from paperless.mail import send_email
|
||||
|
@ -17,7 +17,7 @@ from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from paperless.permissions import PaperlessObjectPermissions
|
||||
from paperless.views import PassUserMixin
|
||||
from paperless.views import StandardPagination
|
||||
|
Loading…
x
Reference in New Issue
Block a user