Refactor: Use django-filter logic for filtering full text search queries (#7507)

This commit is contained in:
Yichi Yang 2024-08-25 12:20:43 +08:00 committed by GitHub
parent 057ce29676
commit a0c227fe55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 290 deletions

View File

@ -8,8 +8,8 @@ from datetime import timezone
from shutil import rmtree
from typing import Optional
from dateutil.parser import isoparse
from django.conf import settings
from django.db.models import QuerySet
from django.utils import timezone as django_timezone
from guardian.shortcuts import get_users_with_perms
from whoosh import classify
@ -22,6 +22,8 @@ from whoosh.fields import NUMERIC
from whoosh.fields import TEXT
from whoosh.fields import Schema
from whoosh.highlight import HtmlFormatter
from whoosh.idsets import BitSet
from whoosh.idsets import DocIdSet
from whoosh.index import FileIndex
from whoosh.index import create_in
from whoosh.index import exists_in
@ -31,6 +33,7 @@ from whoosh.qparser import QueryParser
from whoosh.qparser.dateparse import DateParserPlugin
from whoosh.qparser.dateparse import English
from whoosh.qparser.plugins import FieldsPlugin
from whoosh.reading import IndexReader
from whoosh.scoring import TF_IDF
from whoosh.searching import ResultsPage
from whoosh.searching import Searcher
@ -201,114 +204,32 @@ def remove_document_from_index(document: Document):
remove_document(writer, document)
class MappedDocIdSet(DocIdSet):
"""
A DocIdSet backed by a set of `Document` IDs.
Supports efficiently looking up if a whoosh docnum is in the provided `filter_queryset`.
"""
def __init__(self, filter_queryset: QuerySet, ixreader: IndexReader) -> None:
super().__init__()
document_ids = filter_queryset.order_by("id").values_list("id", flat=True)
max_id = document_ids.last() or 0
self.document_ids = BitSet(document_ids, size=max_id)
self.ixreader = ixreader
def __contains__(self, docnum):
document_id = self.ixreader.stored_fields(docnum)["id"]
return document_id in self.document_ids
def __bool__(self):
# searcher.search ignores a filter if it's "falsy".
# We use this hack so this DocIdSet, when used as a filter, is never ignored.
return True
class DelayedQuery:
param_map = {
"correspondent": ("correspondent", ["id", "id__in", "id__none", "isnull"]),
"document_type": ("type", ["id", "id__in", "id__none", "isnull"]),
"storage_path": ("path", ["id", "id__in", "id__none", "isnull"]),
"owner": ("owner", ["id", "id__in", "id__none", "isnull"]),
"shared_by": ("shared_by", ["id"]),
"tags": ("tag", ["id__all", "id__in", "id__none"]),
"added": ("added", ["date__lt", "date__gt"]),
"created": ("created", ["date__lt", "date__gt"]),
"checksum": ("checksum", ["icontains", "istartswith"]),
"original_filename": ("original_filename", ["icontains", "istartswith"]),
"custom_fields": (
"custom_fields",
["icontains", "istartswith", "id__all", "id__in", "id__none"],
),
}
def _get_query(self):
raise NotImplementedError
def _get_query_filter(self):
criterias = []
for key, value in self.query_params.items():
# is_tagged is a special case
if key == "is_tagged":
criterias.append(query.Term("has_tag", self.evalBoolean(value)))
continue
if key == "has_custom_fields":
criterias.append(
query.Term("has_custom_fields", self.evalBoolean(value)),
)
continue
# Don't process query params without a filter
if "__" not in key:
continue
# All other query params consist of a parameter and a query filter
param, query_filter = key.split("__", 1)
try:
field, supported_query_filters = self.param_map[param]
except KeyError:
logger.error(f"Unable to build a query filter for parameter {key}")
continue
# We only support certain filters per parameter
if query_filter not in supported_query_filters:
logger.info(
f"Query filter {query_filter} not supported for parameter {param}",
)
continue
if query_filter == "id":
if param == "shared_by":
criterias.append(query.Term("is_shared", True))
criterias.append(query.Term("owner_id", value))
else:
criterias.append(query.Term(f"{field}_id", value))
elif query_filter == "id__in":
in_filter = []
for object_id in value.split(","):
in_filter.append(
query.Term(f"{field}_id", object_id),
)
criterias.append(query.Or(in_filter))
elif query_filter == "id__none":
for object_id in value.split(","):
criterias.append(
query.Not(query.Term(f"{field}_id", object_id)),
)
elif query_filter == "isnull":
criterias.append(
query.Term(f"has_{field}", self.evalBoolean(value) is False),
)
elif query_filter == "id__all":
for object_id in value.split(","):
criterias.append(query.Term(f"{field}_id", object_id))
elif query_filter == "date__lt":
criterias.append(
query.DateRange(field, start=None, end=isoparse(value)),
)
elif query_filter == "date__gt":
criterias.append(
query.DateRange(field, start=isoparse(value), end=None),
)
elif query_filter == "icontains":
criterias.append(
query.Term(field, value),
)
elif query_filter == "istartswith":
criterias.append(
query.Prefix(field, value),
)
user_criterias = get_permissions_criterias(
user=self.user,
)
if len(criterias) > 0:
if len(user_criterias) > 0:
criterias.append(query.Or(user_criterias))
return query.And(criterias)
else:
return query.Or(user_criterias) if len(user_criterias) > 0 else None
def evalBoolean(self, val):
return val.lower() in {"true", "1"}
raise NotImplementedError # pragma: no cover
def _get_query_sortedby(self):
if "ordering" not in self.query_params:
@ -339,13 +260,19 @@ class DelayedQuery:
else:
return sort_fields_map[field], reverse
def __init__(self, searcher: Searcher, query_params, page_size, user):
def __init__(
self,
searcher: Searcher,
query_params,
page_size,
filter_queryset: QuerySet,
):
self.searcher = searcher
self.query_params = query_params
self.page_size = page_size
self.saved_results = dict()
self.first_score = None
self.user = user
self.filter_queryset = filter_queryset
def __len__(self):
page = self[0:1]
@ -361,7 +288,7 @@ class DelayedQuery:
page: ResultsPage = self.searcher.search_page(
q,
mask=mask,
filter=self._get_query_filter(),
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
pagenum=math.floor(item.start / self.page_size) + 1,
pagelen=self.page_size,
sortedby=sortedby,

View File

@ -15,6 +15,7 @@ from rest_framework.test import APITestCase
from whoosh.writing import AsyncWriter
from documents import index
from documents.bulk_edit import set_permissions
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance
@ -1159,7 +1160,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
[d3.id, d2.id, d1.id],
)
def test_global_search(self):
@mock.patch("documents.bulk_edit.bulk_update_documents")
def test_global_search(self, m):
"""
GIVEN:
- Multiple documents and objects
@ -1186,11 +1188,38 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
checksum="C",
pk=3,
)
# The below two documents are owned by user2 and shouldn't show up in results!
d4 = Document.objects.create(
title="doc 4 owned by user2",
content="bank bank bank bank 4",
checksum="D",
pk=4,
)
d5 = Document.objects.create(
title="doc 5 owned by user2",
content="bank bank bank bank 5",
checksum="E",
pk=5,
)
user1 = User.objects.create_user("bank user1")
user2 = User.objects.create_superuser("user2")
group1 = Group.objects.create(name="bank group1")
Group.objects.create(name="group2")
user1.user_permissions.add(
*Permission.objects.filter(codename__startswith="view_").exclude(
content_type__app_label="admin",
),
)
set_permissions([4, 5], set_permissions=[], owner=user2, merge=False)
with index.open_index_writer() as writer:
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
index.update_document(writer, d4)
index.update_document(writer, d5)
correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
Correspondent.objects.create(name="correspondent 2")
@ -1200,10 +1229,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
StoragePath.objects.create(name="path 2", path="path2")
tag1 = Tag.objects.create(name="bank tag1")
Tag.objects.create(name="tag2")
user1 = User.objects.create_superuser("bank user1")
User.objects.create_user("user2")
group1 = Group.objects.create(name="bank group1")
Group.objects.create(name="group2")
SavedView.objects.create(
name="bank view",
show_on_dashboard=True,

View File

@ -1,8 +1,6 @@
from dateutil.parser import isoparse
from django.test import TestCase
from whoosh import query
from documents.index import DelayedQuery
from documents.index import get_permissions_criterias
from documents.models import User
@ -58,162 +56,3 @@ class TestDelayedQuery(TestCase):
)
for user, expected in tests:
self.assertEqual(get_permissions_criterias(user), expected)
def test_no_query_filters(self):
dq = DelayedQuery(None, {}, None, None)
self.assertEqual(dq._get_query_filter(), self.has_no_owner)
def test_date_query_filters(self):
def _get_testset(param: str):
date_str = "1970-01-01T02:44"
date_obj = isoparse(date_str)
return (
(
{f"{param}__date__lt": date_str},
query.And(
[
query.DateRange(param, start=None, end=date_obj),
self.has_no_owner,
],
),
),
(
{f"{param}__date__gt": date_str},
query.And(
[
query.DateRange(param, start=date_obj, end=None),
self.has_no_owner,
],
),
),
)
query_params = ["created", "added"]
for param in query_params:
for params, expected in _get_testset(param):
dq = DelayedQuery(None, params, None, None)
got = dq._get_query_filter()
self.assertCountEqual(got, expected)
def test_is_tagged_query_filter(self):
tests = (
("True", True),
("true", True),
("1", True),
("False", False),
("false", False),
("0", False),
("foo", False),
)
for param, expected in tests:
dq = DelayedQuery(None, {"is_tagged": param}, None, None)
self.assertEqual(
dq._get_query_filter(),
query.And([query.Term("has_tag", expected), self.has_no_owner]),
)
def test_tags_query_filters(self):
# tests contains tuples of query_parameter dics and the expected whoosh query
param = "tags"
field, _ = DelayedQuery.param_map[param]
tests = (
(
{f"{param}__id__all": "42,43"},
query.And(
[
query.Term(f"{field}_id", "42"),
query.Term(f"{field}_id", "43"),
self.has_no_owner,
],
),
),
# tags does not allow __id
(
{f"{param}__id": "42"},
self.has_no_owner,
),
# tags does not allow __isnull
(
{f"{param}__isnull": "true"},
self.has_no_owner,
),
self._get_testset__id__in(param, field),
self._get_testset__id__none(param, field),
)
for params, expected in tests:
dq = DelayedQuery(None, params, None, None)
got = dq._get_query_filter()
self.assertCountEqual(got, expected)
def test_generic_query_filters(self):
def _get_testset(param: str):
field, _ = DelayedQuery.param_map[param]
return (
(
{f"{param}__id": "42"},
query.And(
[
query.Term(f"{field}_id", "42"),
self.has_no_owner,
],
),
),
self._get_testset__id__in(param, field),
self._get_testset__id__none(param, field),
(
{f"{param}__isnull": "true"},
query.And(
[
query.Term(f"has_{field}", False),
self.has_no_owner,
],
),
),
(
{f"{param}__isnull": "false"},
query.And(
[
query.Term(f"has_{field}", True),
self.has_no_owner,
],
),
),
)
query_params = ["correspondent", "document_type", "storage_path", "owner"]
for param in query_params:
for params, expected in _get_testset(param):
dq = DelayedQuery(None, params, None, None)
got = dq._get_query_filter()
self.assertCountEqual(got, expected)
def test_char_query_filter(self):
def _get_testset(param: str):
return (
(
{f"{param}__icontains": "foo"},
query.And(
[
query.Term(f"{param}", "foo"),
self.has_no_owner,
],
),
),
(
{f"{param}__istartswith": "foo"},
query.And(
[
query.Prefix(f"{param}", "foo"),
self.has_no_owner,
],
),
),
)
query_params = ["checksum", "original_filename"]
for param in query_params:
for params, expected in _get_testset(param):
dq = DelayedQuery(None, params, None, None)
got = dq._get_query_filter()
self.assertCountEqual(got, expected)

View File

@ -852,6 +852,8 @@ class UnifiedSearchViewSet(DocumentViewSet):
)
def filter_queryset(self, queryset):
filtered_queryset = super().filter_queryset(queryset)
if self._is_search_request():
from documents import index
@ -866,10 +868,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
self.searcher,
self.request.query_params,
self.paginator.get_page_size(self.request),
self.request.user,
filter_queryset=filtered_queryset,
)
else:
return super().filter_queryset(queryset)
return filtered_queryset
def list(self, request, *args, **kwargs):
if self._is_search_request():
@ -1199,14 +1201,16 @@ class GlobalSearchView(PassUserMixin):
from documents import index
with index.open_index_searcher() as s:
q, _ = index.DelayedFullTextQuery(
fts_query = index.DelayedFullTextQuery(
s,
request.query_params,
10,
request.user,
)._get_query()
results = s.search(q, limit=OBJECT_LIMIT)
docs = docs | all_docs.filter(id__in=[r["id"] for r in results])
OBJECT_LIMIT,
filter_queryset=all_docs,
)
results = fts_query[0:1]
docs = docs | Document.objects.filter(
id__in=[r["id"] for r in results],
)
docs = docs[:OBJECT_LIMIT]
saved_views = (
SavedView.objects.filter(owner=request.user, name__icontains=query)
@ -1452,12 +1456,12 @@ class StatisticsView(APIView):
{
"documents_total": documents_total,
"documents_inbox": documents_inbox,
"inbox_tag": inbox_tags.first().pk
if inbox_tags.exists()
else None, # backwards compatibility
"inbox_tags": [tag.pk for tag in inbox_tags]
if inbox_tags.exists()
else None,
"inbox_tag": (
inbox_tags.first().pk if inbox_tags.exists() else None
), # backwards compatibility
"inbox_tags": (
[tag.pk for tag in inbox_tags] if inbox_tags.exists() else None
),
"document_file_type_counts": document_file_type_counts,
"character_count": character_count,
"tag_count": len(tags),