mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-14 00:26:21 +00:00
Backend streaming chat
This commit is contained in:
@@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
{content[:8000]} # Trim to safe size
|
||||
{content[:8000]}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
@@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
|
||||
|
||||
logger = logging.getLogger("paperless.ai.chat")
|
||||
|
||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||
template="""Context information is below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given the context information and not prior knowledge, answer the query.
|
||||
Query: {query_str}
|
||||
Answer:""",
|
||||
)
|
||||
|
||||
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
|
||||
index = load_index()
|
||||
|
||||
doc_ids = [doc.pk for doc in documents]
|
||||
@@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
return "Sorry, I couldn't find any content to answer your question."
|
||||
|
||||
local_index = VectorStoreIndex.from_documents(nodes)
|
||||
local_index = VectorStoreIndex(nodes=nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
)
|
||||
|
||||
if len(documents) == 1:
|
||||
# Just one doc — provide full content
|
||||
doc = documents[0]
|
||||
# TODO: include document metadata in the context
|
||||
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
|
||||
else:
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
context = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
|
||||
)
|
||||
|
||||
prompt = CHAT_PROMPT_TMPL.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
).format(llm=client.llm)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
logger.debug("Document chat prompt: %s", prompt)
|
||||
response = query_engine.query(prompt)
|
||||
logger.debug("Document chat response: %s", response)
|
||||
return str(response)
|
||||
|
||||
response_stream = query_engine.query(prompt)
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk.text
|
||||
|
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.base.llms.types import ChatResponse
|
||||
@@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
|
||||
from llama_index.core.base.llms.types import CompletionResponseGen
|
||||
from llama_index.core.base.llms.types import LLMMetadata
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.prompts import SelectorPromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@@ -37,33 +40,42 @@ class OllamaLLM(LLM):
|
||||
data = response.json()
|
||||
return CompletionResponse(text=data["response"])
|
||||
|
||||
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return ChatResponse(text=data["response"])
|
||||
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||
return self.stream_complete(prompt, **kwargs)
|
||||
|
||||
# -- Required stubs for ABC:
|
||||
def stream_complete(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: SelectorPromptTemplate,
|
||||
**kwargs,
|
||||
) -> CompletionResponseGen: # pragma: no cover
|
||||
raise NotImplementedError("stream_complete not supported")
|
||||
) -> CompletionResponseGen:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"prompt": prompt.format(llm=self),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/api/generate",
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
chunk = json.loads(line)
|
||||
if "response" in chunk:
|
||||
yield CompletionResponse(text=chunk["response"])
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
**kwargs,
|
||||
) -> ChatResponse: # pragma: no cover
|
||||
raise NotImplementedError("chat not supported")
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
|
@@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
|
||||
from documents.views import BulkDownloadView
|
||||
from documents.views import BulkEditObjectsView
|
||||
from documents.views import BulkEditView
|
||||
from documents.views import ChatStreamingView
|
||||
from documents.views import CorrespondentViewSet
|
||||
from documents.views import CustomFieldViewSet
|
||||
from documents.views import DocumentTypeViewSet
|
||||
@@ -139,6 +140,11 @@ urlpatterns = [
|
||||
SelectionDataView.as_view(),
|
||||
name="selection_data",
|
||||
),
|
||||
re_path(
|
||||
"^chat/",
|
||||
ChatStreamingView.as_view(),
|
||||
name="chat_streaming_view",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
|
Reference in New Issue
Block a user