diff --git a/src/documents/index.py b/src/documents/index.py index 0b0493514..0c2bd1bfe 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -1,6 +1,7 @@ import logging import math import os +from collections import Counter from contextlib import contextmanager from dateutil.parser import isoparse @@ -17,17 +18,21 @@ from whoosh.fields import NUMERIC from whoosh.fields import TEXT from whoosh.fields import Schema from whoosh.highlight import HtmlFormatter +from whoosh.index import FileIndex from whoosh.index import create_in from whoosh.index import exists_in from whoosh.index import open_dir from whoosh.qparser import MultifieldParser +from whoosh.qparser import QueryParser from whoosh.qparser.dateparse import DateParserPlugin +from whoosh.scoring import TF_IDF from whoosh.searching import ResultsPage from whoosh.searching import Searcher from whoosh.writing import AsyncWriter from documents.models import Document from documents.models import Note +from documents.models import User logger = logging.getLogger("paperless.index") @@ -238,15 +243,9 @@ class DelayedQuery: elif k == "storage_path__isnull": criterias.append(query.Term("has_path", v == "false")) - user_criterias = [query.Term("has_owner", False)] - if "user" in self.query_params: - if self.query_params["is_superuser"]: # superusers see all docs - user_criterias = [] - else: - user_criterias.append(query.Term("owner_id", self.query_params["user"])) - user_criterias.append( - query.Term("viewer_id", str(self.query_params["user"])), - ) + user_criterias = get_permissions_criterias( + user=self.user, + ) if len(criterias) > 0: if len(user_criterias) > 0: criterias.append(query.Or(user_criterias)) @@ -282,12 +281,13 @@ class DelayedQuery: else: return sort_fields_map[field], reverse - def __init__(self, searcher: Searcher, query_params, page_size): + def __init__(self, searcher: Searcher, query_params, page_size, user): self.searcher = searcher self.query_params = query_params self.page_size = page_size self.saved_results = dict() self.first_score = None + self.user = user def __len__(self): page = self[0:1] @@ -368,13 +368,42 @@ class DelayedMoreLikeThisQuery(DelayedQuery): return q, mask -def autocomplete(ix, term, limit=10): - with ix.reader() as reader: - terms = [] - for score, t in reader.most_distinctive_terms( - "content", - number=limit, - prefix=term.lower(), - ): - terms.append(t) - return terms +def autocomplete(ix: FileIndex, term: str, limit: int = 10, user: User = None): + """ + Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions + and without scoring + """ + terms = [] + + with ix.searcher(weighting=TF_IDF()) as s: + qp = QueryParser("content", schema=ix.schema) + q = qp.parse(f"{term.lower()}*") + user_criterias = get_permissions_criterias(user) + + results = s.search( + q, + terms=True, + filter=query.Or(user_criterias) if user_criterias is not None else None, + ) + + termCounts = Counter() + if results.has_matched_terms(): + for hit in results: + for _, term in hit.matched_terms(): + termCounts[term] += 1 + terms = [t for t, _ in termCounts.most_common(limit)] + + return terms + + +def get_permissions_criterias(user: User = None): + user_criterias = [query.Term("has_owner", False)] + if user is not None: + if user.is_superuser: # superusers see all docs + user_criterias = [] + else: + user_criterias.append(query.Term("owner_id", user.id)) + user_criterias.append( + query.Term("viewer_id", str(user.id)), + ) + return user_criterias diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index dd872fe78..a2f92a226 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -832,7 +832,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): @mock.patch("documents.index.autocomplete") def test_search_autocomplete(self, m): - m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] + m.side_effect = lambda ix, term, limit, user: [term for _ in range(limit)] response = self.client.get("/api/search/autocomplete/?term=test") self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -852,6 +852,66 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 10) + def test_search_autocomplete_respect_permissions(self): + """ + GIVEN: + - Multiple users and documents with & without permissions + WHEN: + - API reuqest for autocomplete is made by user with or without permissions + THEN: + - Terms only within docs user has access to are returned + """ + u1 = User.objects.create_user("user1") + u2 = User.objects.create_user("user2") + + self.client.force_authenticate(user=u1) + + d1 = Document.objects.create( + title="doc1", + content="apples", + checksum="1", + owner=u1, + ) + d2 = Document.objects.create( + title="doc2", + content="applebaum", + checksum="2", + owner=u1, + ) + d3 = Document.objects.create( + title="doc3", + content="appletini", + checksum="3", + owner=u1, + ) + + with AsyncWriter(index.open_index()) as writer: + index.update_document(writer, d1) + index.update_document(writer, d2) + index.update_document(writer, d3) + + response = self.client.get("/api/search/autocomplete/?term=app") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"]) + + d3.owner = u2 + + with AsyncWriter(index.open_index()) as writer: + index.update_document(writer, d3) + + response = self.client.get("/api/search/autocomplete/?term=app") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, [b"apples", b"applebaum"]) + + assign_perm("view_document", u1, d3) + + with AsyncWriter(index.open_index()) as writer: + index.update_document(writer, d3) + + response = self.client.get("/api/search/autocomplete/?term=app") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"]) + @pytest.mark.skip(reason="Not implemented yet") def test_search_spelling_correction(self): with AsyncWriter(index.open_index()) as writer: diff --git a/src/documents/tests/test_index.py b/src/documents/tests/test_index.py index 11cbee443..24bc26d4c 100644 --- a/src/documents/tests/test_index.py +++ b/src/documents/tests/test_index.py @@ -25,13 +25,13 @@ class TestAutoComplete(DirectoriesMixin, TestCase): self.assertListEqual( index.autocomplete(ix, "tes"), - [b"test3", b"test", b"test2"], + [b"test2", b"test", b"test3"], ) self.assertListEqual( index.autocomplete(ix, "tes", limit=3), - [b"test3", b"test", b"test2"], + [b"test2", b"test", b"test3"], ) - self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test3"]) + self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test2"]) self.assertListEqual(index.autocomplete(ix, "tes", limit=0), []) def test_archive_serial_number_ranging(self): diff --git a/src/documents/views.py b/src/documents/views.py index bfe2b3e6f..ebd6f0f26 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -598,15 +598,6 @@ class UnifiedSearchViewSet(DocumentViewSet): if self._is_search_request(): from documents import index - if hasattr(self.request, "user"): - # pass user to query for perms - self.request.query_params._mutable = True - self.request.query_params["user"] = self.request.user.id - self.request.query_params[ - "is_superuser" - ] = self.request.user.is_superuser - self.request.query_params._mutable = False - if "query" in self.request.query_params: query_class = index.DelayedFullTextQuery elif "more_like_id" in self.request.query_params: @@ -618,6 +609,7 @@ class UnifiedSearchViewSet(DocumentViewSet): self.searcher, self.request.query_params, self.paginator.get_page_size(self.request), + self.request.user, ) else: return super().filter_queryset(queryset) @@ -817,6 +809,8 @@ class SearchAutoCompleteView(APIView): permission_classes = (IsAuthenticated,) def get(self, request, format=None): + user = self.request.user if hasattr(self.request, "user") else None + if "term" in request.query_params: term = request.query_params["term"] else: @@ -833,7 +827,14 @@ class SearchAutoCompleteView(APIView): ix = index.open_index() - return Response(index.autocomplete(ix, term, limit)) + return Response( + index.autocomplete( + ix, + term, + limit, + user, + ), + ) class StatisticsView(APIView):