Fix: handle whoosh query correction errors (#10121)

This commit is contained in:
shamoon 2025-06-05 08:57:25 -07:00 committed by GitHub
parent 422bffe1a6
commit 51e6eed72a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 13 deletions

View File

@ -281,6 +281,7 @@ class DelayedQuery:
self.saved_results = dict() self.saved_results = dict()
self.first_score = None self.first_score = None
self.filter_queryset = filter_queryset self.filter_queryset = filter_queryset
self.suggested_correction = None
def __len__(self) -> int: def __len__(self) -> int:
page = self[0:1] page = self[0:1]
@ -290,7 +291,8 @@ class DelayedQuery:
if item.start in self.saved_results: if item.start in self.saved_results:
return self.saved_results[item.start] return self.saved_results[item.start]
q, mask = self._get_query() q, mask, suggested_correction = self._get_query()
self.suggested_correction = suggested_correction
sortedby, reverse = self._get_query_sortedby() sortedby, reverse = self._get_query_sortedby()
page: ResultsPage = self.searcher.search_page( page: ResultsPage = self.searcher.search_page(
@ -361,12 +363,19 @@ class DelayedFullTextQuery(DelayedQuery):
), ),
) )
q = qp.parse(q_str) q = qp.parse(q_str)
suggested_correction = None
try:
corrected = self.searcher.correct_query(q, q_str) corrected = self.searcher.correct_query(q, q_str)
if corrected.query != q: if corrected.string != q_str:
corrected.query = corrected.string suggested_correction = corrected.string
except Exception as e:
logger.info(
"Error while correcting query %s: %s",
f"{q_str!r}",
e,
)
return q, None return q, None, suggested_correction
class DelayedMoreLikeThisQuery(DelayedQuery): class DelayedMoreLikeThisQuery(DelayedQuery):
@ -387,7 +396,7 @@ class DelayedMoreLikeThisQuery(DelayedQuery):
) )
mask: set = {docnum} mask: set = {docnum}
return q, mask return q, mask, None
def autocomplete( def autocomplete(

View File

@ -2,7 +2,6 @@ import datetime
from datetime import timedelta from datetime import timedelta
from unittest import mock from unittest import mock
import pytest
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
@ -623,8 +622,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0], b"auto") self.assertEqual(response.data[0], b"auto")
@pytest.mark.skip(reason="Not implemented yet") def test_search_spelling_suggestion(self):
def test_search_spelling_correction(self):
with AsyncWriter(index.open_index()) as writer: with AsyncWriter(index.open_index()) as writer:
for i in range(55): for i in range(55):
doc = Document.objects.create( doc = Document.objects.create(
@ -635,16 +633,36 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
) )
index.update_document(writer, doc) index.update_document(writer, doc)
response = self.client.get("/api/search/?query=thing") response = self.client.get("/api/documents/?query=thing")
correction = response.data["corrected_query"] correction = response.data["corrected_query"]
self.assertEqual(correction, "things") self.assertEqual(correction, "things")
response = self.client.get("/api/search/?query=things") response = self.client.get("/api/documents/?query=things")
correction = response.data["corrected_query"] correction = response.data["corrected_query"]
self.assertEqual(correction, None) self.assertEqual(correction, None)
@mock.patch(
"whoosh.searching.Searcher.correct_query",
side_effect=Exception("Test error"),
)
def test_corrected_query_error(self, mock_correct_query):
"""
GIVEN:
- A query that raises an error on correction
WHEN:
- API request for search with that query
THEN:
- The error is logged and the search proceeds
"""
with self.assertLogs("paperless.index", level="INFO") as cm:
response = self.client.get("/api/documents/?query=2025-06-04")
self.assertEqual(response.status_code, status.HTTP_200_OK)
error_str = cm.output[0]
expected_str = "Error while correcting query '2025-06-04': Test error"
self.assertIn(expected_str, error_str)
def test_search_more_like(self): def test_search_more_like(self):
""" """
GIVEN: GIVEN:

View File

@ -1126,7 +1126,19 @@ class UnifiedSearchViewSet(DocumentViewSet):
try: try:
with index.open_index_searcher() as s: with index.open_index_searcher() as s:
self.searcher = s self.searcher = s
return super().list(request) queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
response = self.get_paginated_response(serializer.data)
response.data["corrected_query"] = (
queryset.suggested_correction
if hasattr(queryset, "suggested_correction")
else None
)
return response
except NotFound: except NotFound:
raise raise
except Exception as e: except Exception as e: