From a0c227fe55400a3cf09ca51492cd18d69302cf6f Mon Sep 17 00:00:00 2001 From: Yichi Yang Date: Sun, 25 Aug 2024 12:20:43 +0800 Subject: [PATCH] Refactor: Use django-filter logic for filtering full text search queries (#7507) --- src/documents/index.py | 147 ++++++--------------- src/documents/tests/test_api_search.py | 36 ++++- src/documents/tests/test_delayedquery.py | 161 ----------------------- src/documents/views.py | 32 +++-- 4 files changed, 86 insertions(+), 290 deletions(-) diff --git a/src/documents/index.py b/src/documents/index.py index 98c43d1e8..d95a80213 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -8,8 +8,8 @@ from datetime import timezone from shutil import rmtree from typing import Optional -from dateutil.parser import isoparse from django.conf import settings +from django.db.models import QuerySet from django.utils import timezone as django_timezone from guardian.shortcuts import get_users_with_perms from whoosh import classify @@ -22,6 +22,8 @@ from whoosh.fields import NUMERIC from whoosh.fields import TEXT from whoosh.fields import Schema from whoosh.highlight import HtmlFormatter +from whoosh.idsets import BitSet +from whoosh.idsets import DocIdSet from whoosh.index import FileIndex from whoosh.index import create_in from whoosh.index import exists_in @@ -31,6 +33,7 @@ from whoosh.qparser import QueryParser from whoosh.qparser.dateparse import DateParserPlugin from whoosh.qparser.dateparse import English from whoosh.qparser.plugins import FieldsPlugin +from whoosh.reading import IndexReader from whoosh.scoring import TF_IDF from whoosh.searching import ResultsPage from whoosh.searching import Searcher @@ -201,114 +204,32 @@ def remove_document_from_index(document: Document): remove_document(writer, document) +class MappedDocIdSet(DocIdSet): + """ + A DocIdSet backed by a set of `Document` IDs. + Supports efficiently looking up if a whoosh docnum is in the provided `filter_queryset`. + """ + + def __init__(self, filter_queryset: QuerySet, ixreader: IndexReader) -> None: + super().__init__() + document_ids = filter_queryset.order_by("id").values_list("id", flat=True) + max_id = document_ids.last() or 0 + self.document_ids = BitSet(document_ids, size=max_id) + self.ixreader = ixreader + + def __contains__(self, docnum): + document_id = self.ixreader.stored_fields(docnum)["id"] + return document_id in self.document_ids + + def __bool__(self): + # searcher.search ignores a filter if it's "falsy". + # We use this hack so this DocIdSet, when used as a filter, is never ignored. + return True + + class DelayedQuery: - param_map = { - "correspondent": ("correspondent", ["id", "id__in", "id__none", "isnull"]), - "document_type": ("type", ["id", "id__in", "id__none", "isnull"]), - "storage_path": ("path", ["id", "id__in", "id__none", "isnull"]), - "owner": ("owner", ["id", "id__in", "id__none", "isnull"]), - "shared_by": ("shared_by", ["id"]), - "tags": ("tag", ["id__all", "id__in", "id__none"]), - "added": ("added", ["date__lt", "date__gt"]), - "created": ("created", ["date__lt", "date__gt"]), - "checksum": ("checksum", ["icontains", "istartswith"]), - "original_filename": ("original_filename", ["icontains", "istartswith"]), - "custom_fields": ( - "custom_fields", - ["icontains", "istartswith", "id__all", "id__in", "id__none"], - ), - } - def _get_query(self): - raise NotImplementedError - - def _get_query_filter(self): - criterias = [] - for key, value in self.query_params.items(): - # is_tagged is a special case - if key == "is_tagged": - criterias.append(query.Term("has_tag", self.evalBoolean(value))) - continue - - if key == "has_custom_fields": - criterias.append( - query.Term("has_custom_fields", self.evalBoolean(value)), - ) - continue - - # Don't process query params without a filter - if "__" not in key: - continue - - # All other query params consist of a parameter and a query filter - param, query_filter = key.split("__", 1) - try: - field, supported_query_filters = self.param_map[param] - except KeyError: - logger.error(f"Unable to build a query filter for parameter {key}") - continue - - # We only support certain filters per parameter - if query_filter not in supported_query_filters: - logger.info( - f"Query filter {query_filter} not supported for parameter {param}", - ) - continue - - if query_filter == "id": - if param == "shared_by": - criterias.append(query.Term("is_shared", True)) - criterias.append(query.Term("owner_id", value)) - else: - criterias.append(query.Term(f"{field}_id", value)) - elif query_filter == "id__in": - in_filter = [] - for object_id in value.split(","): - in_filter.append( - query.Term(f"{field}_id", object_id), - ) - criterias.append(query.Or(in_filter)) - elif query_filter == "id__none": - for object_id in value.split(","): - criterias.append( - query.Not(query.Term(f"{field}_id", object_id)), - ) - elif query_filter == "isnull": - criterias.append( - query.Term(f"has_{field}", self.evalBoolean(value) is False), - ) - elif query_filter == "id__all": - for object_id in value.split(","): - criterias.append(query.Term(f"{field}_id", object_id)) - elif query_filter == "date__lt": - criterias.append( - query.DateRange(field, start=None, end=isoparse(value)), - ) - elif query_filter == "date__gt": - criterias.append( - query.DateRange(field, start=isoparse(value), end=None), - ) - elif query_filter == "icontains": - criterias.append( - query.Term(field, value), - ) - elif query_filter == "istartswith": - criterias.append( - query.Prefix(field, value), - ) - - user_criterias = get_permissions_criterias( - user=self.user, - ) - if len(criterias) > 0: - if len(user_criterias) > 0: - criterias.append(query.Or(user_criterias)) - return query.And(criterias) - else: - return query.Or(user_criterias) if len(user_criterias) > 0 else None - - def evalBoolean(self, val): - return val.lower() in {"true", "1"} + raise NotImplementedError # pragma: no cover def _get_query_sortedby(self): if "ordering" not in self.query_params: @@ -339,13 +260,19 @@ class DelayedQuery: else: return sort_fields_map[field], reverse - def __init__(self, searcher: Searcher, query_params, page_size, user): + def __init__( + self, + searcher: Searcher, + query_params, + page_size, + filter_queryset: QuerySet, + ): self.searcher = searcher self.query_params = query_params self.page_size = page_size self.saved_results = dict() self.first_score = None - self.user = user + self.filter_queryset = filter_queryset def __len__(self): page = self[0:1] @@ -361,7 +288,7 @@ class DelayedQuery: page: ResultsPage = self.searcher.search_page( q, mask=mask, - filter=self._get_query_filter(), + filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), pagenum=math.floor(item.start / self.page_size) + 1, pagelen=self.page_size, sortedby=sortedby, diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index c10d6c1bb..e524e7b91 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -15,6 +15,7 @@ from rest_framework.test import APITestCase from whoosh.writing import AsyncWriter from documents import index +from documents.bulk_edit import set_permissions from documents.models import Correspondent from documents.models import CustomField from documents.models import CustomFieldInstance @@ -1159,7 +1160,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): [d3.id, d2.id, d1.id], ) - def test_global_search(self): + @mock.patch("documents.bulk_edit.bulk_update_documents") + def test_global_search(self, m): """ GIVEN: - Multiple documents and objects @@ -1186,11 +1188,38 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): checksum="C", pk=3, ) + # The below two documents are owned by user2 and shouldn't show up in results! + d4 = Document.objects.create( + title="doc 4 owned by user2", + content="bank bank bank bank 4", + checksum="D", + pk=4, + ) + d5 = Document.objects.create( + title="doc 5 owned by user2", + content="bank bank bank bank 5", + checksum="E", + pk=5, + ) + + user1 = User.objects.create_user("bank user1") + user2 = User.objects.create_superuser("user2") + group1 = Group.objects.create(name="bank group1") + Group.objects.create(name="group2") + + user1.user_permissions.add( + *Permission.objects.filter(codename__startswith="view_").exclude( + content_type__app_label="admin", + ), + ) + set_permissions([4, 5], set_permissions=[], owner=user2, merge=False) with index.open_index_writer() as writer: index.update_document(writer, d1) index.update_document(writer, d2) index.update_document(writer, d3) + index.update_document(writer, d4) + index.update_document(writer, d5) correspondent1 = Correspondent.objects.create(name="bank correspondent 1") Correspondent.objects.create(name="correspondent 2") @@ -1200,10 +1229,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): StoragePath.objects.create(name="path 2", path="path2") tag1 = Tag.objects.create(name="bank tag1") Tag.objects.create(name="tag2") - user1 = User.objects.create_superuser("bank user1") - User.objects.create_user("user2") - group1 = Group.objects.create(name="bank group1") - Group.objects.create(name="group2") + SavedView.objects.create( name="bank view", show_on_dashboard=True, diff --git a/src/documents/tests/test_delayedquery.py b/src/documents/tests/test_delayedquery.py index b0dfc2ed2..1895bd6c6 100644 --- a/src/documents/tests/test_delayedquery.py +++ b/src/documents/tests/test_delayedquery.py @@ -1,8 +1,6 @@ -from dateutil.parser import isoparse from django.test import TestCase from whoosh import query -from documents.index import DelayedQuery from documents.index import get_permissions_criterias from documents.models import User @@ -58,162 +56,3 @@ class TestDelayedQuery(TestCase): ) for user, expected in tests: self.assertEqual(get_permissions_criterias(user), expected) - - def test_no_query_filters(self): - dq = DelayedQuery(None, {}, None, None) - self.assertEqual(dq._get_query_filter(), self.has_no_owner) - - def test_date_query_filters(self): - def _get_testset(param: str): - date_str = "1970-01-01T02:44" - date_obj = isoparse(date_str) - return ( - ( - {f"{param}__date__lt": date_str}, - query.And( - [ - query.DateRange(param, start=None, end=date_obj), - self.has_no_owner, - ], - ), - ), - ( - {f"{param}__date__gt": date_str}, - query.And( - [ - query.DateRange(param, start=date_obj, end=None), - self.has_no_owner, - ], - ), - ), - ) - - query_params = ["created", "added"] - for param in query_params: - for params, expected in _get_testset(param): - dq = DelayedQuery(None, params, None, None) - got = dq._get_query_filter() - self.assertCountEqual(got, expected) - - def test_is_tagged_query_filter(self): - tests = ( - ("True", True), - ("true", True), - ("1", True), - ("False", False), - ("false", False), - ("0", False), - ("foo", False), - ) - for param, expected in tests: - dq = DelayedQuery(None, {"is_tagged": param}, None, None) - self.assertEqual( - dq._get_query_filter(), - query.And([query.Term("has_tag", expected), self.has_no_owner]), - ) - - def test_tags_query_filters(self): - # tests contains tuples of query_parameter dics and the expected whoosh query - param = "tags" - field, _ = DelayedQuery.param_map[param] - tests = ( - ( - {f"{param}__id__all": "42,43"}, - query.And( - [ - query.Term(f"{field}_id", "42"), - query.Term(f"{field}_id", "43"), - self.has_no_owner, - ], - ), - ), - # tags does not allow __id - ( - {f"{param}__id": "42"}, - self.has_no_owner, - ), - # tags does not allow __isnull - ( - {f"{param}__isnull": "true"}, - self.has_no_owner, - ), - self._get_testset__id__in(param, field), - self._get_testset__id__none(param, field), - ) - - for params, expected in tests: - dq = DelayedQuery(None, params, None, None) - got = dq._get_query_filter() - self.assertCountEqual(got, expected) - - def test_generic_query_filters(self): - def _get_testset(param: str): - field, _ = DelayedQuery.param_map[param] - return ( - ( - {f"{param}__id": "42"}, - query.And( - [ - query.Term(f"{field}_id", "42"), - self.has_no_owner, - ], - ), - ), - self._get_testset__id__in(param, field), - self._get_testset__id__none(param, field), - ( - {f"{param}__isnull": "true"}, - query.And( - [ - query.Term(f"has_{field}", False), - self.has_no_owner, - ], - ), - ), - ( - {f"{param}__isnull": "false"}, - query.And( - [ - query.Term(f"has_{field}", True), - self.has_no_owner, - ], - ), - ), - ) - - query_params = ["correspondent", "document_type", "storage_path", "owner"] - for param in query_params: - for params, expected in _get_testset(param): - dq = DelayedQuery(None, params, None, None) - got = dq._get_query_filter() - self.assertCountEqual(got, expected) - - def test_char_query_filter(self): - def _get_testset(param: str): - return ( - ( - {f"{param}__icontains": "foo"}, - query.And( - [ - query.Term(f"{param}", "foo"), - self.has_no_owner, - ], - ), - ), - ( - {f"{param}__istartswith": "foo"}, - query.And( - [ - query.Prefix(f"{param}", "foo"), - self.has_no_owner, - ], - ), - ), - ) - - query_params = ["checksum", "original_filename"] - for param in query_params: - for params, expected in _get_testset(param): - dq = DelayedQuery(None, params, None, None) - got = dq._get_query_filter() - self.assertCountEqual(got, expected) diff --git a/src/documents/views.py b/src/documents/views.py index df54546e1..c0ceef4a3 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -852,6 +852,8 @@ class UnifiedSearchViewSet(DocumentViewSet): ) def filter_queryset(self, queryset): + filtered_queryset = super().filter_queryset(queryset) + if self._is_search_request(): from documents import index @@ -866,10 +868,10 @@ class UnifiedSearchViewSet(DocumentViewSet): self.searcher, self.request.query_params, self.paginator.get_page_size(self.request), - self.request.user, + filter_queryset=filtered_queryset, ) else: - return super().filter_queryset(queryset) + return filtered_queryset def list(self, request, *args, **kwargs): if self._is_search_request(): @@ -1199,14 +1201,16 @@ class GlobalSearchView(PassUserMixin): from documents import index with index.open_index_searcher() as s: - q, _ = index.DelayedFullTextQuery( + fts_query = index.DelayedFullTextQuery( s, request.query_params, - 10, - request.user, - )._get_query() - results = s.search(q, limit=OBJECT_LIMIT) - docs = docs | all_docs.filter(id__in=[r["id"] for r in results]) + OBJECT_LIMIT, + filter_queryset=all_docs, + ) + results = fts_query[0:1] + docs = docs | Document.objects.filter( + id__in=[r["id"] for r in results], + ) docs = docs[:OBJECT_LIMIT] saved_views = ( SavedView.objects.filter(owner=request.user, name__icontains=query) @@ -1452,12 +1456,12 @@ class StatisticsView(APIView): { "documents_total": documents_total, "documents_inbox": documents_inbox, - "inbox_tag": inbox_tags.first().pk - if inbox_tags.exists() - else None, # backwards compatibility - "inbox_tags": [tag.pk for tag in inbox_tags] - if inbox_tags.exists() - else None, + "inbox_tag": ( + inbox_tags.first().pk if inbox_tags.exists() else None + ), # backwards compatibility + "inbox_tags": ( + [tag.pk for tag in inbox_tags] if inbox_tags.exists() else None + ), "document_file_type_counts": document_file_type_counts, "character_count": character_count, "tag_count": len(tags),