Respect permissions for autocomplete suggestions

This commit is contained in:
shamoon 2023-05-08 14:24:37 -07:00
parent f8f5a77744
commit 66a0783e7b
4 changed files with 124 additions and 34 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import math import math
import os import os
from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
from dateutil.parser import isoparse from dateutil.parser import isoparse
@ -17,17 +18,21 @@ from whoosh.fields import NUMERIC
from whoosh.fields import TEXT from whoosh.fields import TEXT
from whoosh.fields import Schema from whoosh.fields import Schema
from whoosh.highlight import HtmlFormatter from whoosh.highlight import HtmlFormatter
from whoosh.index import FileIndex
from whoosh.index import create_in from whoosh.index import create_in
from whoosh.index import exists_in from whoosh.index import exists_in
from whoosh.index import open_dir from whoosh.index import open_dir
from whoosh.qparser import MultifieldParser from whoosh.qparser import MultifieldParser
from whoosh.qparser import QueryParser
from whoosh.qparser.dateparse import DateParserPlugin from whoosh.qparser.dateparse import DateParserPlugin
from whoosh.scoring import TF_IDF
from whoosh.searching import ResultsPage from whoosh.searching import ResultsPage
from whoosh.searching import Searcher from whoosh.searching import Searcher
from whoosh.writing import AsyncWriter from whoosh.writing import AsyncWriter
from documents.models import Document from documents.models import Document
from documents.models import Note from documents.models import Note
from documents.models import User
logger = logging.getLogger("paperless.index") logger = logging.getLogger("paperless.index")
@ -238,15 +243,9 @@ class DelayedQuery:
elif k == "storage_path__isnull": elif k == "storage_path__isnull":
criterias.append(query.Term("has_path", v == "false")) criterias.append(query.Term("has_path", v == "false"))
user_criterias = [query.Term("has_owner", False)] user_criterias = get_permissions_criterias(
if "user" in self.query_params: user=self.user,
if self.query_params["is_superuser"]: # superusers see all docs )
user_criterias = []
else:
user_criterias.append(query.Term("owner_id", self.query_params["user"]))
user_criterias.append(
query.Term("viewer_id", str(self.query_params["user"])),
)
if len(criterias) > 0: if len(criterias) > 0:
if len(user_criterias) > 0: if len(user_criterias) > 0:
criterias.append(query.Or(user_criterias)) criterias.append(query.Or(user_criterias))
@ -282,12 +281,13 @@ class DelayedQuery:
else: else:
return sort_fields_map[field], reverse return sort_fields_map[field], reverse
def __init__(self, searcher: Searcher, query_params, page_size): def __init__(self, searcher: Searcher, query_params, page_size, user):
self.searcher = searcher self.searcher = searcher
self.query_params = query_params self.query_params = query_params
self.page_size = page_size self.page_size = page_size
self.saved_results = dict() self.saved_results = dict()
self.first_score = None self.first_score = None
self.user = user
def __len__(self): def __len__(self):
page = self[0:1] page = self[0:1]
@ -368,13 +368,42 @@ class DelayedMoreLikeThisQuery(DelayedQuery):
return q, mask return q, mask
def autocomplete(ix, term, limit=10): def autocomplete(ix: FileIndex, term: str, limit: int = 10, user: User = None):
with ix.reader() as reader: """
terms = [] Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions
for score, t in reader.most_distinctive_terms( and without scoring
"content", """
number=limit, terms = []
prefix=term.lower(),
): with ix.searcher(weighting=TF_IDF()) as s:
terms.append(t) qp = QueryParser("content", schema=ix.schema)
return terms q = qp.parse(f"{term.lower()}*")
user_criterias = get_permissions_criterias(user)
results = s.search(
q,
terms=True,
filter=query.Or(user_criterias) if user_criterias is not None else None,
)
termCounts = Counter()
if results.has_matched_terms():
for hit in results:
for _, term in hit.matched_terms():
termCounts[term] += 1
terms = [t for t, _ in termCounts.most_common(limit)]
return terms
def get_permissions_criterias(user: User = None):
user_criterias = [query.Term("has_owner", False)]
if user is not None:
if user.is_superuser: # superusers see all docs
user_criterias = []
else:
user_criterias.append(query.Term("owner_id", user.id))
user_criterias.append(
query.Term("viewer_id", str(user.id)),
)
return user_criterias

View File

@ -832,7 +832,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
@mock.patch("documents.index.autocomplete") @mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m): def test_search_autocomplete(self, m):
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] m.side_effect = lambda ix, term, limit, user: [term for _ in range(limit)]
response = self.client.get("/api/search/autocomplete/?term=test") response = self.client.get("/api/search/autocomplete/?term=test")
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -852,6 +852,66 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 10) self.assertEqual(len(response.data), 10)
def test_search_autocomplete_respect_permissions(self):
"""
GIVEN:
- Multiple users and documents with & without permissions
WHEN:
- API reuqest for autocomplete is made by user with or without permissions
THEN:
- Terms only within docs user has access to are returned
"""
u1 = User.objects.create_user("user1")
u2 = User.objects.create_user("user2")
self.client.force_authenticate(user=u1)
d1 = Document.objects.create(
title="doc1",
content="apples",
checksum="1",
owner=u1,
)
d2 = Document.objects.create(
title="doc2",
content="applebaum",
checksum="2",
owner=u1,
)
d3 = Document.objects.create(
title="doc3",
content="appletini",
checksum="3",
owner=u1,
)
with AsyncWriter(index.open_index()) as writer:
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get("/api/search/autocomplete/?term=app")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
d3.owner = u2
with AsyncWriter(index.open_index()) as writer:
index.update_document(writer, d3)
response = self.client.get("/api/search/autocomplete/?term=app")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [b"apples", b"applebaum"])
assign_perm("view_document", u1, d3)
with AsyncWriter(index.open_index()) as writer:
index.update_document(writer, d3)
response = self.client.get("/api/search/autocomplete/?term=app")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
@pytest.mark.skip(reason="Not implemented yet") @pytest.mark.skip(reason="Not implemented yet")
def test_search_spelling_correction(self): def test_search_spelling_correction(self):
with AsyncWriter(index.open_index()) as writer: with AsyncWriter(index.open_index()) as writer:

View File

@ -25,13 +25,13 @@ class TestAutoComplete(DirectoriesMixin, TestCase):
self.assertListEqual( self.assertListEqual(
index.autocomplete(ix, "tes"), index.autocomplete(ix, "tes"),
[b"test3", b"test", b"test2"], [b"test2", b"test", b"test3"],
) )
self.assertListEqual( self.assertListEqual(
index.autocomplete(ix, "tes", limit=3), index.autocomplete(ix, "tes", limit=3),
[b"test3", b"test", b"test2"], [b"test2", b"test", b"test3"],
) )
self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test3"]) self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test2"])
self.assertListEqual(index.autocomplete(ix, "tes", limit=0), []) self.assertListEqual(index.autocomplete(ix, "tes", limit=0), [])
def test_archive_serial_number_ranging(self): def test_archive_serial_number_ranging(self):

View File

@ -598,15 +598,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
if self._is_search_request(): if self._is_search_request():
from documents import index from documents import index
if hasattr(self.request, "user"):
# pass user to query for perms
self.request.query_params._mutable = True
self.request.query_params["user"] = self.request.user.id
self.request.query_params[
"is_superuser"
] = self.request.user.is_superuser
self.request.query_params._mutable = False
if "query" in self.request.query_params: if "query" in self.request.query_params:
query_class = index.DelayedFullTextQuery query_class = index.DelayedFullTextQuery
elif "more_like_id" in self.request.query_params: elif "more_like_id" in self.request.query_params:
@ -618,6 +609,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
self.searcher, self.searcher,
self.request.query_params, self.request.query_params,
self.paginator.get_page_size(self.request), self.paginator.get_page_size(self.request),
self.request.user,
) )
else: else:
return super().filter_queryset(queryset) return super().filter_queryset(queryset)
@ -817,6 +809,8 @@ class SearchAutoCompleteView(APIView):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
def get(self, request, format=None): def get(self, request, format=None):
user = self.request.user if hasattr(self.request, "user") else None
if "term" in request.query_params: if "term" in request.query_params:
term = request.query_params["term"] term = request.query_params["term"]
else: else:
@ -833,7 +827,14 @@ class SearchAutoCompleteView(APIView):
ix = index.open_index() ix = index.open_index()
return Response(index.autocomplete(ix, term, limit)) return Response(
index.autocomplete(
ix,
term,
limit,
user,
),
)
class StatisticsView(APIView): class StatisticsView(APIView):