Feature: global search, keyboard shortcuts / hotkey support (#6449)

This commit is contained in:
shamoon
2024-05-02 09:15:56 -07:00
committed by GitHub
parent 40289cd714
commit c6e7d06bb7
51 changed files with 2970 additions and 683 deletions

View File

@@ -796,6 +796,34 @@ class DocumentSerializer(
)
class SearchResultSerializer(DocumentSerializer):
def to_representation(self, instance):
doc = (
Document.objects.select_related(
"correspondent",
"storage_path",
"document_type",
"owner",
)
.prefetch_related("tags", "custom_fields", "notes")
.get(id=instance["id"])
)
notes = ",".join(
[str(c.note) for c in doc.notes.all()],
)
r = super().to_representation(doc)
r["__search_hit__"] = {
"score": instance.score,
"highlights": instance.highlights("content", text=doc.content),
"note_highlights": (
instance.highlights("notes", text=notes) if doc else None
),
"rank": instance.rank,
}
return r
class SavedViewFilterRuleSerializer(serializers.ModelSerializer):
class Meta:
model = SavedViewFilterRule

View File

@@ -4,6 +4,7 @@ from unittest import mock
import pytest
from dateutil.relativedelta import relativedelta
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.test import override_settings
@@ -20,9 +21,13 @@ from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import Note
from documents.models import SavedView
from documents.models import StoragePath
from documents.models import Tag
from documents.models import Workflow
from documents.tests.utils import DirectoriesMixin
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
@@ -1153,3 +1158,110 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
search_query("&ordering=-owner"),
[d3.id, d2.id, d1.id],
)
def test_global_search(self):
"""
GIVEN:
- Multiple documents and objects
WHEN:
- Global search query is made
THEN:
- Appropriately filtered results are returned
"""
d1 = Document.objects.create(
title="invoice doc1",
content="the thing i bought at a shop and paid with bank account",
checksum="A",
pk=1,
)
d2 = Document.objects.create(
title="bank statement doc2",
content="things i paid for in august",
checksum="B",
pk=2,
)
d3 = Document.objects.create(
title="tax bill doc3",
content="no b word",
checksum="C",
pk=3,
)
with index.open_index_writer() as writer:
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
Correspondent.objects.create(name="correspondent 2")
document_type1 = DocumentType.objects.create(name="bank invoice")
DocumentType.objects.create(name="invoice")
storage_path1 = StoragePath.objects.create(name="bank path 1", path="path1")
StoragePath.objects.create(name="path 2", path="path2")
tag1 = Tag.objects.create(name="bank tag1")
Tag.objects.create(name="tag2")
user1 = User.objects.create_superuser("bank user1")
User.objects.create_user("user2")
group1 = Group.objects.create(name="bank group1")
Group.objects.create(name="group2")
SavedView.objects.create(
name="bank view",
show_on_dashboard=True,
show_in_sidebar=True,
sort_field="",
owner=user1,
)
mail_account1 = MailAccount.objects.create(name="bank mail account 1")
mail_account2 = MailAccount.objects.create(name="mail account 2")
mail_rule1 = MailRule.objects.create(
name="bank mail rule 1",
account=mail_account1,
action=MailRule.MailAction.MOVE,
)
MailRule.objects.create(
name="mail rule 2",
account=mail_account2,
action=MailRule.MailAction.MOVE,
)
custom_field1 = CustomField.objects.create(
name="bank custom field 1",
data_type=CustomField.FieldDataType.STRING,
)
CustomField.objects.create(
name="custom field 2",
data_type=CustomField.FieldDataType.INT,
)
workflow1 = Workflow.objects.create(name="bank workflow 1")
Workflow.objects.create(name="workflow 2")
self.client.force_authenticate(user1)
response = self.client.get("/api/search/?query=bank")
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data
self.assertEqual(len(results["documents"]), 2)
self.assertEqual(len(results["saved_views"]), 1)
self.assertNotEqual(results["documents"][0]["id"], d3.id)
self.assertNotEqual(results["documents"][1]["id"], d3.id)
self.assertEqual(results["correspondents"][0]["id"], correspondent1.id)
self.assertEqual(results["document_types"][0]["id"], document_type1.id)
self.assertEqual(results["storage_paths"][0]["id"], storage_path1.id)
self.assertEqual(results["tags"][0]["id"], tag1.id)
self.assertEqual(results["users"][0]["id"], user1.id)
self.assertEqual(results["groups"][0]["id"], group1.id)
self.assertEqual(results["mail_accounts"][0]["id"], mail_account1.id)
self.assertEqual(results["mail_rules"][0]["id"], mail_rule1.id)
self.assertEqual(results["custom_fields"][0]["id"], custom_field1.id)
self.assertEqual(results["workflows"][0]["id"], workflow1.id)
def test_global_search_bad_request(self):
"""
WHEN:
- Global search query is made without or with query < 3 characters
THEN:
- Error is returned
"""
response = self.client.get("/api/search/")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response = self.client.get("/api/search/?query=no")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

View File

@@ -17,6 +17,7 @@ from urllib.parse import urlparse
import pathvalidate
from django.apps import apps
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.db import connections
@@ -137,6 +138,7 @@ from documents.serialisers import DocumentSerializer
from documents.serialisers import DocumentTypeSerializer
from documents.serialisers import PostDocumentSerializer
from documents.serialisers import SavedViewSerializer
from documents.serialisers import SearchResultSerializer
from documents.serialisers import ShareLinkSerializer
from documents.serialisers import StoragePathSerializer
from documents.serialisers import TagSerializer
@@ -152,7 +154,13 @@ from paperless import version
from paperless.celery import app as celery_app
from paperless.config import GeneralConfig
from paperless.db import GnuPG
from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.serialisers import MailAccountSerializer
from paperless_mail.serialisers import MailRuleSerializer
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
@@ -813,34 +821,6 @@ class DocumentViewSet(
return Response(sorted(entries, key=lambda x: x["timestamp"], reverse=True))
class SearchResultSerializer(DocumentSerializer, PassUserMixin):
def to_representation(self, instance):
doc = (
Document.objects.select_related(
"correspondent",
"storage_path",
"document_type",
"owner",
)
.prefetch_related("tags", "custom_fields", "notes")
.get(id=instance["id"])
)
notes = ",".join(
[str(c.note) for c in doc.notes.all()],
)
r = super().to_representation(doc)
r["__search_hit__"] = {
"score": instance.score,
"highlights": instance.highlights("content", text=doc.content),
"note_highlights": (
instance.highlights("notes", text=notes) if doc else None
),
"rank": instance.rank,
}
return r
class UnifiedSearchViewSet(DocumentViewSet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1158,6 +1138,189 @@ class SearchAutoCompleteView(APIView):
)
class GlobalSearchView(PassUserMixin):
permission_classes = (IsAuthenticated,)
serializer_class = SearchResultSerializer
def get(self, request, *args, **kwargs):
query = request.query_params.get("query", None)
if query is None:
return HttpResponseBadRequest("Query required")
elif len(query) < 3:
return HttpResponseBadRequest("Query must be at least 3 characters")
db_only = request.query_params.get("db_only", False)
OBJECT_LIMIT = 3
docs = []
if request.user.has_perm("documents.view_document"):
all_docs = get_objects_for_user_owner_aware(
request.user,
"view_document",
Document,
)
# First search by title
docs = all_docs.filter(title__icontains=query)[:OBJECT_LIMIT]
if not db_only and len(docs) < OBJECT_LIMIT:
# If we don't have enough results, search by content
from documents import index
with index.open_index_searcher() as s:
q, _ = index.DelayedFullTextQuery(
s,
request.query_params,
10,
request.user,
)._get_query()
results = s.search(q, limit=OBJECT_LIMIT)
docs = docs | all_docs.filter(id__in=[r["id"] for r in results])
saved_views = (
SavedView.objects.filter(owner=request.user, name__icontains=query)[
:OBJECT_LIMIT
]
if request.user.has_perm("documents.view_savedview")
else []
)
tags = (
get_objects_for_user_owner_aware(request.user, "view_tag", Tag).filter(
name__icontains=query,
)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_tag")
else []
)
correspondents = (
get_objects_for_user_owner_aware(
request.user,
"view_correspondent",
Correspondent,
).filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_correspondent")
else []
)
document_types = (
get_objects_for_user_owner_aware(
request.user,
"view_documenttype",
DocumentType,
).filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_documenttype")
else []
)
storage_paths = (
get_objects_for_user_owner_aware(
request.user,
"view_storagepath",
StoragePath,
).filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_storagepath")
else []
)
users = (
User.objects.filter(username__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("auth.view_user")
else []
)
groups = (
Group.objects.filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("auth.view_group")
else []
)
mail_rules = (
MailRule.objects.filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("paperless_mail.view_mailrule")
else []
)
mail_accounts = (
MailAccount.objects.filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("paperless_mail.view_mailaccount")
else []
)
workflows = (
Workflow.objects.filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_workflow")
else []
)
custom_fields = (
CustomField.objects.filter(name__icontains=query)[:OBJECT_LIMIT]
if request.user.has_perm("documents.view_customfield")
else []
)
context = {
"request": request,
}
docs_serializer = DocumentSerializer(docs, many=True, context=context)
saved_views_serializer = SavedViewSerializer(
saved_views,
many=True,
context=context,
)
tags_serializer = TagSerializer(tags, many=True, context=context)
correspondents_serializer = CorrespondentSerializer(
correspondents,
many=True,
context=context,
)
document_types_serializer = DocumentTypeSerializer(
document_types,
many=True,
context=context,
)
storage_paths_serializer = StoragePathSerializer(
storage_paths,
many=True,
context=context,
)
users_serializer = UserSerializer(users, many=True, context=context)
groups_serializer = GroupSerializer(groups, many=True, context=context)
mail_rules_serializer = MailRuleSerializer(
mail_rules,
many=True,
context=context,
)
mail_accounts_serializer = MailAccountSerializer(
mail_accounts,
many=True,
context=context,
)
workflows_serializer = WorkflowSerializer(workflows, many=True, context=context)
custom_fields_serializer = CustomFieldSerializer(
custom_fields,
many=True,
context=context,
)
return Response(
{
"total": len(docs)
+ len(saved_views)
+ len(tags)
+ len(correspondents)
+ len(document_types)
+ len(storage_paths)
+ len(users)
+ len(groups)
+ len(mail_rules)
+ len(mail_accounts)
+ len(workflows)
+ len(custom_fields),
"documents": docs_serializer.data,
"saved_views": saved_views_serializer.data,
"tags": tags_serializer.data,
"correspondents": correspondents_serializer.data,
"document_types": document_types_serializer.data,
"storage_paths": storage_paths_serializer.data,
"users": users_serializer.data,
"groups": groups_serializer.data,
"mail_rules": mail_rules_serializer.data,
"mail_accounts": mail_accounts_serializer.data,
"workflows": workflows_serializer.data,
"custom_fields": custom_fields_serializer.data,
},
)
class StatisticsView(APIView):
permission_classes = (IsAuthenticated,)