mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	sorting for full text queries
This commit is contained in:
		| @@ -7,7 +7,7 @@ from dateutil.parser import isoparse | ||||
| from django.conf import settings | ||||
| from whoosh import highlight, classify, query | ||||
| from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN | ||||
| from whoosh.highlight import Formatter, get_text, HtmlFormatter | ||||
| from whoosh.highlight import HtmlFormatter | ||||
| from whoosh.index import create_in, exists_in, open_dir | ||||
| from whoosh.qparser import MultifieldParser | ||||
| from whoosh.qparser.dateparse import DateParserPlugin | ||||
| @@ -147,12 +147,10 @@ def remove_document_from_index(document): | ||||
|  | ||||
| class DelayedQuery: | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|     def _get_query(self): | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @property | ||||
|     def _query_filter(self): | ||||
|     def _get_query_filter(self): | ||||
|         criterias = [] | ||||
|         for k, v in self.query_params.items(): | ||||
|             if k == 'correspondent__id': | ||||
| @@ -185,16 +183,33 @@ class DelayedQuery: | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     @property | ||||
|     def _query_sortedby(self): | ||||
|         # if not 'ordering' in self.query_params: | ||||
|         return None, False | ||||
|     def _get_query_sortedby(self): | ||||
|         if not 'ordering' in self.query_params: | ||||
|             return None, False | ||||
|  | ||||
|         # o: str = self.query_params['ordering'] | ||||
|         # if o.startswith('-'): | ||||
|         #     return o[1:], True | ||||
|         # else: | ||||
|         #     return o, False | ||||
|         field: str = self.query_params['ordering'] | ||||
|  | ||||
|         sort_fields_map = { | ||||
|             "created": "created", | ||||
|             "modified": "modified", | ||||
|             "added": "added", | ||||
|             "title": "title", | ||||
|             "correspondent__name": "correspondent", | ||||
|             "document_type__name": "type", | ||||
|             "archive_serial_number": "asn", | ||||
|             "score": None, | ||||
|         } | ||||
|  | ||||
|         if field.startswith('-'): | ||||
|             field = field[1:] | ||||
|             reverse = True | ||||
|         else: | ||||
|             reverse = False | ||||
|  | ||||
|         if field not in sort_fields_map: | ||||
|             return None, False | ||||
|         else: | ||||
|             return sort_fields_map[field], reverse | ||||
|  | ||||
|     def __init__(self, searcher: Searcher, query_params, page_size): | ||||
|         self.searcher = searcher | ||||
| @@ -211,13 +226,13 @@ class DelayedQuery: | ||||
|         if item.start in self.saved_results: | ||||
|             return self.saved_results[item.start] | ||||
|  | ||||
|         q, mask = self._query | ||||
|         sortedby, reverse = self._query_sortedby | ||||
|         q, mask = self._get_query() | ||||
|         sortedby, reverse = self._get_query_sortedby() | ||||
|  | ||||
|         page: ResultsPage = self.searcher.search_page( | ||||
|             q, | ||||
|             mask=mask, | ||||
|             filter=self._query_filter, | ||||
|             filter=self._get_query_filter(), | ||||
|             pagenum=math.floor(item.start / self.page_size) + 1, | ||||
|             pagelen=self.page_size, | ||||
|             sortedby=sortedby, | ||||
| @@ -227,7 +242,9 @@ class DelayedQuery: | ||||
|             surround=50) | ||||
|         page.results.formatter = HtmlFormatter(tagname="span", between=" ... ") | ||||
|  | ||||
|         if not self.first_score and len(page.results) > 0: | ||||
|         if (not self.first_score and | ||||
|                 len(page.results) > 0 and | ||||
|                 sortedby is None): | ||||
|             self.first_score = page.results[0].score | ||||
|  | ||||
|         if self.first_score: | ||||
| @@ -243,8 +260,7 @@ class DelayedQuery: | ||||
|  | ||||
| class DelayedFullTextQuery(DelayedQuery): | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|     def _get_query(self): | ||||
|         q_str = self.query_params['query'] | ||||
|         qp = MultifieldParser( | ||||
|             ["content", "title", "correspondent", "tag", "type"], | ||||
| @@ -261,8 +277,7 @@ class DelayedFullTextQuery(DelayedQuery): | ||||
|  | ||||
| class DelayedMoreLikeThisQuery(DelayedQuery): | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|     def _get_query(self): | ||||
|         more_like_doc_id = int(self.query_params['more_like_id']) | ||||
|         content = Document.objects.get(id=more_like_doc_id).content | ||||
|  | ||||
|   | ||||
| @@ -471,6 +471,31 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) | ||||
|         self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"))) | ||||
|  | ||||
|     def test_search_sorting(self): | ||||
|         c1 = Correspondent.objects.create(name="corres Ax") | ||||
|         c2 = Correspondent.objects.create(name="corres Cx") | ||||
|         c3 = Correspondent.objects.create(name="corres Bx") | ||||
|         d1 = Document.objects.create(checksum="1", correspondent=c1, content="test", archive_serial_number=2, title="3") | ||||
|         d2 = Document.objects.create(checksum="2", correspondent=c2, content="test", archive_serial_number=3, title="2") | ||||
|         d3 = Document.objects.create(checksum="3", correspondent=c3, content="test", archive_serial_number=1, title="1") | ||||
|  | ||||
|         with AsyncWriter(index.open_index()) as writer: | ||||
|             for doc in Document.objects.all(): | ||||
|                 index.update_document(writer, doc) | ||||
|  | ||||
|         def search_query(q): | ||||
|             r = self.client.get("/api/documents/?query=test" + q) | ||||
|             self.assertEqual(r.status_code, 200) | ||||
|             return [hit['id'] for hit in r.data['results']] | ||||
|  | ||||
|         self.assertListEqual(search_query("&ordering=archive_serial_number"), [d3.id, d1.id, d2.id]) | ||||
|         self.assertListEqual(search_query("&ordering=-archive_serial_number"), [d2.id, d1.id, d3.id]) | ||||
|         self.assertListEqual(search_query("&ordering=title"), [d3.id, d2.id, d1.id]) | ||||
|         self.assertListEqual(search_query("&ordering=-title"), [d1.id, d2.id, d3.id]) | ||||
|         self.assertListEqual(search_query("&ordering=correspondent__name"), [d1.id, d3.id, d2.id]) | ||||
|         self.assertListEqual(search_query("&ordering=-correspondent__name"), [d2.id, d3.id, d1.id]) | ||||
|  | ||||
|  | ||||
|     def test_statistics(self): | ||||
|  | ||||
|         doc1 = Document.objects.create(title="none1", checksum="A") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler