From 0e5ab7f3e0ccaf03fbcd90f9b484756f7f1b289b Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 17 Nov 2025 12:47:55 -0800 Subject: [PATCH] Fix: support for custom field ordering w advanced search (#11383) --- src/documents/index.py | 90 +++++++++++++++++++++++--- src/documents/tests/test_api_search.py | 59 +++++++++++++++++ 2 files changed, 140 insertions(+), 9 deletions(-) diff --git a/src/documents/index.py b/src/documents/index.py index 90cbb8000..9446c7db1 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -287,15 +287,75 @@ class DelayedQuery: self.first_score = None self.filter_queryset = filter_queryset self.suggested_correction = None + self._manual_hits_cache: list | None = None def __len__(self) -> int: + if self._manual_sort_requested(): + manual_hits = self._manual_hits() + return len(manual_hits) + page = self[0:1] return len(page) + def _manual_sort_requested(self): + ordering = self.query_params.get("ordering", "") + return ordering.lstrip("-").startswith("custom_field_") + + def _manual_hits(self): + if self._manual_hits_cache is None: + q, mask, suggested_correction = self._get_query() + self.suggested_correction = suggested_correction + + results = self.searcher.search( + q, + mask=mask, + filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), + limit=None, + ) + results.fragmenter = highlight.ContextFragmenter(surround=50) + results.formatter = HtmlFormatter(tagname="span", between=" ... ") + + if not self.first_score and len(results) > 0: + self.first_score = results[0].score + + if self.first_score: + results.top_n = [ + ( + (hit[0] / self.first_score) if self.first_score else None, + hit[1], + ) + for hit in results.top_n + ] + + hits_by_id = {hit["id"]: hit for hit in results} + matching_ids = list(hits_by_id.keys()) + + ordered_ids = list( + self.filter_queryset.filter(id__in=matching_ids).values_list( + "id", + flat=True, + ), + ) + ordered_ids = list(dict.fromkeys(ordered_ids)) + + self._manual_hits_cache = [ + hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id + ] + return self._manual_hits_cache + def __getitem__(self, item): if item.start in self.saved_results: return self.saved_results[item.start] + if self._manual_sort_requested(): + manual_hits = self._manual_hits() + start = 0 if item.start is None else item.start + stop = item.stop + hits = manual_hits[start:stop] if stop is not None else manual_hits[start:] + page = ManualResultsPage(hits) + self.saved_results[start] = page + return page + q, mask, suggested_correction = self._get_query() self.suggested_correction = suggested_correction sortedby, reverse = self._get_query_sortedby() @@ -315,21 +375,33 @@ class DelayedQuery: if not self.first_score and len(page.results) > 0 and sortedby is None: self.first_score = page.results[0].score - page.results.top_n = list( - map( - lambda hit: ( - (hit[0] / self.first_score) if self.first_score else None, - hit[1], - ), - page.results.top_n, - ), - ) + page.results.top_n = [ + ( + (hit[0] / self.first_score) if self.first_score else None, + hit[1], + ) + for hit in page.results.top_n + ] self.saved_results[item.start] = page return page +class ManualResultsPage(list): + def __init__(self, hits): + super().__init__(hits) + self.results = ManualResults(hits) + + +class ManualResults: + def __init__(self, hits): + self._docnums = [hit.docnum for hit in hits] + + def docs(self): + return self._docnums + + class LocalDateParser(English): def reverse_timezone_offset(self, d): return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone( diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 8f316c145..5a2fc9b52 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -89,6 +89,65 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): self.assertEqual(len(results), 0) self.assertCountEqual(response.data["all"], []) + def test_search_custom_field_ordering(self): + custom_field = CustomField.objects.create( + name="Sortable field", + data_type=CustomField.FieldDataType.INT, + ) + d1 = Document.objects.create( + title="first", + content="match", + checksum="A1", + ) + d2 = Document.objects.create( + title="second", + content="match", + checksum="B2", + ) + d3 = Document.objects.create( + title="third", + content="match", + checksum="C3", + ) + CustomFieldInstance.objects.create( + document=d1, + field=custom_field, + value_int=30, + ) + CustomFieldInstance.objects.create( + document=d2, + field=custom_field, + value_int=10, + ) + CustomFieldInstance.objects.create( + document=d3, + field=custom_field, + value_int=20, + ) + + with AsyncWriter(index.open_index()) as writer: + index.update_document(writer, d1) + index.update_document(writer, d2) + index.update_document(writer, d3) + + response = self.client.get( + f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + [doc["id"] for doc in response.data["results"]], + [d2.id, d3.id, d1.id], + ) + + response = self.client.get( + f"/api/documents/?query=match&ordering=-custom_field_{custom_field.pk}", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + [doc["id"] for doc in response.data["results"]], + [d1.id, d3.id, d2.id], + ) + def test_search_multi_page(self): with AsyncWriter(index.open_index()) as writer: for i in range(55):