Backend streaming chat

This commit is contained in:
shamoon 2025-04-25 10:06:26 -07:00
parent dd4684170c
commit 0f517a5971
No known key found for this signature in database
5 changed files with 101 additions and 39 deletions

View File

@ -1,4 +1,5 @@
import itertools
import json
import logging
import os
import platform
@ -16,6 +17,7 @@ import httpx
import pathvalidate
from celery import states
from django.conf import settings
from django.contrib.auth.decorators import login_required
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.db import connections
@ -38,6 +40,7 @@ from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden
from django.http import HttpResponseRedirect
from django.http import HttpResponseServerError
from django.http import StreamingHttpResponse
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.decorators import method_decorator
@ -45,6 +48,7 @@ from django.utils.timezone import make_aware
from django.utils.translation import get_language
from django.views import View
from django.views.decorators.cache import cache_control
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.decorators.http import condition
from django.views.decorators.http import last_modified
from django.views.generic import TemplateView
@ -171,7 +175,7 @@ from documents.tasks import train_classifier
from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import chat_with_documents
from paperless.ai.chat import stream_chat_with_documents
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
@ -1135,19 +1139,27 @@ class DocumentViewSet(
"Error emailing document, check logs for more detail.",
)
@action(methods=["post"], detail=False, url_path="chat")
def chat(self, request):
@method_decorator(ensure_csrf_cookie, name="dispatch")
@method_decorator(login_required, name="dispatch")
class ChatStreamingView(View):
def post(self, request):
ai_config = AIConfig()
if not ai_config.ai_enabled:
return HttpResponseBadRequest("AI is required for this feature")
question = request.data["q"]
doc_id = request.data.get("document_id", None)
try:
data = json.loads(request.body)
question = data["q"]
doc_id = data.get("document_id", None)
except (KeyError, json.JSONDecodeError):
return HttpResponseBadRequest("Invalid request")
if doc_id:
try:
document = Document.objects.get(id=doc_id)
except Document.DoesNotExist:
return HttpResponseBadRequest("Invalid document ID")
return HttpResponseBadRequest("Document not found")
if not has_perms_owner_aware(request.user, "view_document", document):
return HttpResponseForbidden("Insufficient permissions")
@ -1160,9 +1172,12 @@ class DocumentViewSet(
Document,
)
result = chat_with_documents(question, documents)
return Response({"answer": result})
response = StreamingHttpResponse(
stream_chat_with_documents(query_str=question, documents=documents),
content_type="text/plain",
)
response["Cache-Control"] = "no-cache"
return response
@extend_schema_view(

View File

@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
{filename}
CONTENT:
{content[:8000]} # Trim to safe size
{content[:8000]}
"""
return prompt

View File

@ -1,6 +1,7 @@
import logging
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
logger = logging.getLogger("paperless.ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_index()
doc_ids = [doc.pk for doc in documents]
@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question."
local_index = VectorStoreIndex.from_documents(nodes)
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
streaming=True,
)
logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response)
return str(response)
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk.text

View File

@ -1,3 +1,5 @@
import json
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import SelectorPromptTemplate
from pydantic import Field
@ -37,33 +40,42 @@ class OllamaLLM(LLM):
data = response.json()
return CompletionResponse(text=data["response"])
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"messages": [
{
"role": message.role,
"content": message.content,
}
for message in messages
],
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return ChatResponse(text=data["response"])
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
return self.stream_complete(prompt, **kwargs)
# -- Required stubs for ABC:
def stream_complete(
self,
prompt: str,
prompt: SelectorPromptTemplate,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("stream_complete not supported")
) -> CompletionResponseGen:
headers = {"Content-Type": "application/json"}
data = {
"model": self.model,
"prompt": prompt.format(llm=self),
"stream": True,
}
with httpx.stream(
"POST",
f"{self.base_url}/api/generate",
headers=headers,
json=data,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.strip():
continue
chunk = json.loads(line)
if "response" in chunk:
yield CompletionResponse(text=chunk["response"])
def chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("chat not supported")
def stream_chat(
self,

View File

@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView
from documents.views import BulkEditObjectsView
from documents.views import BulkEditView
from documents.views import ChatStreamingView
from documents.views import CorrespondentViewSet
from documents.views import CustomFieldViewSet
from documents.views import DocumentTypeViewSet
@ -139,6 +140,11 @@ urlpatterns = [
SelectionDataView.as_view(),
name="selection_data",
),
re_path(
"^chat/",
ChatStreamingView.as_view(),
name="chat_streaming_view",
),
],
),
),