diff --git a/src/documents/index.py b/src/documents/index.py index 89e56e930..2c851c9ea 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -2,6 +2,7 @@ import logging import os from contextlib import contextmanager +import math from django.conf import settings from whoosh import highlight, classify, query from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME @@ -9,8 +10,10 @@ from whoosh.highlight import Formatter, get_text from whoosh.index import create_in, exists_in, open_dir from whoosh.qparser import MultifieldParser from whoosh.qparser.dateparse import DateParserPlugin +from whoosh.searching import ResultsPage from whoosh.writing import AsyncWriter +from documents.models import Document logger = logging.getLogger("paperless.index") @@ -66,6 +69,7 @@ def get_schema(): title=TEXT(stored=True), content=TEXT(), correspondent=TEXT(stored=True), + correspondent_id=NUMERIC(stored=True, numtype=int), tag=KEYWORD(stored=True, commas=True, scorable=True, lowercase=True), type=TEXT(stored=True), created=DATETIME(stored=True, sortable=True), @@ -109,6 +113,7 @@ def update_document(writer, doc): title=doc.title, content=doc.content, correspondent=doc.correspondent.name if doc.correspondent else None, + correspondent_id=doc.correspondent.id if doc.correspondent else None, tag=tags if tags else None, type=doc.document_type.name if doc.document_type else None, created=doc.created, @@ -181,6 +186,65 @@ def query_page(ix, page, querystring, more_like_doc_id, more_like_doc_content): searcher.close() +class DelayedQuery: + + @property + def _query(self): + if 'query' in self.query_params: + qp = MultifieldParser( + ["content", "title", "correspondent", "tag", "type"], + self.ix.schema) + qp.add_plugin(DateParserPlugin()) + q = qp.parse(self.query_params['query']) + elif 'more_like_id' in self.query_params: + more_like_doc_id = int(self.query_params['more_like_id']) + content = Document.objects.get(id=more_like_doc_id).content + + docnum = self.searcher.document_number(id=more_like_doc_id) + kts = self.searcher.key_terms_from_text( + 'content', content, numterms=20, + model=classify.Bo1Model, normalize=False) + q = query.Or( + [query.Term('content', word, boost=weight) + for word, weight in kts]) + else: + raise ValueError( + "Either query or more_like_id is required." + ) + return q + + @property + def _query_filter(self): + criterias = [] + for k, v in self.query_params.items(): + if k == 'correspondent__id': + criterias.append(query.Term('correspondent_id', v)) + if len(criterias) > 0: + return query.And(criterias) + else: + return None + + def __init__(self, ix, searcher, query_params, page_size): + self.ix = ix + self.searcher = searcher + self.query_params = query_params + self.page_size = page_size + + def __len__(self): + results = self.searcher.search(self._query, limit=1, filter=self._query_filter) + return len(results) + #return 1000 + + def __getitem__(self, item): + page: ResultsPage = self.searcher.search_page( + self._query, + filter=self._query_filter, + pagenum=math.floor(item.start / self.page_size) + 1, + pagelen=self.page_size + ) + return page + + def autocomplete(ix, term, limit=10): with ix.reader() as reader: terms = [] diff --git a/src/documents/views.py b/src/documents/views.py index a3f495d50..209a277b8 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -35,6 +35,7 @@ from rest_framework.viewsets import ( from paperless.db import GnuPG from paperless.views import StandardPagination +from . import index from .bulk_download import OriginalAndArchiveStrategy, OriginalsOnlyStrategy, \ ArchiveOnlyStrategy from .classifier import load_classifier @@ -326,6 +327,45 @@ class DocumentViewSet(RetrieveModelMixin, raise Http404() +class SearchResultSerializer(DocumentSerializer): + + def to_representation(self, instance): + doc = Document.objects.get(id=instance['id']) + # repressentation = super(SearchResultSerializer, self).to_representation(doc) + # repressentation['__search_hit__'] = { + # "score": instance.score + # } + return super(SearchResultSerializer, self).to_representation(doc) + + +class UnifiedSearchViewSet(DocumentViewSet): + + def get_serializer_class(self): + if self._is_search_request(): + return SearchResultSerializer + else: + return DocumentSerializer + + def _is_search_request(self): + return "query" in self.request.query_params + + def filter_queryset(self, queryset): + + if self._is_search_request(): + ix = index.open_index() + return index.DelayedQuery(ix, self.searcher, self.request.query_params, self.paginator.page_size) + else: + return super(UnifiedSearchViewSet, self).filter_queryset(queryset) + + def list(self, request, *args, **kwargs): + if self._is_search_request(): + ix = index.open_index() + with ix.searcher() as s: + self.searcher = s + return super(UnifiedSearchViewSet, self).list(request) + else: + return super(UnifiedSearchViewSet, self).list(request) + class LogViewSet(ViewSet): permission_classes = (IsAuthenticated,) diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 4e0b8f191..176fce257 100755 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -12,7 +12,7 @@ from django.utils.translation import gettext_lazy as _ from paperless.consumers import StatusConsumer from documents.views import ( CorrespondentViewSet, - DocumentViewSet, + UnifiedSearchViewSet, LogViewSet, TagViewSet, DocumentTypeViewSet, @@ -31,7 +31,7 @@ from paperless.views import FaviconView api_router = DefaultRouter() api_router.register(r"correspondents", CorrespondentViewSet) api_router.register(r"document_types", DocumentTypeViewSet) -api_router.register(r"documents", DocumentViewSet) +api_router.register(r"documents", UnifiedSearchViewSet) api_router.register(r"logs", LogViewSet, basename="logs") api_router.register(r"tags", TagViewSet) api_router.register(r"saved_views", SavedViewViewSet)