diff --git a/src/documents/index.py b/src/documents/index.py index 3a2b2cb58..10de04245 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -281,6 +281,7 @@ class DelayedQuery: self.saved_results = dict() self.first_score = None self.filter_queryset = filter_queryset + self.suggested_correction = None def __len__(self) -> int: page = self[0:1] @@ -290,7 +291,8 @@ class DelayedQuery: if item.start in self.saved_results: return self.saved_results[item.start] - q, mask = self._get_query() + q, mask, suggested_correction = self._get_query() + self.suggested_correction = suggested_correction sortedby, reverse = self._get_query_sortedby() page: ResultsPage = self.searcher.search_page( @@ -361,12 +363,19 @@ class DelayedFullTextQuery(DelayedQuery): ), ) q = qp.parse(q_str) + suggested_correction = None + try: + corrected = self.searcher.correct_query(q, q_str) + if corrected.string != q_str: + suggested_correction = corrected.string + except Exception as e: + logger.info( + "Error while correcting query %s: %s", + f"{q_str!r}", + e, + ) - corrected = self.searcher.correct_query(q, q_str) - if corrected.query != q: - corrected.query = corrected.string - - return q, None + return q, None, suggested_correction class DelayedMoreLikeThisQuery(DelayedQuery): @@ -387,7 +396,7 @@ class DelayedMoreLikeThisQuery(DelayedQuery): ) mask: set = {docnum} - return q, mask + return q, mask, None def autocomplete( diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 7ffce06de..8f316c145 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -2,7 +2,6 @@ import datetime from datetime import timedelta from unittest import mock -import pytest from dateutil.relativedelta import relativedelta from django.contrib.auth.models import Group from django.contrib.auth.models import Permission @@ -623,8 +622,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data[0], b"auto") - @pytest.mark.skip(reason="Not implemented yet") - def test_search_spelling_correction(self): + def test_search_spelling_suggestion(self): with AsyncWriter(index.open_index()) as writer: for i in range(55): doc = Document.objects.create( @@ -635,16 +633,36 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): ) index.update_document(writer, doc) - response = self.client.get("/api/search/?query=thing") + response = self.client.get("/api/documents/?query=thing") correction = response.data["corrected_query"] self.assertEqual(correction, "things") - response = self.client.get("/api/search/?query=things") + response = self.client.get("/api/documents/?query=things") correction = response.data["corrected_query"] self.assertEqual(correction, None) + @mock.patch( + "whoosh.searching.Searcher.correct_query", + side_effect=Exception("Test error"), + ) + def test_corrected_query_error(self, mock_correct_query): + """ + GIVEN: + - A query that raises an error on correction + WHEN: + - API request for search with that query + THEN: + - The error is logged and the search proceeds + """ + with self.assertLogs("paperless.index", level="INFO") as cm: + response = self.client.get("/api/documents/?query=2025-06-04") + self.assertEqual(response.status_code, status.HTTP_200_OK) + error_str = cm.output[0] + expected_str = "Error while correcting query '2025-06-04': Test error" + self.assertIn(expected_str, error_str) + def test_search_more_like(self): """ GIVEN: diff --git a/src/documents/views.py b/src/documents/views.py index 4cd100c2d..6e7f814b3 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1126,7 +1126,19 @@ class UnifiedSearchViewSet(DocumentViewSet): try: with index.open_index_searcher() as s: self.searcher = s - return super().list(request) + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + + serializer = self.get_serializer(page, many=True) + response = self.get_paginated_response(serializer.data) + + response.data["corrected_query"] = ( + queryset.suggested_correction + if hasattr(queryset, "suggested_correction") + else None + ) + + return response except NotFound: raise except Exception as e: