mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	lots of changes for the new unified search
This commit is contained in:
		| @@ -5,12 +5,12 @@ from contextlib import contextmanager | ||||
| import math | ||||
| from django.conf import settings | ||||
| from whoosh import highlight, classify, query | ||||
| from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME | ||||
| from whoosh.highlight import Formatter, get_text | ||||
| from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN | ||||
| from whoosh.highlight import Formatter, get_text, HtmlFormatter | ||||
| from whoosh.index import create_in, exists_in, open_dir | ||||
| from whoosh.qparser import MultifieldParser | ||||
| from whoosh.qparser.dateparse import DateParserPlugin | ||||
| from whoosh.searching import ResultsPage | ||||
| from whoosh.searching import ResultsPage, Searcher | ||||
| from whoosh.writing import AsyncWriter | ||||
|  | ||||
| from documents.models import Document | ||||
| @@ -18,63 +18,53 @@ from documents.models import Document | ||||
| logger = logging.getLogger("paperless.index") | ||||
|  | ||||
|  | ||||
| class JsonFormatter(Formatter): | ||||
|     def __init__(self): | ||||
|         self.seen = {} | ||||
|  | ||||
|     def format_token(self, text, token, replace=False): | ||||
|         ttext = self._text(get_text(text, token, replace)) | ||||
|         return {'text': ttext, 'highlight': 'true'} | ||||
|  | ||||
|     def format_fragment(self, fragment, replace=False): | ||||
|         output = [] | ||||
|         index = fragment.startchar | ||||
|         text = fragment.text | ||||
|         amend_token = None | ||||
|         for t in fragment.matches: | ||||
|             if t.startchar is None: | ||||
|                 continue | ||||
|             if t.startchar < index: | ||||
|                 continue | ||||
|             if t.startchar > index: | ||||
|                 text_inbetween = text[index:t.startchar] | ||||
|                 if amend_token and t.startchar - index < 10: | ||||
|                     amend_token['text'] += text_inbetween | ||||
|                 else: | ||||
|                     output.append({'text': text_inbetween, | ||||
|                                    'highlight': False}) | ||||
|                     amend_token = None | ||||
|             token = self.format_token(text, t, replace) | ||||
|             if amend_token: | ||||
|                 amend_token['text'] += token['text'] | ||||
|             else: | ||||
|                 output.append(token) | ||||
|                 amend_token = token | ||||
|             index = t.endchar | ||||
|         if index < fragment.endchar: | ||||
|             output.append({'text': text[index:fragment.endchar], | ||||
|                            'highlight': False}) | ||||
|         return output | ||||
|  | ||||
|     def format(self, fragments, replace=False): | ||||
|         output = [] | ||||
|         for fragment in fragments: | ||||
|             output.append(self.format_fragment(fragment, replace=replace)) | ||||
|         return output | ||||
|  | ||||
|  | ||||
| def get_schema(): | ||||
|     return Schema( | ||||
|         id=NUMERIC(stored=True, unique=True, numtype=int), | ||||
|         title=TEXT(stored=True), | ||||
|         id=NUMERIC( | ||||
|             stored=True, | ||||
|             unique=True | ||||
|         ), | ||||
|         title=TEXT( | ||||
|             sortable=True | ||||
|         ), | ||||
|         content=TEXT(), | ||||
|         correspondent=TEXT(stored=True), | ||||
|         correspondent_id=NUMERIC(stored=True, numtype=int), | ||||
|         tag=KEYWORD(stored=True, commas=True, scorable=True, lowercase=True), | ||||
|         type=TEXT(stored=True), | ||||
|         created=DATETIME(stored=True, sortable=True), | ||||
|         modified=DATETIME(stored=True, sortable=True), | ||||
|         added=DATETIME(stored=True, sortable=True), | ||||
|         archive_serial_number=NUMERIC( | ||||
|             sortable=True | ||||
|         ), | ||||
|  | ||||
|         correspondent=TEXT( | ||||
|             sortable=True | ||||
|         ), | ||||
|         correspondent_id=NUMERIC(), | ||||
|         has_correspondent=BOOLEAN(), | ||||
|  | ||||
|         tag=KEYWORD( | ||||
|             commas=True, | ||||
|             scorable=True, | ||||
|             lowercase=True | ||||
|         ), | ||||
|         tag_id=KEYWORD( | ||||
|             commas=True, | ||||
|             scorable=True | ||||
|         ), | ||||
|         has_tag=BOOLEAN(), | ||||
|  | ||||
|         type=TEXT( | ||||
|             sortable=True | ||||
|         ), | ||||
|         type_id=NUMERIC(), | ||||
|         has_type=BOOLEAN(), | ||||
|  | ||||
|         created=DATETIME( | ||||
|             sortable=True | ||||
|         ), | ||||
|         modified=DATETIME( | ||||
|             sortable=True | ||||
|         ), | ||||
|         added=DATETIME( | ||||
|             sortable=True | ||||
|         ), | ||||
|  | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @@ -106,18 +96,38 @@ def open_index_writer(ix=None, optimize=False): | ||||
|         writer.commit(optimize=optimize) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def open_index_searcher(ix=None): | ||||
|     if ix: | ||||
|         searcher = ix.searcher() | ||||
|     else: | ||||
|         searcher = open_index().searcher() | ||||
|  | ||||
|     try: | ||||
|         yield searcher | ||||
|     finally: | ||||
|         searcher.close() | ||||
|  | ||||
|  | ||||
| def update_document(writer, doc): | ||||
|     tags = ",".join([t.name for t in doc.tags.all()]) | ||||
|     tags_ids = ",".join([str(t.id) for t in doc.tags.all()]) | ||||
|     writer.update_document( | ||||
|         id=doc.pk, | ||||
|         title=doc.title, | ||||
|         content=doc.content, | ||||
|         correspondent=doc.correspondent.name if doc.correspondent else None, | ||||
|         correspondent_id=doc.correspondent.id if doc.correspondent else None, | ||||
|         has_correspondent=doc.correspondent is not None, | ||||
|         tag=tags if tags else None, | ||||
|         tag_id=tags_ids if tags_ids else None, | ||||
|         has_tag=len(tags) > 0, | ||||
|         type=doc.document_type.name if doc.document_type else None, | ||||
|         type_id=doc.document_type.id if doc.document_type else None, | ||||
|         has_type=doc.document_type is not None, | ||||
|         created=doc.created, | ||||
|         added=doc.added, | ||||
|         archive_serial_number=doc.archive_serial_number, | ||||
|         modified=doc.modified, | ||||
|     ) | ||||
|  | ||||
| @@ -140,78 +150,11 @@ def remove_document_from_index(document): | ||||
|         remove_document(writer, document) | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def query_page(ix, page, querystring, more_like_doc_id, more_like_doc_content): | ||||
|     searcher = ix.searcher() | ||||
|     try: | ||||
|         if querystring: | ||||
|             qp = MultifieldParser( | ||||
|                 ["content", "title", "correspondent", "tag", "type"], | ||||
|                 ix.schema) | ||||
|             qp.add_plugin(DateParserPlugin()) | ||||
|             str_q = qp.parse(querystring) | ||||
|             corrected = searcher.correct_query(str_q, querystring) | ||||
|         else: | ||||
|             str_q = None | ||||
|             corrected = None | ||||
|  | ||||
|         if more_like_doc_id: | ||||
|             docnum = searcher.document_number(id=more_like_doc_id) | ||||
|             kts = searcher.key_terms_from_text( | ||||
|                 'content', more_like_doc_content, numterms=20, | ||||
|                 model=classify.Bo1Model, normalize=False) | ||||
|             more_like_q = query.Or( | ||||
|                 [query.Term('content', word, boost=weight) | ||||
|                  for word, weight in kts]) | ||||
|             result_page = searcher.search_page( | ||||
|                 more_like_q, page, filter=str_q, mask={docnum}) | ||||
|         elif str_q: | ||||
|             result_page = searcher.search_page(str_q, page) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Either querystring or more_like_doc_id is required." | ||||
|             ) | ||||
|  | ||||
|         result_page.results.fragmenter = highlight.ContextFragmenter( | ||||
|             surround=50) | ||||
|         result_page.results.formatter = JsonFormatter() | ||||
|  | ||||
|         if corrected and corrected.query != str_q: | ||||
|             corrected_query = corrected.string | ||||
|         else: | ||||
|             corrected_query = None | ||||
|  | ||||
|         yield result_page, corrected_query | ||||
|     finally: | ||||
|         searcher.close() | ||||
|  | ||||
|  | ||||
| class DelayedQuery: | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|         if 'query' in self.query_params: | ||||
|             qp = MultifieldParser( | ||||
|                 ["content", "title", "correspondent", "tag", "type"], | ||||
|                 self.ix.schema) | ||||
|             qp.add_plugin(DateParserPlugin()) | ||||
|             q = qp.parse(self.query_params['query']) | ||||
|         elif 'more_like_id' in self.query_params: | ||||
|             more_like_doc_id = int(self.query_params['more_like_id']) | ||||
|             content = Document.objects.get(id=more_like_doc_id).content | ||||
|  | ||||
|             docnum = self.searcher.document_number(id=more_like_doc_id) | ||||
|             kts = self.searcher.key_terms_from_text( | ||||
|                 'content', content, numterms=20, | ||||
|                 model=classify.Bo1Model, normalize=False) | ||||
|             q = query.Or( | ||||
|                 [query.Term('content', word, boost=weight) | ||||
|                  for word, weight in kts]) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Either query or more_like_id is required." | ||||
|             ) | ||||
|         return q | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @property | ||||
|     def _query_filter(self): | ||||
| @@ -219,32 +162,114 @@ class DelayedQuery: | ||||
|         for k, v in self.query_params.items(): | ||||
|             if k == 'correspondent__id': | ||||
|                 criterias.append(query.Term('correspondent_id', v)) | ||||
|             elif k == 'tags__id__all': | ||||
|                 for tag_id in v.split(","): | ||||
|                     criterias.append(query.Term('tag_id', tag_id)) | ||||
|             elif k == 'document_type__id': | ||||
|                 criterias.append(query.Term('type_id', v)) | ||||
|             elif k == 'correspondent__isnull': | ||||
|                 criterias.append(query.Term("has_correspondent", v == "false")) | ||||
|             elif k == 'is_tagged': | ||||
|                 criterias.append(query.Term("has_tag", v == "true")) | ||||
|             elif k == 'document_type__isnull': | ||||
|                 criterias.append(query.Term("has_type", v == "false")) | ||||
|             elif k == 'created__date__lt': | ||||
|                 pass | ||||
|             elif k == 'created__date__gt': | ||||
|                 pass | ||||
|             elif k == 'added__date__gt': | ||||
|                 pass | ||||
|             elif k == 'added__date__lt': | ||||
|                 pass | ||||
|         if len(criterias) > 0: | ||||
|             return query.And(criterias) | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     def __init__(self, ix, searcher, query_params, page_size): | ||||
|         self.ix = ix | ||||
|     @property | ||||
|     def _query_sortedby(self): | ||||
|         if not 'ordering' in self.query_params: | ||||
|             return None, False | ||||
|  | ||||
|         o: str = self.query_params['ordering'] | ||||
|         if o.startswith('-'): | ||||
|             return o[1:], True | ||||
|         else: | ||||
|             return o, False | ||||
|  | ||||
|     def __init__(self, searcher: Searcher, query_params, page_size): | ||||
|         self.searcher = searcher | ||||
|         self.query_params = query_params | ||||
|         self.page_size = page_size | ||||
|         self.saved_results = dict() | ||||
|  | ||||
|     def __len__(self): | ||||
|         results = self.searcher.search(self._query, limit=1, filter=self._query_filter) | ||||
|         return len(results) | ||||
|         #return 1000 | ||||
|         page = self[0:1] | ||||
|         return len(page) | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         if item.start in self.saved_results: | ||||
|             return self.saved_results[item.start] | ||||
|  | ||||
|         q, mask = self._query | ||||
|         sortedby, reverse = self._query_sortedby | ||||
|  | ||||
|         print("OY", self.page_size) | ||||
|         page: ResultsPage = self.searcher.search_page( | ||||
|             self._query, | ||||
|             q, | ||||
|             mask=mask, | ||||
|             filter=self._query_filter, | ||||
|             pagenum=math.floor(item.start / self.page_size) + 1, | ||||
|             pagelen=self.page_size | ||||
|             pagelen=self.page_size, | ||||
|             sortedby=sortedby, | ||||
|             reverse=reverse | ||||
|         ) | ||||
|         page.results.fragmenter = highlight.ContextFragmenter( | ||||
|             surround=50) | ||||
|         page.results.formatter = HtmlFormatter(tagname="span", between=" ... ") | ||||
|  | ||||
|         self.saved_results[item.start] = page | ||||
|  | ||||
|         return page | ||||
|  | ||||
|  | ||||
| class DelayedFullTextQuery(DelayedQuery): | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|         q_str = self.query_params['query'] | ||||
|         qp = MultifieldParser( | ||||
|             ["content", "title", "correspondent", "tag", "type"], | ||||
|             self.searcher.ixreader.schema) | ||||
|         qp.add_plugin(DateParserPlugin()) | ||||
|         q = qp.parse(q_str) | ||||
|  | ||||
|         corrected = self.searcher.correct_query(q, q_str) | ||||
|         if corrected.query != q: | ||||
|             corrected_query = corrected.string | ||||
|  | ||||
|         return q, None | ||||
|  | ||||
|  | ||||
| class DelayedMoreLikeThisQuery(DelayedQuery): | ||||
|  | ||||
|     @property | ||||
|     def _query(self): | ||||
|         more_like_doc_id = int(self.query_params['more_like_id']) | ||||
|         content = Document.objects.get(id=more_like_doc_id).content | ||||
|  | ||||
|         docnum = self.searcher.document_number(id=more_like_doc_id) | ||||
|         kts = self.searcher.key_terms_from_text( | ||||
|             'content', content, numterms=20, | ||||
|             model=classify.Bo1Model, normalize=False) | ||||
|         q = query.Or( | ||||
|             [query.Term('content', word, boost=weight) | ||||
|              for word, weight in kts]) | ||||
|         mask = {docnum} | ||||
|  | ||||
|         return q, mask | ||||
|  | ||||
|  | ||||
| def autocomplete(ix, term, limit=10): | ||||
|     with ix.reader() as reader: | ||||
|         terms = [] | ||||
|   | ||||
| @@ -359,7 +359,10 @@ class SavedView(models.Model): | ||||
|  | ||||
|     sort_field = models.CharField( | ||||
|         _("sort field"), | ||||
|         max_length=128) | ||||
|         max_length=128, | ||||
|         null=True, | ||||
|         blank=True | ||||
|     ) | ||||
|     sort_reverse = models.BooleanField( | ||||
|         _("sort reverse"), | ||||
|         default=False) | ||||
| @@ -387,6 +390,8 @@ class SavedViewFilterRule(models.Model): | ||||
|         (17, _("does not have tag")), | ||||
|         (18, _("does not have ASN")), | ||||
|         (19, _("title or content contains")), | ||||
|         (20, _("fulltext query")), | ||||
|         (21, _("more like this")) | ||||
|     ] | ||||
|  | ||||
|     saved_view = models.ForeignKey( | ||||
|   | ||||
| @@ -1,20 +1,10 @@ | ||||
| from django.test import TestCase | ||||
|  | ||||
| from documents import index | ||||
| from documents.index import JsonFormatter | ||||
| from documents.models import Document | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| class JsonFormatterTest(TestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.formatter = JsonFormatter() | ||||
|  | ||||
|     def test_empty_fragments(self): | ||||
|         self.assertListEqual(self.formatter.format([]), []) | ||||
|  | ||||
|  | ||||
| class TestAutoComplete(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     def test_auto_complete(self): | ||||
|   | ||||
| @@ -36,7 +36,6 @@ from rest_framework.viewsets import ( | ||||
|  | ||||
| from paperless.db import GnuPG | ||||
| from paperless.views import StandardPagination | ||||
| from . import index | ||||
| from .bulk_download import OriginalAndArchiveStrategy, OriginalsOnlyStrategy, \ | ||||
|     ArchiveOnlyStrategy | ||||
| from .classifier import load_classifier | ||||
| @@ -332,15 +331,23 @@ class SearchResultSerializer(DocumentSerializer): | ||||
|  | ||||
|     def to_representation(self, instance): | ||||
|         doc = Document.objects.get(id=instance['id']) | ||||
|         # repressentation = super(SearchResultSerializer, self).to_representation(doc) | ||||
|         # repressentation['__search_hit__'] = { | ||||
|         #     "score": instance.score | ||||
|         # } | ||||
|         return super(SearchResultSerializer, self).to_representation(doc) | ||||
|         representation = super(SearchResultSerializer, self).to_representation(doc) | ||||
|         representation['__search_hit__'] = { | ||||
|             "score": instance.score, | ||||
|             "highlights": instance.highlights("content", | ||||
|                                    text=doc.content) if doc else None,  # NOQA: E501 | ||||
|             "rank": instance.rank | ||||
|         } | ||||
|  | ||||
|         return representation | ||||
|  | ||||
|  | ||||
| class UnifiedSearchViewSet(DocumentViewSet): | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(UnifiedSearchViewSet, self).__init__(*args, **kwargs) | ||||
|         self.searcher = None | ||||
|  | ||||
|     def get_serializer_class(self): | ||||
|         if self._is_search_request(): | ||||
|             return SearchResultSerializer | ||||
| @@ -348,25 +355,39 @@ class UnifiedSearchViewSet(DocumentViewSet): | ||||
|             return DocumentSerializer | ||||
|  | ||||
|     def _is_search_request(self): | ||||
|         return "query" in self.request.query_params | ||||
|         return "query" in self.request.query_params or "more_like_id" in self.request.query_params | ||||
|  | ||||
|     def filter_queryset(self, queryset): | ||||
|  | ||||
|         if self._is_search_request(): | ||||
|             ix = index.open_index() | ||||
|             return index.DelayedQuery(ix, self.searcher, self.request.query_params, self.paginator.page_size) | ||||
|             from documents import index | ||||
|  | ||||
|             if "query" in self.request.query_params: | ||||
|                 query_class = index.DelayedFullTextQuery | ||||
|             elif "more_like_id" in self.request.query_params: | ||||
|                 query_class = index.DelayedMoreLikeThisQuery | ||||
|             else: | ||||
|                 raise ValueError() | ||||
|  | ||||
|             return query_class( | ||||
|                 self.searcher, | ||||
|                 self.request.query_params, | ||||
|                 self.paginator.get_page_size(self.request)) | ||||
|         else: | ||||
|             return super(UnifiedSearchViewSet, self).filter_queryset(queryset) | ||||
|  | ||||
|     def list(self, request, *args, **kwargs): | ||||
|         if self._is_search_request(): | ||||
|             ix = index.open_index() | ||||
|             with ix.searcher() as s: | ||||
|                 self.searcher = s | ||||
|                 return super(UnifiedSearchViewSet, self).list(request) | ||||
|             from documents import index | ||||
|             try: | ||||
|                 with index.open_index_searcher() as s: | ||||
|                     self.searcher = s | ||||
|                     return super(UnifiedSearchViewSet, self).list(request) | ||||
|             except Exception as e: | ||||
|                 return HttpResponseBadRequest(str(e)) | ||||
|         else: | ||||
|             return super(UnifiedSearchViewSet, self).list(request) | ||||
|  | ||||
|  | ||||
| class LogViewSet(ViewSet): | ||||
|  | ||||
|     permission_classes = (IsAuthenticated,) | ||||
| @@ -518,74 +539,6 @@ class SelectionDataView(GenericAPIView): | ||||
|         return r | ||||
|  | ||||
|  | ||||
| class SearchView(APIView): | ||||
|  | ||||
|     permission_classes = (IsAuthenticated,) | ||||
|  | ||||
|     def add_infos_to_hit(self, r): | ||||
|         try: | ||||
|             doc = Document.objects.get(id=r['id']) | ||||
|         except Document.DoesNotExist: | ||||
|             logger.warning( | ||||
|                 f"Search index returned a non-existing document: " | ||||
|                 f"id: {r['id']}, title: {r['title']}. " | ||||
|                 f"Search index needs reindex." | ||||
|             ) | ||||
|             doc = None | ||||
|  | ||||
|         return {'id': r['id'], | ||||
|                 'highlights': r.highlights("content", text=doc.content) if doc else None,  # NOQA: E501 | ||||
|                 'score': r.score, | ||||
|                 'rank': r.rank, | ||||
|                 'document': DocumentSerializer(doc).data if doc else None, | ||||
|                 'title': r['title'] | ||||
|                 } | ||||
|  | ||||
|     def get(self, request, format=None): | ||||
|         from documents import index | ||||
|  | ||||
|         if 'query' in request.query_params: | ||||
|             query = request.query_params['query'] | ||||
|         else: | ||||
|             query = None | ||||
|  | ||||
|         if 'more_like' in request.query_params: | ||||
|             more_like_id = request.query_params['more_like'] | ||||
|             more_like_content = Document.objects.get(id=more_like_id).content | ||||
|         else: | ||||
|             more_like_id = None | ||||
|             more_like_content = None | ||||
|  | ||||
|         if not query and not more_like_id: | ||||
|             return Response({ | ||||
|                 'count': 0, | ||||
|                 'page': 0, | ||||
|                 'page_count': 0, | ||||
|                 'corrected_query': None, | ||||
|                 'results': []}) | ||||
|  | ||||
|         try: | ||||
|             page = int(request.query_params.get('page', 1)) | ||||
|         except (ValueError, TypeError): | ||||
|             page = 1 | ||||
|  | ||||
|         if page < 1: | ||||
|             page = 1 | ||||
|  | ||||
|         ix = index.open_index() | ||||
|  | ||||
|         try: | ||||
|             with index.query_page(ix, page, query, more_like_id, more_like_content) as (result_page, corrected_query):  # NOQA: E501 | ||||
|                 return Response( | ||||
|                     {'count': len(result_page), | ||||
|                      'page': result_page.pagenum, | ||||
|                      'page_count': result_page.pagecount, | ||||
|                      'corrected_query': corrected_query, | ||||
|                      'results': list(map(self.add_infos_to_hit, result_page))}) | ||||
|         except Exception as e: | ||||
|             return HttpResponseBadRequest(str(e)) | ||||
|  | ||||
|  | ||||
| class SearchAutoCompleteView(APIView): | ||||
|  | ||||
|     permission_classes = (IsAuthenticated,) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler