Feature: Enhanced backend custom field search API (#7589)

commit 910dae8413028f647e6295f30207cb5d4fc6605d
Author: Yichi Yang <yiy067@ucsd.edu>
Date:   Wed Sep 4 12:47:19 2024 -0700

    Fix: correctly handle the case where custom_field_lookup refers to multiple fields

commit e43f70d708b7d6b445f3ca8c8bf9dbdf5ee26085
Author: Yichi Yang <yiy067@ucsd.edu>
Date:   Sat Aug 31 14:06:45 2024 -0700

Co-Authored-By: Yichi Yang <yichiyan@usc.edu>
This commit is contained in:
shamoon
2024-09-23 11:28:31 -07:00
parent f06ff85b7d
commit d7ba6d98d3
7 changed files with 1270 additions and 38 deletions

View File

@@ -1,24 +1,36 @@
import functools
import inspect
import json
import operator
from contextlib import contextmanager
from typing import Callable
from typing import Union
from django.contrib.contenttypes.models import ContentType
from django.db.models import CharField
from django.db.models import Count
from django.db.models import OuterRef
from django.db.models import Q
from django.db.models.functions import Cast
from django.utils.translation import gettext_lazy as _
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
from guardian.utils import get_group_obj_perms_model
from guardian.utils import get_user_obj_perms_model
from rest_framework import serializers
from rest_framework_guardian.filters import ObjectPermissionsFilter
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import Log
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
from paperless import settings
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
ID_KWARGS = ["in", "exact"]
@@ -182,6 +194,488 @@ class CustomFieldsFilter(Filter):
return qs
class SelectField(serializers.IntegerField):
def __init__(self, custom_field: CustomField):
self._options = custom_field.extra_data["select_options"]
super().__init__(min_value=0, max_value=len(self._options))
def to_internal_value(self, data):
if not isinstance(data, int):
# If the supplied value is not an integer,
# we will try to map it to an option index.
try:
data = self._options.index(data)
except ValueError:
pass
return super().to_internal_value(data)
def handle_validation_prefix(func: Callable):
"""
Catch ValidationErrors raised by the wrapped function
and add a prefix to the exception detail to track what causes the exception,
similar to nested serializers.
"""
def wrapper(*args, validation_prefix=None, **kwargs):
try:
return func(*args, **kwargs)
except serializers.ValidationError as e:
raise serializers.ValidationError({validation_prefix: e.detail})
# Update the signature to include the validation_prefix argument
old_sig = inspect.signature(func)
new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
# Apply functools.wraps and manually set the new signature
functools.update_wrapper(wrapper, func)
wrapper.__signature__ = new_sig
return wrapper
class CustomFieldLookupParser:
EXPR_BY_CATEGORY = {
"basic": ["exact", "in", "isnull", "exists"],
"string": [
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
],
"arithmetic": [
"gt",
"gte",
"lt",
"lte",
"range",
],
"containment": ["contains"],
}
# These string lookup expressions are problematic. We shall disable
# them by default unless the user explicitly opts in.
STR_EXPR_DISABLED_BY_DEFAULT = [
# SQLite: is case-sensitive outside the ASCII range
"iexact",
# SQLite: behaves the same as icontains
"contains",
# SQLite: behaves the same as istartswith
"startswith",
# SQLite: behaves the same as iendswith
"endswith",
# Syntax depends on database backends, can be exploited for ReDoS
"regex",
# Syntax depends on database backends, can be exploited for ReDoS
"iregex",
]
SUPPORTED_EXPR_CATEGORIES = {
CustomField.FieldDataType.STRING: ("basic", "string"),
CustomField.FieldDataType.URL: ("basic", "string"),
CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
CustomField.FieldDataType.BOOL: ("basic",),
CustomField.FieldDataType.INT: ("basic", "arithmetic"),
CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
CustomField.FieldDataType.MONETARY: ("basic", "string"),
CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
CustomField.FieldDataType.SELECT: ("basic",),
}
DATE_COMPONENTS = [
"year",
"iso_year",
"month",
"day",
"week",
"week_day",
"iso_week_day",
"quarter",
]
def __init__(
self,
validation_prefix,
max_query_depth=10,
max_atom_count=20,
) -> None:
"""
A helper class that parses the query string into a `django.db.models.Q` for filtering
documents based on custom field values.
The syntax of the query expression is illustrated with the below pseudo code rules:
1. parse([`custom_field`, "exists", true]):
matches documents with Q(custom_fields__field=`custom_field`)
2. parse([`custom_field`, "exists", false]):
matches documents with ~Q(custom_fields__field=`custom_field`)
3. parse([`custom_field`, `op`, `value`]):
matches documents with
Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
-> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
-> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
6. parse(["NOT", `q`])
-> ~parse(`q`)
Args:
validation_prefix: Used to generate the ValidationError message.
max_query_depth: Limits the maximum nesting depth of queries.
max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
`max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
complex SQL queries.
"""
self._custom_fields: dict[Union[int, str], CustomField] = {}
self._validation_prefix = validation_prefix
# Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
self._model_serializer = serializers.ModelSerializer()
# Used for sanity check
self._max_query_depth = max_query_depth
self._max_atom_count = max_atom_count
self._current_depth = 0
self._atom_count = 0
# The set of annotations that we need to apply to the queryset
self._annotations = {}
def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
"""
Parses the query string into a `django.db.models.Q`
and a set of annotations to be applied to the queryset.
"""
try:
expr = json.loads(query)
except json.JSONDecodeError:
raise serializers.ValidationError(
{self._validation_prefix: [_("Value must be valid JSON.")]},
)
return (
self._parse_expr(expr, validation_prefix=self._validation_prefix),
self._annotations,
)
@handle_validation_prefix
def _parse_expr(self, expr) -> Q:
"""
Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
"""
with self._track_query_depth():
if isinstance(expr, (list, tuple)):
if len(expr) == 2:
return self._parse_logical_expr(*expr)
elif len(expr) == 3:
return self._parse_atom(*expr)
raise serializers.ValidationError(
[_("Invalid custom field lookup expression")],
)
@handle_validation_prefix
def _parse_expr_list(self, exprs) -> list[Q]:
"""
Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
"""
if not isinstance(exprs, (list, tuple)) or not exprs:
raise serializers.ValidationError(
[_("Invalid expression list. Must be nonempty.")],
)
return [
self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
]
def _parse_logical_expr(self, op, args) -> Q:
"""
Handles rule 4, 5, 6.
"""
op_lower = op.lower()
if op_lower == "not":
return ~self._parse_expr(args, validation_prefix=1)
if op_lower == "and":
op_func = operator.and_
elif op_lower == "or":
op_func = operator.or_
else:
raise serializers.ValidationError(
{"0": [_("Invalid logical operator {op!r}").format(op=op)]},
)
qs = self._parse_expr_list(args, validation_prefix="1")
return functools.reduce(op_func, qs)
def _parse_atom(self, id_or_name, op, value) -> Q:
"""
Handles rule 1, 2, 3.
"""
# Guard against queries with too many conditions.
self._atom_count += 1
if self._atom_count > self._max_atom_count:
raise serializers.ValidationError(
[
_(
"Maximum number of query conditions exceeded. You can raise "
"the limit by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_ATOMS "
"in your configuration file.",
),
],
)
custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
op = self._validate_atom_op(custom_field, op, validation_prefix="1")
value = self._validate_atom_value(
custom_field,
op,
value,
validation_prefix="2",
)
# Needed because not all DB backends support Array __contains
if (
custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
and op == "contains"
):
return self._parse_atom_doc_link_contains(custom_field, value)
value_field_name = CustomFieldInstance.get_value_field_name(
custom_field.data_type,
)
has_field = Q(custom_fields__field=custom_field)
# Our special exists operator.
if op == "exists":
field_filter = has_field if value else ~has_field
else:
field_filter = has_field & Q(
**{f"custom_fields__{value_field_name}__{op}": value},
)
# We need to use an annotation here because different atoms
# might be referring to different instances of custom fields.
annotation_name = f"_custom_field_filter_{len(self._annotations)}"
self._annotations[annotation_name] = Count("custom_fields", filter=field_filter)
return Q(**{f"{annotation_name}__gt": 0})
@handle_validation_prefix
def _get_custom_field(self, id_or_name):
"""Get the CustomField instance by id or name."""
if id_or_name in self._custom_fields:
return self._custom_fields[id_or_name]
kwargs = (
{"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
)
try:
custom_field = CustomField.objects.get(**kwargs)
except CustomField.DoesNotExist:
raise serializers.ValidationError(
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
)
self._custom_fields[custom_field.id] = custom_field
self._custom_fields[custom_field.name] = custom_field
return custom_field
@staticmethod
def _split_op(full_op):
*prefix, op = str(full_op).rsplit("__", maxsplit=1)
prefix = prefix[0] if prefix else None
return prefix, op
@handle_validation_prefix
def _validate_atom_op(self, custom_field, raw_op):
"""Check if the `op` is compatible with the type of the custom field."""
prefix, op = self._split_op(raw_op)
# Check if the operator is supported for the current data_type.
supported = False
for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
if (
category == "string"
and op in self.STR_EXPR_DISABLED_BY_DEFAULT
and op not in settings.CUSTOM_FIELD_LOOKUP_OPT_IN
):
raise serializers.ValidationError(
[
_(
"{expr!r} is disabled by default because it does not "
"behave consistently across database backends, or can "
"cause security risks. If you understand the implications "
"you may enabled it by adding it to "
"`PAPERLESS_CUSTOM_FIELD_LOOKUP_OPT_IN`.",
).format(expr=op),
],
)
if op in self.EXPR_BY_CATEGORY[category]:
supported = True
break
# Check prefix
if prefix is not None:
if (
prefix in self.DATE_COMPONENTS
and custom_field.data_type == CustomField.FieldDataType.DATE
):
pass # ok - e.g., "year__exact" for date field
else:
supported = False # anything else is invalid
if not supported:
raise serializers.ValidationError(
[
_("{data_type} does not support lookup expr {expr!r}.").format(
data_type=custom_field.data_type,
expr=raw_op,
),
],
)
return raw_op
def _get_serializer_field(self, custom_field, full_op):
"""Return a serializers.Field for value validation."""
prefix, op = self._split_op(full_op)
field = None
if op in ("isnull", "exists"):
# `isnull` takes either True or False regardless of the data_type.
field = serializers.BooleanField()
elif (
custom_field.data_type == CustomField.FieldDataType.DATE
and prefix in self.DATE_COMPONENTS
):
# DateField admits lookups in the form of `year__exact`, etc. These take integers.
field = serializers.IntegerField()
elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
# We can be more specific here and make sure the value is a list.
field = serializers.ListField(child=serializers.IntegerField())
elif custom_field.data_type == CustomField.FieldDataType.SELECT:
# We use this custom field to permit SELECT option names.
field = SelectField(custom_field)
elif custom_field.data_type == CustomField.FieldDataType.URL:
# For URL fields we don't need to be strict about validation (e.g., for istartswith).
field = serializers.CharField()
else:
# The general case: inferred from the corresponding field in CustomFieldInstance.
value_field_name = CustomFieldInstance.get_value_field_name(
custom_field.data_type,
)
model_field = CustomFieldInstance._meta.get_field(value_field_name)
field_name = model_field.deconstruct()[0]
field_class, field_kwargs = self._model_serializer.build_standard_field(
field_name,
model_field,
)
field = field_class(**field_kwargs)
field.allow_null = False
# Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
# See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
if isinstance(field, serializers.CharField):
field.allow_blank = True
if op == "in":
# `in` takes a list of values.
field = serializers.ListField(child=field, allow_empty=False)
elif op == "range":
# `range` takes a list of values, i.e., [start, end].
field = serializers.ListField(
child=field,
min_length=2,
max_length=2,
)
return field
@handle_validation_prefix
def _validate_atom_value(self, custom_field, op, value):
"""Check if `value` is valid for the custom field and `op`. Returns the validated value."""
serializer_field = self._get_serializer_field(custom_field, op)
return serializer_field.run_validation(value)
def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
"""
Handles document link `contains` in a way that is supported by all DB backends.
"""
# If the value is an empty set,
# this is trivially true for any document with not null document links.
if not value:
return Q(
custom_fields__field=custom_field,
custom_fields__value_document_ids__isnull=False,
)
# First we lookup reverse links from the requested documents.
links = CustomFieldInstance.objects.filter(
document_id__in=value,
field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
)
# Check if any of the requested IDs are missing.
missing_ids = set(value) - set(link.document_id for link in links)
if missing_ids:
# The result should be an empty set in this case.
return Q(id__in=[])
# Take the intersection of the reverse links - this should be what we are looking for.
document_ids_we_want = functools.reduce(
operator.and_,
(set(link.value_document_ids) for link in links),
)
return Q(id__in=document_ids_we_want)
@contextmanager
def _track_query_depth(self):
# guard against queries that are too deeply nested
self._current_depth += 1
if self._current_depth > self._max_query_depth:
raise serializers.ValidationError(
[
_(
"Maximum nesting depth exceeded. You can raise the limit "
"by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_DEPTH in "
"your configuration file.",
),
],
)
try:
yield
finally:
self._current_depth -= 1
class CustomFieldLookupFilter(Filter):
def __init__(self, validation_prefix):
"""
A filter that filters documents based on custom field name and value.
Args:
validation_prefix: Used to generate the ValidationError message.
"""
super().__init__()
self._validation_prefix = validation_prefix
def filter(self, qs, value):
if not value:
return qs
parser = CustomFieldLookupParser(
self._validation_prefix,
max_query_depth=settings.CUSTOM_FIELD_LOOKUP_MAX_DEPTH,
max_atom_count=settings.CUSTOM_FIELD_LOOKUP_MAX_ATOMS,
)
q, annotations = parser.parse(value)
return qs.annotate(**annotations).filter(q)
class DocumentFilterSet(FilterSet):
is_tagged = BooleanFilter(
label="Is tagged",
@@ -229,6 +723,8 @@ class DocumentFilterSet(FilterSet):
exclude=True,
)
custom_field_lookup = CustomFieldLookupFilter("custom_field_lookup")
shared_by__id = SharedByUser()
class Meta: