From 1df52fd4c10e1e719f177f9cd7f80d8d98db1186 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 25 Apr 2025 00:09:33 -0700 Subject: [PATCH] Unify, respect perms [ci skip] --- src/documents/views.py | 17 +++++++++--- src/paperless/ai/chat.py | 52 +++++++++++++++--------------------- src/paperless/ai/indexing.py | 8 ++---- 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/src/documents/views.py b/src/documents/views.py index f36d05cac..2e848f44c 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -172,7 +172,6 @@ from documents.templating.filepath import validate_filepath_template_and_render from paperless import version from paperless.ai.ai_classifier import get_ai_document_classification from paperless.ai.chat import chat_with_documents -from paperless.ai.chat import chat_with_single_document from paperless.ai.matching import extract_unmatched_names from paperless.ai.matching import match_correspondents_by_name from paperless.ai.matching import match_document_types_by_name @@ -1145,13 +1144,23 @@ class DocumentViewSet( question = request.data["q"] doc_id = request.data.get("document_id", None) if doc_id: - document = Document.objects.get(id=doc_id) + try: + document = Document.objects.get(id=doc_id) + except Document.DoesNotExist: + return HttpResponseBadRequest("Invalid document ID") + if not has_perms_owner_aware(request.user, "view_document", document): return HttpResponseForbidden("Insufficient permissions") - result = chat_with_single_document(document, question, request.user) + documents = [document] else: - result = chat_with_documents(question, request.user) + documents = get_objects_for_user_owner_aware( + request.user, + "view_document", + Document, + ) + + result = chat_with_documents(question, documents) return Response({"answer": result}) diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py index 6e75884d9..3ce109a79 100644 --- a/src/paperless/ai/chat.py +++ b/src/paperless/ai/chat.py @@ -1,20 +1,38 @@ import logging -from django.contrib.auth.models import User from llama_index.core import VectorStoreIndex from llama_index.core.query_engine import RetrieverQueryEngine +from documents.models import Document from paperless.ai.client import AIClient -from paperless.ai.indexing import get_document_retriever from paperless.ai.indexing import load_index logger = logging.getLogger("paperless.ai.chat") -def chat_with_documents(prompt: str, user: User) -> str: - retriever = get_document_retriever(top_k=5) +def chat_with_documents(prompt: str, documents: list[Document]) -> str: client = AIClient() + index = load_index() + + doc_ids = [doc.pk for doc in documents] + + # Filter only the node(s) that match the document IDs + nodes = [ + node + for node in index.docstore.docs.values() + if node.metadata.get("document_id") in doc_ids + ] + + if len(nodes) == 0: + logger.warning("No nodes found for the given documents.") + return "Sorry, I couldn't find any content to answer your question." + + local_index = VectorStoreIndex.from_documents(nodes) + retriever = local_index.as_retriever( + similarity_top_k=3 if len(documents) == 1 else 5, + ) + query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=client.llm, @@ -24,29 +42,3 @@ def chat_with_documents(prompt: str, user: User) -> str: response = query_engine.query(prompt) logger.debug("Document chat response: %s", response) return str(response) - - -def chat_with_single_document(document, question: str, user): - index = load_index() - - # Filter only the node(s) belonging to this doc - nodes = [ - node - for node in index.docstore.docs.values() - if node.metadata.get("document_id") == str(document.id) - ] - - if not nodes: - raise Exception("This document is not indexed yet.") - - local_index = VectorStoreIndex.from_documents(nodes) - - client = AIClient() - - engine = RetrieverQueryEngine.from_args( - retriever=local_index.as_retriever(similarity_top_k=3), - llm=client.llm, - ) - - response = engine.query(question) - return str(response) diff --git a/src/paperless/ai/indexing.py b/src/paperless/ai/indexing.py index 9ed09daa1..6d9a59e79 100644 --- a/src/paperless/ai/indexing.py +++ b/src/paperless/ai/indexing.py @@ -14,11 +14,6 @@ from paperless.ai.embedding import get_embedding_model logger = logging.getLogger("paperless.ai.indexing") -def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever: - index = load_index() - return VectorIndexRetriever(index=index, similarity_top_k=top_k) - - def load_index() -> VectorStoreIndex: """Loads the persisted LlamaIndex from disk.""" vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) @@ -37,7 +32,8 @@ def load_index() -> VectorStoreIndex: def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]: """Runs a similarity query and returns top-k similar Document objects.""" # Load the index - retriever = get_document_retriever(top_k=top_k) + index = load_index() + retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k) # Build query from the document text query_text = (document.title or "") + "\n" + (document.content or "")