Backend streaming chat

This commit is contained in:
shamoon
2025-04-25 10:06:26 -07:00
parent 4a28be233e
commit e864a51497
5 changed files with 101 additions and 39 deletions

View File

@@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
{filename}
CONTENT:
{content[:8000]} # Trim to safe size
{content[:8000]}
"""
return prompt

View File

@@ -1,6 +1,7 @@
import logging
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
@@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
logger = logging.getLogger("paperless.ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_index()
doc_ids = [doc.pk for doc in documents]
@@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question."
local_index = VectorStoreIndex.from_documents(nodes)
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
streaming=True,
)
logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk.text

View File

@@ -1,3 +1,5 @@
import json
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
@@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import SelectorPromptTemplate
from pydantic import Field
@@ -37,33 +40,42 @@ class OllamaLLM(LLM):
data = response.json()
return CompletionResponse(text=data["response"])
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"messages": [
{
"role": message.role,
"content": message.content,
}
for message in messages
],
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return ChatResponse(text=data["response"])
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
return self.stream_complete(prompt, **kwargs)
# -- Required stubs for ABC:
def stream_complete(
self,
prompt: str,
prompt: SelectorPromptTemplate,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("stream_complete not supported")
) -> CompletionResponseGen:
headers = {"Content-Type": "application/json"}
data = {
"model": self.model,
"prompt": prompt.format(llm=self),
"stream": True,
}
with httpx.stream(
"POST",
f"{self.base_url}/api/generate",
headers=headers,
json=data,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.strip():
continue
chunk = json.loads(line)
if "response" in chunk:
yield CompletionResponse(text=chunk["response"])
def chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("chat not supported")
def stream_chat(
self,