From d7ba6d98d396900fe20be78412385394069dd506 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 23 Sep 2024 11:28:31 -0700 Subject: [PATCH] Feature: Enhanced backend custom field search API (#7589) commit 910dae8413028f647e6295f30207cb5d4fc6605d Author: Yichi Yang Date: Wed Sep 4 12:47:19 2024 -0700 Fix: correctly handle the case where custom_field_lookup refers to multiple fields commit e43f70d708b7d6b445f3ca8c8bf9dbdf5ee26085 Author: Yichi Yang Date: Sat Aug 31 14:06:45 2024 -0700 Co-Authored-By: Yichi Yang --- docs/api.md | 67 +- paperless.conf.example | 1 + src/documents/filters.py | 496 +++++++++++++ src/documents/models.py | 40 +- src/documents/serialisers.py | 15 +- .../tests/test_api_filter_by_custom_fields.py | 670 ++++++++++++++++++ src/paperless/settings.py | 19 +- 7 files changed, 1270 insertions(+), 38 deletions(-) create mode 100644 src/documents/tests/test_api_filter_by_custom_fields.py diff --git a/docs/api.md b/docs/api.md index 94ece85ab..057ccaedb 100644 --- a/docs/api.md +++ b/docs/api.md @@ -235,12 +235,6 @@ results: Pagination works exactly the same as it does for normal requests on this endpoint. -Certain limitations apply to full text queries: - -- Results are always sorted by search score. The results matching the - query best will show up first. -- Only a small subset of filtering parameters are supported. - Furthermore, each returned document has an additional `__search_hit__` attribute with various information about the search results: @@ -280,6 +274,67 @@ attribute with various information about the search results: - `rank` is the index of the search results. The first result will have rank 0. +### Filtering by custom fields + +You can filter documents by their custom field values by specifying the +`custom_field_lookup` query parameter. Here are some recipes for common +use cases: + +1. Documents with a custom field "due" (date) between Aug 1, 2024 and + Sept 1, 2024 (inclusive): + + `?custom_field_lookup=["due", "range", ["2024-08-01", "2024-09-01"]]` + +2. Documents with a custom field "customer" (text) that equals "bob" + (case sensitive): + + `?custom_field_lookup=["customer", "exact", "bob"]` + +3. Documents with a custom field "answered" (boolean) set to `true`: + + `?custom_field_lookup=["answered", "exact", true]` + +4. Documents with a custom field "favorite animal" (select) set to either + "cat" or "dog": + + `?custom_field_lookup=["favorite animal", "in", ["cat", "dog"]]` + +5. Documents with a custom field "address" (text) that is empty: + + `?custom_field_lookup=["OR", ["address", "isnull", true], ["address", "exact", ""]]` + +6. Documents that don't have a field called "foo": + + `?custom_field_lookup=["foo", "exists", false]` + +7. Documents that have document links "references" to both document 3 and 7: + + `?custom_field_lookup=["references", "contains", [3, 7]]` + +All field types support basic operations including `exact`, `in`, `isnull`, +and `exists`. String, URL, and monetary fields support case-insensitive +substring matching operations including `icontains`, `istartswith`, and +`iendswith`. Integer, float, and date fields support arithmetic comparisons +including `gt` (>), `gte` (>=), `lt` (<), `lte` (<=), and `range`. +Lastly, document link fields support a `contains` operator that behaves +like a "is superset of" check. + +!!! warning + + It is possible to do case-insensitive exact match (i.e., `iexact`) and + case-sensitive substring match (i.e., `contains`, `startswith`, + `endswith`) for string, URL, and monetary fields, but + [they may not work as expected on some database backends](https://docs.djangoproject.com/en/5.1/ref/databases/#substring-matching-and-case-sensitivity). + + It is also possible to use regular expressions to match string, URL, and + monetary fields, but the syntax is database-dependent, and accepting + regular expressions from untrusted sources could make your instance + vulnerable to regular expression denial of service attacks. + + For these reasons the above expressions are disabled by default. + If you understand the implications, you may enable them by uncommenting + `PAPERLESS_CUSTOM_FIELD_LOOKUP_OPT_IN` in your configuration file. + ### `/api/search/autocomplete/` Get auto completions for a partial search term. diff --git a/paperless.conf.example b/paperless.conf.example index 63ee7be22..5fabbf390 100644 --- a/paperless.conf.example +++ b/paperless.conf.example @@ -81,6 +81,7 @@ #PAPERLESS_THUMBNAIL_FONT_NAME= #PAPERLESS_IGNORE_DATES= #PAPERLESS_ENABLE_UPDATE_CHECK= +#PAPERLESS_ALLOW_CUSTOM_FIELD_LOOKUP=iexact,contains,startswith,endswith,regex,iregex # Tika settings diff --git a/src/documents/filters.py b/src/documents/filters.py index 1770f8514..5288bd45c 100644 --- a/src/documents/filters.py +++ b/src/documents/filters.py @@ -1,24 +1,36 @@ +import functools +import inspect +import json +import operator +from contextlib import contextmanager +from typing import Callable +from typing import Union + from django.contrib.contenttypes.models import ContentType from django.db.models import CharField from django.db.models import Count from django.db.models import OuterRef from django.db.models import Q from django.db.models.functions import Cast +from django.utils.translation import gettext_lazy as _ from django_filters.rest_framework import BooleanFilter from django_filters.rest_framework import Filter from django_filters.rest_framework import FilterSet from guardian.utils import get_group_obj_perms_model from guardian.utils import get_user_obj_perms_model +from rest_framework import serializers from rest_framework_guardian.filters import ObjectPermissionsFilter from documents.models import Correspondent from documents.models import CustomField +from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType from documents.models import Log from documents.models import ShareLink from documents.models import StoragePath from documents.models import Tag +from paperless import settings CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"] ID_KWARGS = ["in", "exact"] @@ -182,6 +194,488 @@ class CustomFieldsFilter(Filter): return qs +class SelectField(serializers.IntegerField): + def __init__(self, custom_field: CustomField): + self._options = custom_field.extra_data["select_options"] + super().__init__(min_value=0, max_value=len(self._options)) + + def to_internal_value(self, data): + if not isinstance(data, int): + # If the supplied value is not an integer, + # we will try to map it to an option index. + try: + data = self._options.index(data) + except ValueError: + pass + return super().to_internal_value(data) + + +def handle_validation_prefix(func: Callable): + """ + Catch ValidationErrors raised by the wrapped function + and add a prefix to the exception detail to track what causes the exception, + similar to nested serializers. + """ + + def wrapper(*args, validation_prefix=None, **kwargs): + try: + return func(*args, **kwargs) + except serializers.ValidationError as e: + raise serializers.ValidationError({validation_prefix: e.detail}) + + # Update the signature to include the validation_prefix argument + old_sig = inspect.signature(func) + new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY) + new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param]) + + # Apply functools.wraps and manually set the new signature + functools.update_wrapper(wrapper, func) + wrapper.__signature__ = new_sig + + return wrapper + + +class CustomFieldLookupParser: + EXPR_BY_CATEGORY = { + "basic": ["exact", "in", "isnull", "exists"], + "string": [ + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "regex", + "iregex", + ], + "arithmetic": [ + "gt", + "gte", + "lt", + "lte", + "range", + ], + "containment": ["contains"], + } + + # These string lookup expressions are problematic. We shall disable + # them by default unless the user explicitly opts in. + STR_EXPR_DISABLED_BY_DEFAULT = [ + # SQLite: is case-sensitive outside the ASCII range + "iexact", + # SQLite: behaves the same as icontains + "contains", + # SQLite: behaves the same as istartswith + "startswith", + # SQLite: behaves the same as iendswith + "endswith", + # Syntax depends on database backends, can be exploited for ReDoS + "regex", + # Syntax depends on database backends, can be exploited for ReDoS + "iregex", + ] + + SUPPORTED_EXPR_CATEGORIES = { + CustomField.FieldDataType.STRING: ("basic", "string"), + CustomField.FieldDataType.URL: ("basic", "string"), + CustomField.FieldDataType.DATE: ("basic", "arithmetic"), + CustomField.FieldDataType.BOOL: ("basic",), + CustomField.FieldDataType.INT: ("basic", "arithmetic"), + CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"), + CustomField.FieldDataType.MONETARY: ("basic", "string"), + CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"), + CustomField.FieldDataType.SELECT: ("basic",), + } + + DATE_COMPONENTS = [ + "year", + "iso_year", + "month", + "day", + "week", + "week_day", + "iso_week_day", + "quarter", + ] + + def __init__( + self, + validation_prefix, + max_query_depth=10, + max_atom_count=20, + ) -> None: + """ + A helper class that parses the query string into a `django.db.models.Q` for filtering + documents based on custom field values. + + The syntax of the query expression is illustrated with the below pseudo code rules: + 1. parse([`custom_field`, "exists", true]): + matches documents with Q(custom_fields__field=`custom_field`) + 2. parse([`custom_field`, "exists", false]): + matches documents with ~Q(custom_fields__field=`custom_field`) + 3. parse([`custom_field`, `op`, `value`]): + matches documents with + Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`) + 4. parse(["AND", [`q0`, `q1`, ..., `qn`]]) + -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`) + 5. parse(["OR", [`q0`, `q1`, ..., `qn`]]) + -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`) + 6. parse(["NOT", `q`]) + -> ~parse(`q`) + + Args: + validation_prefix: Used to generate the ValidationError message. + max_query_depth: Limits the maximum nesting depth of queries. + max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query. + + `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily + complex SQL queries. + """ + self._custom_fields: dict[Union[int, str], CustomField] = {} + self._validation_prefix = validation_prefix + # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field. + self._model_serializer = serializers.ModelSerializer() + # Used for sanity check + self._max_query_depth = max_query_depth + self._max_atom_count = max_atom_count + self._current_depth = 0 + self._atom_count = 0 + # The set of annotations that we need to apply to the queryset + self._annotations = {} + + def parse(self, query: str) -> tuple[Q, dict[str, Count]]: + """ + Parses the query string into a `django.db.models.Q` + and a set of annotations to be applied to the queryset. + """ + try: + expr = json.loads(query) + except json.JSONDecodeError: + raise serializers.ValidationError( + {self._validation_prefix: [_("Value must be valid JSON.")]}, + ) + return ( + self._parse_expr(expr, validation_prefix=self._validation_prefix), + self._annotations, + ) + + @handle_validation_prefix + def _parse_expr(self, expr) -> Q: + """ + Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr. + """ + with self._track_query_depth(): + if isinstance(expr, (list, tuple)): + if len(expr) == 2: + return self._parse_logical_expr(*expr) + elif len(expr) == 3: + return self._parse_atom(*expr) + raise serializers.ValidationError( + [_("Invalid custom field lookup expression")], + ) + + @handle_validation_prefix + def _parse_expr_list(self, exprs) -> list[Q]: + """ + Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5. + """ + if not isinstance(exprs, (list, tuple)) or not exprs: + raise serializers.ValidationError( + [_("Invalid expression list. Must be nonempty.")], + ) + return [ + self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs) + ] + + def _parse_logical_expr(self, op, args) -> Q: + """ + Handles rule 4, 5, 6. + """ + op_lower = op.lower() + + if op_lower == "not": + return ~self._parse_expr(args, validation_prefix=1) + + if op_lower == "and": + op_func = operator.and_ + elif op_lower == "or": + op_func = operator.or_ + else: + raise serializers.ValidationError( + {"0": [_("Invalid logical operator {op!r}").format(op=op)]}, + ) + + qs = self._parse_expr_list(args, validation_prefix="1") + return functools.reduce(op_func, qs) + + def _parse_atom(self, id_or_name, op, value) -> Q: + """ + Handles rule 1, 2, 3. + """ + # Guard against queries with too many conditions. + self._atom_count += 1 + if self._atom_count > self._max_atom_count: + raise serializers.ValidationError( + [ + _( + "Maximum number of query conditions exceeded. You can raise " + "the limit by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_ATOMS " + "in your configuration file.", + ), + ], + ) + + custom_field = self._get_custom_field(id_or_name, validation_prefix="0") + op = self._validate_atom_op(custom_field, op, validation_prefix="1") + value = self._validate_atom_value( + custom_field, + op, + value, + validation_prefix="2", + ) + + # Needed because not all DB backends support Array __contains + if ( + custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK + and op == "contains" + ): + return self._parse_atom_doc_link_contains(custom_field, value) + + value_field_name = CustomFieldInstance.get_value_field_name( + custom_field.data_type, + ) + has_field = Q(custom_fields__field=custom_field) + + # Our special exists operator. + if op == "exists": + field_filter = has_field if value else ~has_field + else: + field_filter = has_field & Q( + **{f"custom_fields__{value_field_name}__{op}": value}, + ) + + # We need to use an annotation here because different atoms + # might be referring to different instances of custom fields. + annotation_name = f"_custom_field_filter_{len(self._annotations)}" + self._annotations[annotation_name] = Count("custom_fields", filter=field_filter) + + return Q(**{f"{annotation_name}__gt": 0}) + + @handle_validation_prefix + def _get_custom_field(self, id_or_name): + """Get the CustomField instance by id or name.""" + if id_or_name in self._custom_fields: + return self._custom_fields[id_or_name] + + kwargs = ( + {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name} + ) + try: + custom_field = CustomField.objects.get(**kwargs) + except CustomField.DoesNotExist: + raise serializers.ValidationError( + [_("{name!r} is not a valid custom field.").format(name=id_or_name)], + ) + self._custom_fields[custom_field.id] = custom_field + self._custom_fields[custom_field.name] = custom_field + return custom_field + + @staticmethod + def _split_op(full_op): + *prefix, op = str(full_op).rsplit("__", maxsplit=1) + prefix = prefix[0] if prefix else None + return prefix, op + + @handle_validation_prefix + def _validate_atom_op(self, custom_field, raw_op): + """Check if the `op` is compatible with the type of the custom field.""" + prefix, op = self._split_op(raw_op) + + # Check if the operator is supported for the current data_type. + supported = False + for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]: + if ( + category == "string" + and op in self.STR_EXPR_DISABLED_BY_DEFAULT + and op not in settings.CUSTOM_FIELD_LOOKUP_OPT_IN + ): + raise serializers.ValidationError( + [ + _( + "{expr!r} is disabled by default because it does not " + "behave consistently across database backends, or can " + "cause security risks. If you understand the implications " + "you may enabled it by adding it to " + "`PAPERLESS_CUSTOM_FIELD_LOOKUP_OPT_IN`.", + ).format(expr=op), + ], + ) + if op in self.EXPR_BY_CATEGORY[category]: + supported = True + break + + # Check prefix + if prefix is not None: + if ( + prefix in self.DATE_COMPONENTS + and custom_field.data_type == CustomField.FieldDataType.DATE + ): + pass # ok - e.g., "year__exact" for date field + else: + supported = False # anything else is invalid + + if not supported: + raise serializers.ValidationError( + [ + _("{data_type} does not support lookup expr {expr!r}.").format( + data_type=custom_field.data_type, + expr=raw_op, + ), + ], + ) + + return raw_op + + def _get_serializer_field(self, custom_field, full_op): + """Return a serializers.Field for value validation.""" + prefix, op = self._split_op(full_op) + field = None + + if op in ("isnull", "exists"): + # `isnull` takes either True or False regardless of the data_type. + field = serializers.BooleanField() + elif ( + custom_field.data_type == CustomField.FieldDataType.DATE + and prefix in self.DATE_COMPONENTS + ): + # DateField admits lookups in the form of `year__exact`, etc. These take integers. + field = serializers.IntegerField() + elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK: + # We can be more specific here and make sure the value is a list. + field = serializers.ListField(child=serializers.IntegerField()) + elif custom_field.data_type == CustomField.FieldDataType.SELECT: + # We use this custom field to permit SELECT option names. + field = SelectField(custom_field) + elif custom_field.data_type == CustomField.FieldDataType.URL: + # For URL fields we don't need to be strict about validation (e.g., for istartswith). + field = serializers.CharField() + else: + # The general case: inferred from the corresponding field in CustomFieldInstance. + value_field_name = CustomFieldInstance.get_value_field_name( + custom_field.data_type, + ) + model_field = CustomFieldInstance._meta.get_field(value_field_name) + field_name = model_field.deconstruct()[0] + field_class, field_kwargs = self._model_serializer.build_standard_field( + field_name, + model_field, + ) + field = field_class(**field_kwargs) + field.allow_null = False + + # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation. + # See https://github.com/paperless-ngx/paperless-ngx/issues/7361. + if isinstance(field, serializers.CharField): + field.allow_blank = True + + if op == "in": + # `in` takes a list of values. + field = serializers.ListField(child=field, allow_empty=False) + elif op == "range": + # `range` takes a list of values, i.e., [start, end]. + field = serializers.ListField( + child=field, + min_length=2, + max_length=2, + ) + + return field + + @handle_validation_prefix + def _validate_atom_value(self, custom_field, op, value): + """Check if `value` is valid for the custom field and `op`. Returns the validated value.""" + serializer_field = self._get_serializer_field(custom_field, op) + return serializer_field.run_validation(value) + + def _parse_atom_doc_link_contains(self, custom_field, value) -> Q: + """ + Handles document link `contains` in a way that is supported by all DB backends. + """ + + # If the value is an empty set, + # this is trivially true for any document with not null document links. + if not value: + return Q( + custom_fields__field=custom_field, + custom_fields__value_document_ids__isnull=False, + ) + + # First we lookup reverse links from the requested documents. + links = CustomFieldInstance.objects.filter( + document_id__in=value, + field__data_type=CustomField.FieldDataType.DOCUMENTLINK, + ) + + # Check if any of the requested IDs are missing. + missing_ids = set(value) - set(link.document_id for link in links) + if missing_ids: + # The result should be an empty set in this case. + return Q(id__in=[]) + + # Take the intersection of the reverse links - this should be what we are looking for. + document_ids_we_want = functools.reduce( + operator.and_, + (set(link.value_document_ids) for link in links), + ) + + return Q(id__in=document_ids_we_want) + + @contextmanager + def _track_query_depth(self): + # guard against queries that are too deeply nested + self._current_depth += 1 + if self._current_depth > self._max_query_depth: + raise serializers.ValidationError( + [ + _( + "Maximum nesting depth exceeded. You can raise the limit " + "by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_DEPTH in " + "your configuration file.", + ), + ], + ) + try: + yield + finally: + self._current_depth -= 1 + + +class CustomFieldLookupFilter(Filter): + def __init__(self, validation_prefix): + """ + A filter that filters documents based on custom field name and value. + + Args: + validation_prefix: Used to generate the ValidationError message. + """ + super().__init__() + self._validation_prefix = validation_prefix + + def filter(self, qs, value): + if not value: + return qs + + parser = CustomFieldLookupParser( + self._validation_prefix, + max_query_depth=settings.CUSTOM_FIELD_LOOKUP_MAX_DEPTH, + max_atom_count=settings.CUSTOM_FIELD_LOOKUP_MAX_ATOMS, + ) + q, annotations = parser.parse(value) + + return qs.annotate(**annotations).filter(q) + + class DocumentFilterSet(FilterSet): is_tagged = BooleanFilter( label="Is tagged", @@ -229,6 +723,8 @@ class DocumentFilterSet(FilterSet): exclude=True, ) + custom_field_lookup = CustomFieldLookupFilter("custom_field_lookup") + shared_by__id = SharedByUser() class Meta: diff --git a/src/documents/models.py b/src/documents/models.py index 3ee11aeba..24e8c2b26 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -857,6 +857,18 @@ class CustomFieldInstance(models.Model): and attached to a single Document to be metadata for it """ + TYPE_TO_DATA_STORE_NAME_MAP = { + CustomField.FieldDataType.STRING: "value_text", + CustomField.FieldDataType.URL: "value_url", + CustomField.FieldDataType.DATE: "value_date", + CustomField.FieldDataType.BOOL: "value_bool", + CustomField.FieldDataType.INT: "value_int", + CustomField.FieldDataType.FLOAT: "value_float", + CustomField.FieldDataType.MONETARY: "value_monetary", + CustomField.FieldDataType.DOCUMENTLINK: "value_document_ids", + CustomField.FieldDataType.SELECT: "value_select", + } + created = models.DateTimeField( _("created"), default=timezone.now, @@ -923,31 +935,21 @@ class CustomFieldInstance(models.Model): ) return str(self.field.name) + f" : {value}" + @classmethod + def get_value_field_name(cls, data_type: CustomField.FieldDataType): + try: + return cls.TYPE_TO_DATA_STORE_NAME_MAP[data_type] + except KeyError: # pragma: no cover + raise NotImplementedError(data_type) + @property def value(self): """ Based on the data type, access the actual value the instance stores A little shorthand/quick way to get what is actually here """ - if self.field.data_type == CustomField.FieldDataType.STRING: - return self.value_text - elif self.field.data_type == CustomField.FieldDataType.URL: - return self.value_url - elif self.field.data_type == CustomField.FieldDataType.DATE: - return self.value_date - elif self.field.data_type == CustomField.FieldDataType.BOOL: - return self.value_bool - elif self.field.data_type == CustomField.FieldDataType.INT: - return self.value_int - elif self.field.data_type == CustomField.FieldDataType.FLOAT: - return self.value_float - elif self.field.data_type == CustomField.FieldDataType.MONETARY: - return self.value_monetary - elif self.field.data_type == CustomField.FieldDataType.DOCUMENTLINK: - return self.value_document_ids - elif self.field.data_type == CustomField.FieldDataType.SELECT: - return self.value_select - raise NotImplementedError(self.field.data_type) + value_field_name = self.get_value_field_name(self.field.data_type) + return getattr(self, value_field_name) if settings.AUDIT_LOG_ENABLED: diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 747d744b6..5218cbf8a 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -578,23 +578,14 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer): value = ReadWriteSerializerMethodField(allow_null=True) def create(self, validated_data): - type_to_data_store_name_map = { - CustomField.FieldDataType.STRING: "value_text", - CustomField.FieldDataType.URL: "value_url", - CustomField.FieldDataType.DATE: "value_date", - CustomField.FieldDataType.BOOL: "value_bool", - CustomField.FieldDataType.INT: "value_int", - CustomField.FieldDataType.FLOAT: "value_float", - CustomField.FieldDataType.MONETARY: "value_monetary", - CustomField.FieldDataType.DOCUMENTLINK: "value_document_ids", - CustomField.FieldDataType.SELECT: "value_select", - } # An instance is attached to a document document: Document = validated_data["document"] # And to a CustomField custom_field: CustomField = validated_data["field"] # This key must exist, as it is validated - data_store_name = type_to_data_store_name_map[custom_field.data_type] + data_store_name = CustomFieldInstance.get_value_field_name( + custom_field.data_type, + ) if custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK: # prior to update so we can look for any docs that are going to be removed diff --git a/src/documents/tests/test_api_filter_by_custom_fields.py b/src/documents/tests/test_api_filter_by_custom_fields.py new file mode 100644 index 000000000..9ab96d068 --- /dev/null +++ b/src/documents/tests/test_api_filter_by_custom_fields.py @@ -0,0 +1,670 @@ +import json +import re +from datetime import date +from typing import Callable +from unittest.mock import Mock +from urllib.parse import quote + +import pytest +from django.contrib.auth.models import User +from rest_framework.test import APITestCase + +from documents.models import CustomField +from documents.models import Document +from documents.serialisers import DocumentSerializer +from documents.tests.utils import DirectoriesMixin +from paperless import settings + + +class DocumentWrapper: + """ + Allows Pythonic access to the custom fields associated with the wrapped document. + """ + + def __init__(self, document: Document) -> None: + self._document = document + + def __contains__(self, custom_field: str) -> bool: + return self._document.custom_fields.filter(field__name=custom_field).exists() + + def __getitem__(self, custom_field: str): + return self._document.custom_fields.get(field__name=custom_field).value + + +def string_expr_opted_in(op): + return op in settings.CUSTOM_FIELD_LOOKUP_OPT_IN + + +class TestDocumentSearchApi(DirectoriesMixin, APITestCase): + def setUp(self): + super().setUp() + + self.user = User.objects.create_superuser(username="temp_admin") + self.client.force_authenticate(user=self.user) + + # Create one custom field per type. The fields are called f"{type}_field". + self.custom_fields = {} + for data_type in CustomField.FieldDataType.values: + name = data_type + "_field" + self.custom_fields[name] = CustomField.objects.create( + name=name, + data_type=data_type, + ) + + # Add some options to the select_field + select = self.custom_fields["select_field"] + select.extra_data = {"select_options": ["A", "B", "C"]} + select.save() + + # Now we will create some test documents + self.documents = [] + + # CustomField.FieldDataType.STRING + self._create_document(string_field=None) + self._create_document(string_field="") + self._create_document(string_field="paperless") + self._create_document(string_field="Paperless") + self._create_document(string_field="PAPERLESS") + self._create_document(string_field="pointless") + self._create_document(string_field="pointy") + + # CustomField.FieldDataType.URL + self._create_document(url_field=None) + self._create_document(url_field="") + self._create_document(url_field="https://docs.paperless-ngx.com/") + self._create_document(url_field="https://www.django-rest-framework.org/") + self._create_document(url_field="http://example.com/") + + # A document to check if the filter correctly associates field names with values. + # E.g., ["url_field", "exact", "https://docs.paperless-ngx.com/"] should not + # yield this document. + self._create_document( + string_field="https://docs.paperless-ngx.com/", + url_field="http://example.com/", + ) + + # CustomField.FieldDataType.DATE + self._create_document(date_field=None) + self._create_document(date_field=date(2023, 8, 22)) + self._create_document(date_field=date(2024, 8, 22)) + self._create_document(date_field=date(2024, 11, 15)) + + # CustomField.FieldDataType.BOOL + self._create_document(boolean_field=None) + self._create_document(boolean_field=True) + self._create_document(boolean_field=False) + + # CustomField.FieldDataType.INT + self._create_document(integer_field=None) + self._create_document(integer_field=-1) + self._create_document(integer_field=0) + self._create_document(integer_field=1) + + # CustomField.FieldDataType.FLOAT + self._create_document(float_field=None) + self._create_document(float_field=-1e9) + self._create_document(float_field=0.05) + self._create_document(float_field=270.0) + + # CustomField.FieldDataType.MONETARY + self._create_document(monetary_field=None) + self._create_document(monetary_field="USD100.00") + self._create_document(monetary_field="USD1.00") + self._create_document(monetary_field="EUR50.00") + + # CustomField.FieldDataType.DOCUMENTLINK + self._create_document(documentlink_field=None) + self._create_document(documentlink_field=[]) + self._create_document( + documentlink_field=[ + self.documents[0].id, + self.documents[1].id, + self.documents[2].id, + ], + ) + self._create_document( + documentlink_field=[self.documents[4].id, self.documents[5].id], + ) + + # CustomField.FieldDataType.SELECT + self._create_document(select_field=None) + self._create_document(select_field=0) + self._create_document(select_field=1) + self._create_document(select_field=2) + + def _create_document(self, **kwargs): + title = str(kwargs) + document = Document.objects.create( + title=title, + checksum=title, + archive_serial_number=len(self.documents) + 1, + ) + data = { + "custom_fields": [ + {"field": self.custom_fields[name].id, "value": value} + for name, value in kwargs.items() + ], + } + serializer = DocumentSerializer( + document, + data=data, + partial=True, + context={"request": Mock()}, + ) + serializer.is_valid(raise_exception=True) + serializer.save() + self.documents.append(document) + return document + + def _assert_query_match_predicate( + self, + query: list, + reference_predicate: Callable[[DocumentWrapper], bool], + match_nothing_ok=False, + ): + """ + Checks the results of the query against a callable reference predicate. + """ + reference_document_ids = [ + document.id + for document in self.documents + if reference_predicate(DocumentWrapper(document)) + ] + # First sanity check our test cases + if not match_nothing_ok: + self.assertTrue( + reference_document_ids, + msg="Bad test case - should match at least one document.", + ) + self.assertNotEqual( + len(reference_document_ids), + len(self.documents), + msg="Bad test case - should not match all documents.", + ) + + # Now make the API call. + query_string = quote(json.dumps(query), safe="") + response = self.client.get( + "/api/documents/?" + + "&".join( + ( + f"custom_field_lookup={query_string}", + "ordering=archive_serial_number", + "page=1", + f"page_size={len(self.documents)}", + "truncate_content=true", + ), + ), + ) + self.assertEqual(response.status_code, 200, msg=str(response.json())) + response_document_ids = [ + document["id"] for document in response.json()["results"] + ] + self.assertEqual(reference_document_ids, response_document_ids) + + def _assert_validation_error(self, query: str, path: list, keyword: str): + """ + Asserts that the query raises a validation error. + Checks the message to make sure it points to the right place. + """ + query_string = quote(query, safe="") + response = self.client.get( + "/api/documents/?" + + "&".join( + ( + f"custom_field_lookup={query_string}", + "ordering=archive_serial_number", + "page=1", + f"page_size={len(self.documents)}", + "truncate_content=true", + ), + ), + ) + self.assertEqual(response.status_code, 400) + + exception_path = [] + detail = response.json() + while not isinstance(detail, list): + path_item, detail = next(iter(detail.items())) + exception_path.append(path_item) + + self.assertEqual(path, exception_path) + self.assertIn(keyword, " ".join(detail)) + + # ==========================================================# + # Sanity checks # + # ==========================================================# + def test_name_value_association(self): + """ + GIVEN: + - A document with `{"string_field": "https://docs.paperless-ngx.com/", + "url_field": "http://example.com/"}` + WHEN: + - Filtering by `["url_field", "exact", "https://docs.paperless-ngx.com/"]` + THEN: + - That document should not get matched. + """ + self._assert_query_match_predicate( + ["url_field", "exact", "https://docs.paperless-ngx.com/"], + lambda document: "url_field" in document + and document["url_field"] == "https://docs.paperless-ngx.com/", + ) + + def test_filter_by_multiple_fields(self): + """ + GIVEN: + - A document with `{"string_field": "https://docs.paperless-ngx.com/", + "url_field": "http://example.com/"}` + WHEN: + - Filtering by `['AND', [["string_field", "exists", True], ["url_field", "exists", True]]]` + THEN: + - That document should get matched. + """ + self._assert_query_match_predicate( + ["AND", [["string_field", "exists", True], ["url_field", "exists", True]]], + lambda document: "url_field" in document and "string_field" in document, + ) + + # ==========================================================# + # Basic expressions supported by all custom field types # + # ==========================================================# + def test_exact(self): + self._assert_query_match_predicate( + ["string_field", "exact", "paperless"], + lambda document: "string_field" in document + and document["string_field"] == "paperless", + ) + + def test_in(self): + self._assert_query_match_predicate( + ["string_field", "in", ["paperless", "Paperless"]], + lambda document: "string_field" in document + and document["string_field"] in ("paperless", "Paperless"), + ) + + def test_isnull(self): + self._assert_query_match_predicate( + ["string_field", "isnull", True], + lambda document: "string_field" in document + and document["string_field"] is None, + ) + + def test_exists(self): + self._assert_query_match_predicate( + ["string_field", "exists", True], + lambda document: "string_field" in document, + ) + + def test_select(self): + # For select fields, you can either specify the index + # or the name of the option. They function exactly the same. + self._assert_query_match_predicate( + ["select_field", "exact", 1], + lambda document: "select_field" in document + and document["select_field"] == 1, + ) + # This is the same as: + self._assert_query_match_predicate( + ["select_field", "exact", "B"], + lambda document: "select_field" in document + and document["select_field"] == 1, + ) + + # ==========================================================# + # Expressions for string, URL, and monetary fields # + # ==========================================================# + @pytest.mark.skipif( + not string_expr_opted_in("iexact"), + reason="iexact expr is disabled.", + ) + def test_iexact(self): + self._assert_query_match_predicate( + ["string_field", "iexact", "paperless"], + lambda document: "string_field" in document + and document["string_field"] is not None + and document["string_field"].lower() == "paperless", + ) + + @pytest.mark.skipif( + not string_expr_opted_in("contains"), + reason="contains expr is disabled.", + ) + def test_contains(self): + # WARNING: SQLite treats "contains" as "icontains"! + # You should avoid "contains" unless you know what you are doing! + self._assert_query_match_predicate( + ["string_field", "contains", "aper"], + lambda document: "string_field" in document + and document["string_field"] is not None + and "aper" in document["string_field"], + ) + + def test_icontains(self): + self._assert_query_match_predicate( + ["string_field", "icontains", "aper"], + lambda document: "string_field" in document + and document["string_field"] is not None + and "aper" in document["string_field"].lower(), + ) + + @pytest.mark.skipif( + not string_expr_opted_in("startswith"), + reason="startswith expr is disabled.", + ) + def test_startswith(self): + # WARNING: SQLite treats "startswith" as "istartswith"! + # You should avoid "startswith" unless you know what you are doing! + self._assert_query_match_predicate( + ["string_field", "startswith", "paper"], + lambda document: "string_field" in document + and document["string_field"] is not None + and document["string_field"].startswith("paper"), + ) + + def test_istartswith(self): + self._assert_query_match_predicate( + ["string_field", "istartswith", "paper"], + lambda document: "string_field" in document + and document["string_field"] is not None + and document["string_field"].lower().startswith("paper"), + ) + + @pytest.mark.skipif( + not string_expr_opted_in("endswith"), + reason="endswith expr is disabled.", + ) + def test_endswith(self): + # WARNING: SQLite treats "endswith" as "iendswith"! + # You should avoid "endswith" unless you know what you are doing! + self._assert_query_match_predicate( + ["string_field", "iendswith", "less"], + lambda document: "string_field" in document + and document["string_field"] is not None + and document["string_field"].lower().endswith("less"), + ) + + def test_iendswith(self): + self._assert_query_match_predicate( + ["string_field", "iendswith", "less"], + lambda document: "string_field" in document + and document["string_field"] is not None + and document["string_field"].lower().endswith("less"), + ) + + @pytest.mark.skipif( + not string_expr_opted_in("regex"), + reason="regex expr is disabled.", + ) + def test_regex(self): + # WARNING: the regex syntax is database dependent! + self._assert_query_match_predicate( + ["string_field", "regex", r"^p.+s$"], + lambda document: "string_field" in document + and document["string_field"] is not None + and re.match(r"^p.+s$", document["string_field"]), + ) + + @pytest.mark.skipif( + not string_expr_opted_in("iregex"), + reason="iregex expr is disabled.", + ) + def test_iregex(self): + # WARNING: the regex syntax is database dependent! + self._assert_query_match_predicate( + ["string_field", "iregex", r"^p.+s$"], + lambda document: "string_field" in document + and document["string_field"] is not None + and re.match(r"^p.+s$", document["string_field"], re.IGNORECASE), + ) + + def test_url_field_istartswith(self): + # URL fields supports all of the expressions above. + # Just showing one of them here. + self._assert_query_match_predicate( + ["url_field", "istartswith", "http://"], + lambda document: "url_field" in document + and document["url_field"] is not None + and document["url_field"].startswith("http://"), + ) + + @pytest.mark.skipif( + not string_expr_opted_in("iregex"), + reason="regex expr is disabled.", + ) + def test_monetary_field_iregex(self): + # Monetary fields supports all of the expressions above. + # Just showing one of them here. + # + # Unfortunately we can't do arithmetic comparisons on monetary field, + # but you are welcome to use regex to do some of that. + # E.g., USD between 100.00 and 999.99: + self._assert_query_match_predicate( + ["monetary_field", "regex", r"USD[1-9][0-9]{2}\.[0-9]{2}"], + lambda document: "monetary_field" in document + and document["monetary_field"] is not None + and re.match( + r"USD[1-9][0-9]{2}\.[0-9]{2}", + document["monetary_field"], + re.IGNORECASE, + ), + ) + + # ==========================================================# + # Arithmetic comparisons # + # ==========================================================# + def test_gt(self): + self._assert_query_match_predicate( + ["date_field", "gt", date(2024, 8, 22).isoformat()], + lambda document: "date_field" in document + and document["date_field"] is not None + and document["date_field"] > date(2024, 8, 22), + ) + + def test_gte(self): + self._assert_query_match_predicate( + ["date_field", "gte", date(2024, 8, 22).isoformat()], + lambda document: "date_field" in document + and document["date_field"] is not None + and document["date_field"] >= date(2024, 8, 22), + ) + + def test_lt(self): + self._assert_query_match_predicate( + ["integer_field", "lt", 0], + lambda document: "integer_field" in document + and document["integer_field"] is not None + and document["integer_field"] < 0, + ) + + def test_lte(self): + self._assert_query_match_predicate( + ["integer_field", "lte", 0], + lambda document: "integer_field" in document + and document["integer_field"] is not None + and document["integer_field"] <= 0, + ) + + def test_range(self): + self._assert_query_match_predicate( + ["float_field", "range", [-0.05, 0.05]], + lambda document: "float_field" in document + and document["float_field"] is not None + and -0.05 <= document["float_field"] <= 0.05, + ) + + def test_date_modifier(self): + # For date fields you can optionally prefix the operator + # with the part of the date you are comparing with. + self._assert_query_match_predicate( + ["date_field", "year__gte", 2024], + lambda document: "date_field" in document + and document["date_field"] is not None + and document["date_field"].year >= 2024, + ) + + # ==========================================================# + # Subset check (document link field only) # + # ==========================================================# + def test_document_link_contains(self): + # Document link field "contains" performs a subset check. + self._assert_query_match_predicate( + ["documentlink_field", "contains", [1, 2]], + lambda document: "documentlink_field" in document + and document["documentlink_field"] is not None + and set(document["documentlink_field"]) >= {1, 2}, + ) + # The order of IDs don't matter - this is the same as above. + self._assert_query_match_predicate( + ["documentlink_field", "contains", [2, 1]], + lambda document: "documentlink_field" in document + and document["documentlink_field"] is not None + and set(document["documentlink_field"]) >= {1, 2}, + ) + + def test_document_link_contains_empty_set(self): + # An empty set is a subset of any set. + self._assert_query_match_predicate( + ["documentlink_field", "contains", []], + lambda document: "documentlink_field" in document + and document["documentlink_field"] is not None, + ) + + def test_document_link_contains_no_reverse_link(self): + # An edge case is that the document in the value list + # doesn't have a document link field and thus has no reverse link. + self._assert_query_match_predicate( + ["documentlink_field", "contains", [self.documents[6].id]], + lambda document: "documentlink_field" in document + and document["documentlink_field"] is not None + and set(document["documentlink_field"]) >= {self.documents[6].id}, + match_nothing_ok=True, + ) + + # ==========================================================# + # Logical expressions # + # ==========================================================# + def test_logical_and(self): + self._assert_query_match_predicate( + [ + "AND", + [["date_field", "year__exact", 2024], ["date_field", "month__lt", 9]], + ], + lambda document: "date_field" in document + and document["date_field"] is not None + and document["date_field"].year == 2024 + and document["date_field"].month < 9, + ) + + def test_logical_or(self): + # This is also the recommend way to check for "empty" text, URL, and monetary fields. + self._assert_query_match_predicate( + [ + "OR", + [["string_field", "exact", ""], ["string_field", "isnull", True]], + ], + lambda document: "string_field" in document + and not bool(document["string_field"]), + ) + + def test_logical_not(self): + # This means `NOT ((document has string_field) AND (string_field iexact "paperless"))`, + # not `(document has string_field) AND (NOT (string_field iexact "paperless"))`! + self._assert_query_match_predicate( + [ + "NOT", + ["string_field", "exact", "paperless"], + ], + lambda document: not ( + "string_field" in document and document["string_field"] == "paperless" + ), + ) + + # ==========================================================# + # Tests for invalid queries # + # ==========================================================# + + def test_invalid_json(self): + self._assert_validation_error( + "not valid json", + ["custom_field_lookup"], + "must be valid JSON", + ) + + def test_invalid_expression(self): + self._assert_validation_error( + json.dumps("valid json but not valid expr"), + ["custom_field_lookup"], + "Invalid custom field lookup expression", + ) + + def test_invalid_custom_field_name(self): + self._assert_validation_error( + json.dumps(["invalid name", "iexact", "foo"]), + ["custom_field_lookup", "0"], + "is not a valid custom field", + ) + + def test_invalid_operator(self): + self._assert_validation_error( + json.dumps(["integer_field", "iexact", "foo"]), + ["custom_field_lookup", "1"], + "does not support lookup expr", + ) + + def test_invalid_value(self): + self._assert_validation_error( + json.dumps(["select_field", "exact", "not an option"]), + ["custom_field_lookup", "2"], + "integer", + ) + + def test_invalid_logical_operator(self): + self._assert_validation_error( + json.dumps(["invalid op", ["integer_field", "gt", 0]]), + ["custom_field_lookup", "0"], + "Invalid logical operator", + ) + + def test_invalid_expr_list(self): + self._assert_validation_error( + json.dumps(["AND", "not a list"]), + ["custom_field_lookup", "1"], + "Invalid expression list", + ) + + def test_invalid_operator_prefix(self): + self._assert_validation_error( + json.dumps(["integer_field", "foo__gt", 0]), + ["custom_field_lookup", "1"], + "does not support lookup expr", + ) + + @pytest.mark.skipif( + string_expr_opted_in("regex"), + reason="user opted into allowing regex expr", + ) + def test_disabled_operator(self): + self._assert_validation_error( + json.dumps(["string_field", "regex", r"^p.+s$"]), + ["custom_field_lookup", "1"], + "disabled by default", + ) + + def test_query_too_deep(self): + query = ["string_field", "exact", "paperless"] + for _ in range(10): + query = ["NOT", query] + self._assert_validation_error( + json.dumps(query), + ["custom_field_lookup", *(["1"] * 10)], + "Maximum nesting depth exceeded", + ) + + def test_query_too_many_atoms(self): + atom = ["string_field", "exact", "paperless"] + query = ["AND", [atom for _ in range(21)]] + self._assert_validation_error( + json.dumps(query), + ["custom_field_lookup", "1", "20"], + "Maximum number of query conditions exceeded", + ) diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 46a697349..851fe6217 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -1192,6 +1192,23 @@ EMAIL_ENABLE_GPG_DECRYPTOR: Final[bool] = __get_boolean( ############################################################################### -# Soft Delete +# Soft Delete # ############################################################################### EMPTY_TRASH_DELAY = max(__get_int("PAPERLESS_EMPTY_TRASH_DELAY", 30), 1) + +############################################################################### +# custom_field_lookup Filter Settings # +############################################################################### + +CUSTOM_FIELD_LOOKUP_OPT_IN = __get_list( + "PAPERLESS_CUSTOM_FIELD_LOOKUP_OPT_IN", + default=[], +) +CUSTOM_FIELD_LOOKUP_MAX_DEPTH = __get_int( + "PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_DEPTH", + default=10, +) +CUSTOM_FIELD_LOOKUP_MAX_ATOMS = __get_int( + "PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_ATOMS", + default=20, +)