sorting for full text queries

This commit is contained in:
jonaswinkler 2021-05-15 13:58:11 +02:00
parent 814d90745b
commit 8ee2e8b23d
2 changed files with 62 additions and 22 deletions

View File

@ -7,7 +7,7 @@ from dateutil.parser import isoparse
from django.conf import settings from django.conf import settings
from whoosh import highlight, classify, query from whoosh import highlight, classify, query
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
from whoosh.highlight import Formatter, get_text, HtmlFormatter from whoosh.highlight import HtmlFormatter
from whoosh.index import create_in, exists_in, open_dir from whoosh.index import create_in, exists_in, open_dir
from whoosh.qparser import MultifieldParser from whoosh.qparser import MultifieldParser
from whoosh.qparser.dateparse import DateParserPlugin from whoosh.qparser.dateparse import DateParserPlugin
@ -147,12 +147,10 @@ def remove_document_from_index(document):
class DelayedQuery: class DelayedQuery:
@property def _get_query(self):
def _query(self):
raise NotImplementedError() raise NotImplementedError()
@property def _get_query_filter(self):
def _query_filter(self):
criterias = [] criterias = []
for k, v in self.query_params.items(): for k, v in self.query_params.items():
if k == 'correspondent__id': if k == 'correspondent__id':
@ -185,16 +183,33 @@ class DelayedQuery:
else: else:
return None return None
@property def _get_query_sortedby(self):
def _query_sortedby(self): if not 'ordering' in self.query_params:
# if not 'ordering' in self.query_params: return None, False
return None, False
# o: str = self.query_params['ordering'] field: str = self.query_params['ordering']
# if o.startswith('-'):
# return o[1:], True sort_fields_map = {
# else: "created": "created",
# return o, False "modified": "modified",
"added": "added",
"title": "title",
"correspondent__name": "correspondent",
"document_type__name": "type",
"archive_serial_number": "asn",
"score": None,
}
if field.startswith('-'):
field = field[1:]
reverse = True
else:
reverse = False
if field not in sort_fields_map:
return None, False
else:
return sort_fields_map[field], reverse
def __init__(self, searcher: Searcher, query_params, page_size): def __init__(self, searcher: Searcher, query_params, page_size):
self.searcher = searcher self.searcher = searcher
@ -211,13 +226,13 @@ class DelayedQuery:
if item.start in self.saved_results: if item.start in self.saved_results:
return self.saved_results[item.start] return self.saved_results[item.start]
q, mask = self._query q, mask = self._get_query()
sortedby, reverse = self._query_sortedby sortedby, reverse = self._get_query_sortedby()
page: ResultsPage = self.searcher.search_page( page: ResultsPage = self.searcher.search_page(
q, q,
mask=mask, mask=mask,
filter=self._query_filter, filter=self._get_query_filter(),
pagenum=math.floor(item.start / self.page_size) + 1, pagenum=math.floor(item.start / self.page_size) + 1,
pagelen=self.page_size, pagelen=self.page_size,
sortedby=sortedby, sortedby=sortedby,
@ -227,7 +242,9 @@ class DelayedQuery:
surround=50) surround=50)
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ") page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
if not self.first_score and len(page.results) > 0: if (not self.first_score and
len(page.results) > 0 and
sortedby is None):
self.first_score = page.results[0].score self.first_score = page.results[0].score
if self.first_score: if self.first_score:
@ -243,8 +260,7 @@ class DelayedQuery:
class DelayedFullTextQuery(DelayedQuery): class DelayedFullTextQuery(DelayedQuery):
@property def _get_query(self):
def _query(self):
q_str = self.query_params['query'] q_str = self.query_params['query']
qp = MultifieldParser( qp = MultifieldParser(
["content", "title", "correspondent", "tag", "type"], ["content", "title", "correspondent", "tag", "type"],
@ -261,8 +277,7 @@ class DelayedFullTextQuery(DelayedQuery):
class DelayedMoreLikeThisQuery(DelayedQuery): class DelayedMoreLikeThisQuery(DelayedQuery):
@property def _get_query(self):
def _query(self):
more_like_doc_id = int(self.query_params['more_like_id']) more_like_doc_id = int(self.query_params['more_like_id'])
content = Document.objects.get(id=more_like_doc_id).content content = Document.objects.get(id=more_like_doc_id).content

View File

@ -471,6 +471,31 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
def test_search_sorting(self):
c1 = Correspondent.objects.create(name="corres Ax")
c2 = Correspondent.objects.create(name="corres Cx")
c3 = Correspondent.objects.create(name="corres Bx")
d1 = Document.objects.create(checksum="1", correspondent=c1, content="test", archive_serial_number=2, title="3")
d2 = Document.objects.create(checksum="2", correspondent=c2, content="test", archive_serial_number=3, title="2")
d3 = Document.objects.create(checksum="3", correspondent=c3, content="test", archive_serial_number=1, title="1")
with AsyncWriter(index.open_index()) as writer:
for doc in Document.objects.all():
index.update_document(writer, doc)
def search_query(q):
r = self.client.get("/api/documents/?query=test" + q)
self.assertEqual(r.status_code, 200)
return [hit['id'] for hit in r.data['results']]
self.assertListEqual(search_query("&ordering=archive_serial_number"), [d3.id, d1.id, d2.id])
self.assertListEqual(search_query("&ordering=-archive_serial_number"), [d2.id, d1.id, d3.id])
self.assertListEqual(search_query("&ordering=title"), [d3.id, d2.id, d1.id])
self.assertListEqual(search_query("&ordering=-title"), [d1.id, d2.id, d3.id])
self.assertListEqual(search_query("&ordering=correspondent__name"), [d1.id, d3.id, d2.id])
self.assertListEqual(search_query("&ordering=-correspondent__name"), [d2.id, d3.id, d1.id])
def test_statistics(self): def test_statistics(self):
doc1 = Document.objects.create(title="none1", checksum="A") doc1 = Document.objects.create(title="none1", checksum="A")