Unify, respect perms

[ci skip]
This commit is contained in:
shamoon 2025-04-25 00:09:33 -07:00
parent 74102a8c30
commit 1df52fd4c1
No known key found for this signature in database
3 changed files with 37 additions and 40 deletions

View File

@ -172,7 +172,6 @@ from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import chat_with_documents
from paperless.ai.chat import chat_with_single_document
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
@ -1145,13 +1144,23 @@ class DocumentViewSet(
question = request.data["q"]
doc_id = request.data.get("document_id", None)
if doc_id:
document = Document.objects.get(id=doc_id)
try:
document = Document.objects.get(id=doc_id)
except Document.DoesNotExist:
return HttpResponseBadRequest("Invalid document ID")
if not has_perms_owner_aware(request.user, "view_document", document):
return HttpResponseForbidden("Insufficient permissions")
result = chat_with_single_document(document, question, request.user)
documents = [document]
else:
result = chat_with_documents(question, request.user)
documents = get_objects_for_user_owner_aware(
request.user,
"view_document",
Document,
)
result = chat_with_documents(question, documents)
return Response({"answer": result})

View File

@ -1,20 +1,38 @@
import logging
from django.contrib.auth.models import User
from llama_index.core import VectorStoreIndex
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
from paperless.ai.client import AIClient
from paperless.ai.indexing import get_document_retriever
from paperless.ai.indexing import load_index
logger = logging.getLogger("paperless.ai.chat")
def chat_with_documents(prompt: str, user: User) -> str:
retriever = get_document_retriever(top_k=5)
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
client = AIClient()
index = load_index()
doc_ids = [doc.pk for doc in documents]
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question."
local_index = VectorStoreIndex.from_documents(nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
@ -24,29 +42,3 @@ def chat_with_documents(prompt: str, user: User) -> str:
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)
def chat_with_single_document(document, question: str, user):
index = load_index()
# Filter only the node(s) belonging to this doc
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") == str(document.id)
]
if not nodes:
raise Exception("This document is not indexed yet.")
local_index = VectorStoreIndex.from_documents(nodes)
client = AIClient()
engine = RetrieverQueryEngine.from_args(
retriever=local_index.as_retriever(similarity_top_k=3),
llm=client.llm,
)
response = engine.query(question)
return str(response)

View File

@ -14,11 +14,6 @@ from paperless.ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
index = load_index()
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
def load_index() -> VectorStoreIndex:
"""Loads the persisted LlamaIndex from disk."""
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
@ -37,7 +32,8 @@ def load_index() -> VectorStoreIndex:
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
"""Runs a similarity query and returns top-k similar Document objects."""
# Load the index
retriever = get_document_retriever(top_k=top_k)
index = load_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
# Build query from the document text
query_text = (document.title or "") + "\n" + (document.content or "")