Backend streaming chat

This commit is contained in:
shamoon
2025-04-25 10:06:26 -07:00
parent 46df529c3a
commit b4ea2b7521
5 changed files with 101 additions and 39 deletions

View File

@@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
{filename}
CONTENT:
{content[:8000]} # Trim to safe size
{content[:8000]}
"""
return prompt

View File

@@ -1,6 +1,7 @@
import logging
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
@@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
logger = logging.getLogger("paperless.ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_index()
doc_ids = [doc.pk for doc in documents]
@@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question."
local_index = VectorStoreIndex.from_documents(nodes)
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
streaming=True,
)
logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk.text

View File

@@ -1,3 +1,5 @@
import json
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
@@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import SelectorPromptTemplate
from pydantic import Field
@@ -37,33 +40,42 @@ class OllamaLLM(LLM):
data = response.json()
return CompletionResponse(text=data["response"])
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"messages": [
{
"role": message.role,
"content": message.content,
}
for message in messages
],
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return ChatResponse(text=data["response"])
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
return self.stream_complete(prompt, **kwargs)
# -- Required stubs for ABC:
def stream_complete(
self,
prompt: str,
prompt: SelectorPromptTemplate,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("stream_complete not supported")
) -> CompletionResponseGen:
headers = {"Content-Type": "application/json"}
data = {
"model": self.model,
"prompt": prompt.format(llm=self),
"stream": True,
}
with httpx.stream(
"POST",
f"{self.base_url}/api/generate",
headers=headers,
json=data,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.strip():
continue
chunk = json.loads(line)
if "response" in chunk:
yield CompletionResponse(text=chunk["response"])
def chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("chat not supported")
def stream_chat(
self,

View File

@@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView
from documents.views import BulkEditObjectsView
from documents.views import BulkEditView
from documents.views import ChatStreamingView
from documents.views import CorrespondentViewSet
from documents.views import CustomFieldViewSet
from documents.views import DocumentTypeViewSet
@@ -139,6 +140,11 @@ urlpatterns = [
SelectionDataView.as_view(),
name="selection_data",
),
re_path(
"^chat/",
ChatStreamingView.as_view(),
name="chat_streaming_view",
),
],
),
),