mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Unify, respect perms
[ci skip]
This commit is contained in:
parent
74102a8c30
commit
1df52fd4c1
@ -172,7 +172,6 @@ from documents.templating.filepath import validate_filepath_template_and_render
|
|||||||
from paperless import version
|
from paperless import version
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless.ai.chat import chat_with_documents
|
from paperless.ai.chat import chat_with_documents
|
||||||
from paperless.ai.chat import chat_with_single_document
|
|
||||||
from paperless.ai.matching import extract_unmatched_names
|
from paperless.ai.matching import extract_unmatched_names
|
||||||
from paperless.ai.matching import match_correspondents_by_name
|
from paperless.ai.matching import match_correspondents_by_name
|
||||||
from paperless.ai.matching import match_document_types_by_name
|
from paperless.ai.matching import match_document_types_by_name
|
||||||
@ -1145,13 +1144,23 @@ class DocumentViewSet(
|
|||||||
question = request.data["q"]
|
question = request.data["q"]
|
||||||
doc_id = request.data.get("document_id", None)
|
doc_id = request.data.get("document_id", None)
|
||||||
if doc_id:
|
if doc_id:
|
||||||
document = Document.objects.get(id=doc_id)
|
try:
|
||||||
|
document = Document.objects.get(id=doc_id)
|
||||||
|
except Document.DoesNotExist:
|
||||||
|
return HttpResponseBadRequest("Invalid document ID")
|
||||||
|
|
||||||
if not has_perms_owner_aware(request.user, "view_document", document):
|
if not has_perms_owner_aware(request.user, "view_document", document):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
|
|
||||||
result = chat_with_single_document(document, question, request.user)
|
documents = [document]
|
||||||
else:
|
else:
|
||||||
result = chat_with_documents(question, request.user)
|
documents = get_objects_for_user_owner_aware(
|
||||||
|
request.user,
|
||||||
|
"view_document",
|
||||||
|
Document,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = chat_with_documents(question, documents)
|
||||||
|
|
||||||
return Response({"answer": result})
|
return Response({"answer": result})
|
||||||
|
|
||||||
|
@ -1,20 +1,38 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
|
||||||
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
from paperless.ai.indexing import get_document_retriever
|
|
||||||
from paperless.ai.indexing import load_index
|
from paperless.ai.indexing import load_index
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.chat")
|
logger = logging.getLogger("paperless.ai.chat")
|
||||||
|
|
||||||
|
|
||||||
def chat_with_documents(prompt: str, user: User) -> str:
|
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
||||||
retriever = get_document_retriever(top_k=5)
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
|
|
||||||
|
index = load_index()
|
||||||
|
|
||||||
|
doc_ids = [doc.pk for doc in documents]
|
||||||
|
|
||||||
|
# Filter only the node(s) that match the document IDs
|
||||||
|
nodes = [
|
||||||
|
node
|
||||||
|
for node in index.docstore.docs.values()
|
||||||
|
if node.metadata.get("document_id") in doc_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
logger.warning("No nodes found for the given documents.")
|
||||||
|
return "Sorry, I couldn't find any content to answer your question."
|
||||||
|
|
||||||
|
local_index = VectorStoreIndex.from_documents(nodes)
|
||||||
|
retriever = local_index.as_retriever(
|
||||||
|
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||||
|
)
|
||||||
|
|
||||||
query_engine = RetrieverQueryEngine.from_args(
|
query_engine = RetrieverQueryEngine.from_args(
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
llm=client.llm,
|
llm=client.llm,
|
||||||
@ -24,29 +42,3 @@ def chat_with_documents(prompt: str, user: User) -> str:
|
|||||||
response = query_engine.query(prompt)
|
response = query_engine.query(prompt)
|
||||||
logger.debug("Document chat response: %s", response)
|
logger.debug("Document chat response: %s", response)
|
||||||
return str(response)
|
return str(response)
|
||||||
|
|
||||||
|
|
||||||
def chat_with_single_document(document, question: str, user):
|
|
||||||
index = load_index()
|
|
||||||
|
|
||||||
# Filter only the node(s) belonging to this doc
|
|
||||||
nodes = [
|
|
||||||
node
|
|
||||||
for node in index.docstore.docs.values()
|
|
||||||
if node.metadata.get("document_id") == str(document.id)
|
|
||||||
]
|
|
||||||
|
|
||||||
if not nodes:
|
|
||||||
raise Exception("This document is not indexed yet.")
|
|
||||||
|
|
||||||
local_index = VectorStoreIndex.from_documents(nodes)
|
|
||||||
|
|
||||||
client = AIClient()
|
|
||||||
|
|
||||||
engine = RetrieverQueryEngine.from_args(
|
|
||||||
retriever=local_index.as_retriever(similarity_top_k=3),
|
|
||||||
llm=client.llm,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = engine.query(question)
|
|
||||||
return str(response)
|
|
||||||
|
@ -14,11 +14,6 @@ from paperless.ai.embedding import get_embedding_model
|
|||||||
logger = logging.getLogger("paperless.ai.indexing")
|
logger = logging.getLogger("paperless.ai.indexing")
|
||||||
|
|
||||||
|
|
||||||
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
|
|
||||||
index = load_index()
|
|
||||||
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
|
||||||
|
|
||||||
|
|
||||||
def load_index() -> VectorStoreIndex:
|
def load_index() -> VectorStoreIndex:
|
||||||
"""Loads the persisted LlamaIndex from disk."""
|
"""Loads the persisted LlamaIndex from disk."""
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
@ -37,7 +32,8 @@ def load_index() -> VectorStoreIndex:
|
|||||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||||
"""Runs a similarity query and returns top-k similar Document objects."""
|
"""Runs a similarity query and returns top-k similar Document objects."""
|
||||||
# Load the index
|
# Load the index
|
||||||
retriever = get_document_retriever(top_k=top_k)
|
index = load_index()
|
||||||
|
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||||
|
|
||||||
# Build query from the document text
|
# Build query from the document text
|
||||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user