diff --git a/src/documents/index.py b/src/documents/index.py index 6cd136d80..90f896199 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -7,7 +7,7 @@ from dateutil.parser import isoparse from django.conf import settings from whoosh import highlight, classify, query from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN -from whoosh.highlight import Formatter, get_text, HtmlFormatter +from whoosh.highlight import HtmlFormatter from whoosh.index import create_in, exists_in, open_dir from whoosh.qparser import MultifieldParser from whoosh.qparser.dateparse import DateParserPlugin @@ -147,12 +147,10 @@ def remove_document_from_index(document): class DelayedQuery: - @property - def _query(self): + def _get_query(self): raise NotImplementedError() - @property - def _query_filter(self): + def _get_query_filter(self): criterias = [] for k, v in self.query_params.items(): if k == 'correspondent__id': @@ -185,16 +183,33 @@ class DelayedQuery: else: return None - @property - def _query_sortedby(self): - # if not 'ordering' in self.query_params: - return None, False + def _get_query_sortedby(self): + if not 'ordering' in self.query_params: + return None, False - # o: str = self.query_params['ordering'] - # if o.startswith('-'): - # return o[1:], True - # else: - # return o, False + field: str = self.query_params['ordering'] + + sort_fields_map = { + "created": "created", + "modified": "modified", + "added": "added", + "title": "title", + "correspondent__name": "correspondent", + "document_type__name": "type", + "archive_serial_number": "asn", + "score": None, + } + + if field.startswith('-'): + field = field[1:] + reverse = True + else: + reverse = False + + if field not in sort_fields_map: + return None, False + else: + return sort_fields_map[field], reverse def __init__(self, searcher: Searcher, query_params, page_size): self.searcher = searcher @@ -211,13 +226,13 @@ class DelayedQuery: if item.start in self.saved_results: return self.saved_results[item.start] - q, mask = self._query - sortedby, reverse = self._query_sortedby + q, mask = self._get_query() + sortedby, reverse = self._get_query_sortedby() page: ResultsPage = self.searcher.search_page( q, mask=mask, - filter=self._query_filter, + filter=self._get_query_filter(), pagenum=math.floor(item.start / self.page_size) + 1, pagelen=self.page_size, sortedby=sortedby, @@ -227,7 +242,9 @@ class DelayedQuery: surround=50) page.results.formatter = HtmlFormatter(tagname="span", between=" ... ") - if not self.first_score and len(page.results) > 0: + if (not self.first_score and + len(page.results) > 0 and + sortedby is None): self.first_score = page.results[0].score if self.first_score: @@ -243,8 +260,7 @@ class DelayedQuery: class DelayedFullTextQuery(DelayedQuery): - @property - def _query(self): + def _get_query(self): q_str = self.query_params['query'] qp = MultifieldParser( ["content", "title", "correspondent", "tag", "type"], @@ -261,8 +277,7 @@ class DelayedFullTextQuery(DelayedQuery): class DelayedMoreLikeThisQuery(DelayedQuery): - @property - def _query(self): + def _get_query(self): more_like_doc_id = int(self.query_params['more_like_id']) content = Document.objects.get(id=more_like_doc_id).content diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index cfde28e2d..2f8dc18da 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -471,6 +471,31 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) + def test_search_sorting(self): + c1 = Correspondent.objects.create(name="corres Ax") + c2 = Correspondent.objects.create(name="corres Cx") + c3 = Correspondent.objects.create(name="corres Bx") + d1 = Document.objects.create(checksum="1", correspondent=c1, content="test", archive_serial_number=2, title="3") + d2 = Document.objects.create(checksum="2", correspondent=c2, content="test", archive_serial_number=3, title="2") + d3 = Document.objects.create(checksum="3", correspondent=c3, content="test", archive_serial_number=1, title="1") + + with AsyncWriter(index.open_index()) as writer: + for doc in Document.objects.all(): + index.update_document(writer, doc) + + def search_query(q): + r = self.client.get("/api/documents/?query=test" + q) + self.assertEqual(r.status_code, 200) + return [hit['id'] for hit in r.data['results']] + + self.assertListEqual(search_query("&ordering=archive_serial_number"), [d3.id, d1.id, d2.id]) + self.assertListEqual(search_query("&ordering=-archive_serial_number"), [d2.id, d1.id, d3.id]) + self.assertListEqual(search_query("&ordering=title"), [d3.id, d2.id, d1.id]) + self.assertListEqual(search_query("&ordering=-title"), [d1.id, d2.id, d3.id]) + self.assertListEqual(search_query("&ordering=correspondent__name"), [d1.id, d3.id, d2.id]) + self.assertListEqual(search_query("&ordering=-correspondent__name"), [d2.id, d3.id, d1.id]) + + def test_statistics(self): doc1 = Document.objects.create(title="none1", checksum="A")