Super basic doc chat

[ci skip]
This commit is contained in:
shamoon
2025-04-24 23:41:31 -07:00
parent 6bdf396083
commit 0807e32278
4 changed files with 47 additions and 8 deletions

24
src/paperless/ai/chat.py Normal file
View File

@@ -0,0 +1,24 @@
import logging
from django.contrib.auth.models import User
from llama_index.core.query_engine import RetrieverQueryEngine
from paperless.ai.client import AIClient
from paperless.ai.indexing import get_document_retriever
logger = logging.getLogger("paperless.ai.chat")
def chat_with_documents(prompt: str, user: User) -> str:
retriever = get_document_retriever(top_k=5)
client = AIClient()
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
)
logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)

View File

@@ -14,6 +14,10 @@ class AIClient:
A client for interacting with an LLM backend.
"""
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self):
if self.settings.llm_backend == "ollama":
return OllamaLLM(
@@ -28,10 +32,6 @@ class AIClient:
else:
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def run_llm_query(self, prompt: str) -> str:
logger.debug(
"Running LLM query against %s with model %s",

View File

@@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
index = load_index()
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
def load_index() -> VectorStoreIndex:
"""Loads the persisted LlamaIndex from disk."""
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
@@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex:
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
"""Runs a similarity query and returns top-k similar Document objects."""
# Load index
index = load_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
# Load the index
retriever = get_document_retriever(top_k=top_k)
# Build query from the document text
query_text = (document.title or "") + "\n" + (document.content or "")