mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
sorting for full text queries
This commit is contained in:
parent
814d90745b
commit
8ee2e8b23d
@ -7,7 +7,7 @@ from dateutil.parser import isoparse
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from whoosh import highlight, classify, query
|
from whoosh import highlight, classify, query
|
||||||
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
|
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
|
||||||
from whoosh.highlight import Formatter, get_text, HtmlFormatter
|
from whoosh.highlight import HtmlFormatter
|
||||||
from whoosh.index import create_in, exists_in, open_dir
|
from whoosh.index import create_in, exists_in, open_dir
|
||||||
from whoosh.qparser import MultifieldParser
|
from whoosh.qparser import MultifieldParser
|
||||||
from whoosh.qparser.dateparse import DateParserPlugin
|
from whoosh.qparser.dateparse import DateParserPlugin
|
||||||
@ -147,12 +147,10 @@ def remove_document_from_index(document):
|
|||||||
|
|
||||||
class DelayedQuery:
|
class DelayedQuery:
|
||||||
|
|
||||||
@property
|
def _get_query(self):
|
||||||
def _query(self):
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
def _get_query_filter(self):
|
||||||
def _query_filter(self):
|
|
||||||
criterias = []
|
criterias = []
|
||||||
for k, v in self.query_params.items():
|
for k, v in self.query_params.items():
|
||||||
if k == 'correspondent__id':
|
if k == 'correspondent__id':
|
||||||
@ -185,16 +183,33 @@ class DelayedQuery:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
def _get_query_sortedby(self):
|
||||||
def _query_sortedby(self):
|
if not 'ordering' in self.query_params:
|
||||||
# if not 'ordering' in self.query_params:
|
return None, False
|
||||||
return None, False
|
|
||||||
|
|
||||||
# o: str = self.query_params['ordering']
|
field: str = self.query_params['ordering']
|
||||||
# if o.startswith('-'):
|
|
||||||
# return o[1:], True
|
sort_fields_map = {
|
||||||
# else:
|
"created": "created",
|
||||||
# return o, False
|
"modified": "modified",
|
||||||
|
"added": "added",
|
||||||
|
"title": "title",
|
||||||
|
"correspondent__name": "correspondent",
|
||||||
|
"document_type__name": "type",
|
||||||
|
"archive_serial_number": "asn",
|
||||||
|
"score": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.startswith('-'):
|
||||||
|
field = field[1:]
|
||||||
|
reverse = True
|
||||||
|
else:
|
||||||
|
reverse = False
|
||||||
|
|
||||||
|
if field not in sort_fields_map:
|
||||||
|
return None, False
|
||||||
|
else:
|
||||||
|
return sort_fields_map[field], reverse
|
||||||
|
|
||||||
def __init__(self, searcher: Searcher, query_params, page_size):
|
def __init__(self, searcher: Searcher, query_params, page_size):
|
||||||
self.searcher = searcher
|
self.searcher = searcher
|
||||||
@ -211,13 +226,13 @@ class DelayedQuery:
|
|||||||
if item.start in self.saved_results:
|
if item.start in self.saved_results:
|
||||||
return self.saved_results[item.start]
|
return self.saved_results[item.start]
|
||||||
|
|
||||||
q, mask = self._query
|
q, mask = self._get_query()
|
||||||
sortedby, reverse = self._query_sortedby
|
sortedby, reverse = self._get_query_sortedby()
|
||||||
|
|
||||||
page: ResultsPage = self.searcher.search_page(
|
page: ResultsPage = self.searcher.search_page(
|
||||||
q,
|
q,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
filter=self._query_filter,
|
filter=self._get_query_filter(),
|
||||||
pagenum=math.floor(item.start / self.page_size) + 1,
|
pagenum=math.floor(item.start / self.page_size) + 1,
|
||||||
pagelen=self.page_size,
|
pagelen=self.page_size,
|
||||||
sortedby=sortedby,
|
sortedby=sortedby,
|
||||||
@ -227,7 +242,9 @@ class DelayedQuery:
|
|||||||
surround=50)
|
surround=50)
|
||||||
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
||||||
|
|
||||||
if not self.first_score and len(page.results) > 0:
|
if (not self.first_score and
|
||||||
|
len(page.results) > 0 and
|
||||||
|
sortedby is None):
|
||||||
self.first_score = page.results[0].score
|
self.first_score = page.results[0].score
|
||||||
|
|
||||||
if self.first_score:
|
if self.first_score:
|
||||||
@ -243,8 +260,7 @@ class DelayedQuery:
|
|||||||
|
|
||||||
class DelayedFullTextQuery(DelayedQuery):
|
class DelayedFullTextQuery(DelayedQuery):
|
||||||
|
|
||||||
@property
|
def _get_query(self):
|
||||||
def _query(self):
|
|
||||||
q_str = self.query_params['query']
|
q_str = self.query_params['query']
|
||||||
qp = MultifieldParser(
|
qp = MultifieldParser(
|
||||||
["content", "title", "correspondent", "tag", "type"],
|
["content", "title", "correspondent", "tag", "type"],
|
||||||
@ -261,8 +277,7 @@ class DelayedFullTextQuery(DelayedQuery):
|
|||||||
|
|
||||||
class DelayedMoreLikeThisQuery(DelayedQuery):
|
class DelayedMoreLikeThisQuery(DelayedQuery):
|
||||||
|
|
||||||
@property
|
def _get_query(self):
|
||||||
def _query(self):
|
|
||||||
more_like_doc_id = int(self.query_params['more_like_id'])
|
more_like_doc_id = int(self.query_params['more_like_id'])
|
||||||
content = Document.objects.get(id=more_like_doc_id).content
|
content = Document.objects.get(id=more_like_doc_id).content
|
||||||
|
|
||||||
|
@ -471,6 +471,31 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
|||||||
self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||||
self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||||
|
|
||||||
|
def test_search_sorting(self):
|
||||||
|
c1 = Correspondent.objects.create(name="corres Ax")
|
||||||
|
c2 = Correspondent.objects.create(name="corres Cx")
|
||||||
|
c3 = Correspondent.objects.create(name="corres Bx")
|
||||||
|
d1 = Document.objects.create(checksum="1", correspondent=c1, content="test", archive_serial_number=2, title="3")
|
||||||
|
d2 = Document.objects.create(checksum="2", correspondent=c2, content="test", archive_serial_number=3, title="2")
|
||||||
|
d3 = Document.objects.create(checksum="3", correspondent=c3, content="test", archive_serial_number=1, title="1")
|
||||||
|
|
||||||
|
with AsyncWriter(index.open_index()) as writer:
|
||||||
|
for doc in Document.objects.all():
|
||||||
|
index.update_document(writer, doc)
|
||||||
|
|
||||||
|
def search_query(q):
|
||||||
|
r = self.client.get("/api/documents/?query=test" + q)
|
||||||
|
self.assertEqual(r.status_code, 200)
|
||||||
|
return [hit['id'] for hit in r.data['results']]
|
||||||
|
|
||||||
|
self.assertListEqual(search_query("&ordering=archive_serial_number"), [d3.id, d1.id, d2.id])
|
||||||
|
self.assertListEqual(search_query("&ordering=-archive_serial_number"), [d2.id, d1.id, d3.id])
|
||||||
|
self.assertListEqual(search_query("&ordering=title"), [d3.id, d2.id, d1.id])
|
||||||
|
self.assertListEqual(search_query("&ordering=-title"), [d1.id, d2.id, d3.id])
|
||||||
|
self.assertListEqual(search_query("&ordering=correspondent__name"), [d1.id, d3.id, d2.id])
|
||||||
|
self.assertListEqual(search_query("&ordering=-correspondent__name"), [d2.id, d3.id, d1.id])
|
||||||
|
|
||||||
|
|
||||||
def test_statistics(self):
|
def test_statistics(self):
|
||||||
|
|
||||||
doc1 = Document.objects.create(title="none1", checksum="A")
|
doc1 = Document.objects.create(title="none1", checksum="A")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user