Super basic doc chat

[ci skip]
This commit is contained in:
shamoon 2025-04-24 23:41:31 -07:00
parent 20506d53c0
commit a794008649
No known key found for this signature in database
4 changed files with 47 additions and 8 deletions

View File

@ -171,6 +171,7 @@ from documents.tasks import train_classifier
from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import chat_with_documents
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
@ -1134,6 +1135,17 @@ class DocumentViewSet(
"Error emailing document, check logs for more detail.",
)
@action(methods=["post"], detail=False, url_path="chat")
def chat(self, request):
ai_config = AIConfig()
if not ai_config.ai_enabled:
return HttpResponseBadRequest("AI is required for this feature")
question = request.data["q"]
result = chat_with_documents(question, request.user)
return Response({"answer": result})
@extend_schema_view(
list=extend_schema(

24
src/paperless/ai/chat.py Normal file
View File

@ -0,0 +1,24 @@
import logging
from django.contrib.auth.models import User
from llama_index.core.query_engine import RetrieverQueryEngine
from paperless.ai.client import AIClient
from paperless.ai.indexing import get_document_retriever
logger = logging.getLogger("paperless.ai.chat")
def chat_with_documents(prompt: str, user: User) -> str:
retriever = get_document_retriever(top_k=5)
client = AIClient()
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
)
logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)

View File

@ -14,6 +14,10 @@ class AIClient:
A client for interacting with an LLM backend.
"""
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self):
if self.settings.llm_backend == "ollama":
return OllamaLLM(
@ -28,10 +32,6 @@ class AIClient:
else:
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def run_llm_query(self, prompt: str) -> str:
logger.debug(
"Running LLM query against %s with model %s",

View File

@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
index = load_index()
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
def load_index() -> VectorStoreIndex:
"""Loads the persisted LlamaIndex from disk."""
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex:
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
"""Runs a similarity query and returns top-k similar Document objects."""
# Load index
index = load_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
# Load the index
retriever = get_document_retriever(top_k=top_k)
# Build query from the document text
query_text = (document.title or "") + "\n" + (document.content or "")