mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-11-21 04:36:53 -06:00
Fix: support for custom field ordering w advanced search (#11383)
This commit is contained in:
@@ -287,15 +287,75 @@ class DelayedQuery:
|
|||||||
self.first_score = None
|
self.first_score = None
|
||||||
self.filter_queryset = filter_queryset
|
self.filter_queryset = filter_queryset
|
||||||
self.suggested_correction = None
|
self.suggested_correction = None
|
||||||
|
self._manual_hits_cache: list | None = None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
if self._manual_sort_requested():
|
||||||
|
manual_hits = self._manual_hits()
|
||||||
|
return len(manual_hits)
|
||||||
|
|
||||||
page = self[0:1]
|
page = self[0:1]
|
||||||
return len(page)
|
return len(page)
|
||||||
|
|
||||||
|
def _manual_sort_requested(self):
|
||||||
|
ordering = self.query_params.get("ordering", "")
|
||||||
|
return ordering.lstrip("-").startswith("custom_field_")
|
||||||
|
|
||||||
|
def _manual_hits(self):
|
||||||
|
if self._manual_hits_cache is None:
|
||||||
|
q, mask, suggested_correction = self._get_query()
|
||||||
|
self.suggested_correction = suggested_correction
|
||||||
|
|
||||||
|
results = self.searcher.search(
|
||||||
|
q,
|
||||||
|
mask=mask,
|
||||||
|
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
|
||||||
|
limit=None,
|
||||||
|
)
|
||||||
|
results.fragmenter = highlight.ContextFragmenter(surround=50)
|
||||||
|
results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
||||||
|
|
||||||
|
if not self.first_score and len(results) > 0:
|
||||||
|
self.first_score = results[0].score
|
||||||
|
|
||||||
|
if self.first_score:
|
||||||
|
results.top_n = [
|
||||||
|
(
|
||||||
|
(hit[0] / self.first_score) if self.first_score else None,
|
||||||
|
hit[1],
|
||||||
|
)
|
||||||
|
for hit in results.top_n
|
||||||
|
]
|
||||||
|
|
||||||
|
hits_by_id = {hit["id"]: hit for hit in results}
|
||||||
|
matching_ids = list(hits_by_id.keys())
|
||||||
|
|
||||||
|
ordered_ids = list(
|
||||||
|
self.filter_queryset.filter(id__in=matching_ids).values_list(
|
||||||
|
"id",
|
||||||
|
flat=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ordered_ids = list(dict.fromkeys(ordered_ids))
|
||||||
|
|
||||||
|
self._manual_hits_cache = [
|
||||||
|
hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id
|
||||||
|
]
|
||||||
|
return self._manual_hits_cache
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if item.start in self.saved_results:
|
if item.start in self.saved_results:
|
||||||
return self.saved_results[item.start]
|
return self.saved_results[item.start]
|
||||||
|
|
||||||
|
if self._manual_sort_requested():
|
||||||
|
manual_hits = self._manual_hits()
|
||||||
|
start = 0 if item.start is None else item.start
|
||||||
|
stop = item.stop
|
||||||
|
hits = manual_hits[start:stop] if stop is not None else manual_hits[start:]
|
||||||
|
page = ManualResultsPage(hits)
|
||||||
|
self.saved_results[start] = page
|
||||||
|
return page
|
||||||
|
|
||||||
q, mask, suggested_correction = self._get_query()
|
q, mask, suggested_correction = self._get_query()
|
||||||
self.suggested_correction = suggested_correction
|
self.suggested_correction = suggested_correction
|
||||||
sortedby, reverse = self._get_query_sortedby()
|
sortedby, reverse = self._get_query_sortedby()
|
||||||
@@ -315,21 +375,33 @@ class DelayedQuery:
|
|||||||
if not self.first_score and len(page.results) > 0 and sortedby is None:
|
if not self.first_score and len(page.results) > 0 and sortedby is None:
|
||||||
self.first_score = page.results[0].score
|
self.first_score = page.results[0].score
|
||||||
|
|
||||||
page.results.top_n = list(
|
page.results.top_n = [
|
||||||
map(
|
(
|
||||||
lambda hit: (
|
(hit[0] / self.first_score) if self.first_score else None,
|
||||||
(hit[0] / self.first_score) if self.first_score else None,
|
hit[1],
|
||||||
hit[1],
|
)
|
||||||
),
|
for hit in page.results.top_n
|
||||||
page.results.top_n,
|
]
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.saved_results[item.start] = page
|
self.saved_results[item.start] = page
|
||||||
|
|
||||||
return page
|
return page
|
||||||
|
|
||||||
|
|
||||||
|
class ManualResultsPage(list):
|
||||||
|
def __init__(self, hits):
|
||||||
|
super().__init__(hits)
|
||||||
|
self.results = ManualResults(hits)
|
||||||
|
|
||||||
|
|
||||||
|
class ManualResults:
|
||||||
|
def __init__(self, hits):
|
||||||
|
self._docnums = [hit.docnum for hit in hits]
|
||||||
|
|
||||||
|
def docs(self):
|
||||||
|
return self._docnums
|
||||||
|
|
||||||
|
|
||||||
class LocalDateParser(English):
|
class LocalDateParser(English):
|
||||||
def reverse_timezone_offset(self, d):
|
def reverse_timezone_offset(self, d):
|
||||||
return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(
|
return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(
|
||||||
|
|||||||
@@ -89,6 +89,65 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
self.assertEqual(len(results), 0)
|
self.assertEqual(len(results), 0)
|
||||||
self.assertCountEqual(response.data["all"], [])
|
self.assertCountEqual(response.data["all"], [])
|
||||||
|
|
||||||
|
def test_search_custom_field_ordering(self):
|
||||||
|
custom_field = CustomField.objects.create(
|
||||||
|
name="Sortable field",
|
||||||
|
data_type=CustomField.FieldDataType.INT,
|
||||||
|
)
|
||||||
|
d1 = Document.objects.create(
|
||||||
|
title="first",
|
||||||
|
content="match",
|
||||||
|
checksum="A1",
|
||||||
|
)
|
||||||
|
d2 = Document.objects.create(
|
||||||
|
title="second",
|
||||||
|
content="match",
|
||||||
|
checksum="B2",
|
||||||
|
)
|
||||||
|
d3 = Document.objects.create(
|
||||||
|
title="third",
|
||||||
|
content="match",
|
||||||
|
checksum="C3",
|
||||||
|
)
|
||||||
|
CustomFieldInstance.objects.create(
|
||||||
|
document=d1,
|
||||||
|
field=custom_field,
|
||||||
|
value_int=30,
|
||||||
|
)
|
||||||
|
CustomFieldInstance.objects.create(
|
||||||
|
document=d2,
|
||||||
|
field=custom_field,
|
||||||
|
value_int=10,
|
||||||
|
)
|
||||||
|
CustomFieldInstance.objects.create(
|
||||||
|
document=d3,
|
||||||
|
field=custom_field,
|
||||||
|
value_int=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
with AsyncWriter(index.open_index()) as writer:
|
||||||
|
index.update_document(writer, d1)
|
||||||
|
index.update_document(writer, d2)
|
||||||
|
index.update_document(writer, d3)
|
||||||
|
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(
|
||||||
|
[doc["id"] for doc in response.data["results"]],
|
||||||
|
[d2.id, d3.id, d1.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.get(
|
||||||
|
f"/api/documents/?query=match&ordering=-custom_field_{custom_field.pk}",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(
|
||||||
|
[doc["id"] for doc in response.data["results"]],
|
||||||
|
[d1.id, d3.id, d2.id],
|
||||||
|
)
|
||||||
|
|
||||||
def test_search_multi_page(self):
|
def test_search_multi_page(self):
|
||||||
with AsyncWriter(index.open_index()) as writer:
|
with AsyncWriter(index.open_index()) as writer:
|
||||||
for i in range(55):
|
for i in range(55):
|
||||||
|
|||||||
Reference in New Issue
Block a user