Backend streaming chat

This commit is contained in:
shamoon 2025-04-25 10:06:26 -07:00
parent dd4684170c
commit 0f517a5971
No known key found for this signature in database
5 changed files with 101 additions and 39 deletions

View File

@ -1,4 +1,5 @@
import itertools import itertools
import json
import logging import logging
import os import os
import platform import platform
@ -16,6 +17,7 @@ import httpx
import pathvalidate import pathvalidate
from celery import states from celery import states
from django.conf import settings from django.conf import settings
from django.contrib.auth.decorators import login_required
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import connections from django.db import connections
@ -38,6 +40,7 @@ from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden from django.http import HttpResponseForbidden
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.http import HttpResponseServerError from django.http import HttpResponseServerError
from django.http import StreamingHttpResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils import timezone from django.utils import timezone
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
@ -45,6 +48,7 @@ from django.utils.timezone import make_aware
from django.utils.translation import get_language from django.utils.translation import get_language
from django.views import View from django.views import View
from django.views.decorators.cache import cache_control from django.views.decorators.cache import cache_control
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.decorators.http import condition from django.views.decorators.http import condition
from django.views.decorators.http import last_modified from django.views.decorators.http import last_modified
from django.views.generic import TemplateView from django.views.generic import TemplateView
@ -171,7 +175,7 @@ from documents.tasks import train_classifier
from documents.templating.filepath import validate_filepath_template_and_render from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import chat_with_documents from paperless.ai.chat import stream_chat_with_documents
from paperless.ai.matching import extract_unmatched_names from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name from paperless.ai.matching import match_document_types_by_name
@ -1135,19 +1139,27 @@ class DocumentViewSet(
"Error emailing document, check logs for more detail.", "Error emailing document, check logs for more detail.",
) )
@action(methods=["post"], detail=False, url_path="chat")
def chat(self, request): @method_decorator(ensure_csrf_cookie, name="dispatch")
@method_decorator(login_required, name="dispatch")
class ChatStreamingView(View):
def post(self, request):
ai_config = AIConfig() ai_config = AIConfig()
if not ai_config.ai_enabled: if not ai_config.ai_enabled:
return HttpResponseBadRequest("AI is required for this feature") return HttpResponseBadRequest("AI is required for this feature")
question = request.data["q"] try:
doc_id = request.data.get("document_id", None) data = json.loads(request.body)
question = data["q"]
doc_id = data.get("document_id", None)
except (KeyError, json.JSONDecodeError):
return HttpResponseBadRequest("Invalid request")
if doc_id: if doc_id:
try: try:
document = Document.objects.get(id=doc_id) document = Document.objects.get(id=doc_id)
except Document.DoesNotExist: except Document.DoesNotExist:
return HttpResponseBadRequest("Invalid document ID") return HttpResponseBadRequest("Document not found")
if not has_perms_owner_aware(request.user, "view_document", document): if not has_perms_owner_aware(request.user, "view_document", document):
return HttpResponseForbidden("Insufficient permissions") return HttpResponseForbidden("Insufficient permissions")
@ -1160,9 +1172,12 @@ class DocumentViewSet(
Document, Document,
) )
result = chat_with_documents(question, documents) response = StreamingHttpResponse(
stream_chat_with_documents(query_str=question, documents=documents),
return Response({"answer": result}) content_type="text/plain",
)
response["Cache-Control"] = "no-cache"
return response
@extend_schema_view( @extend_schema_view(

View File

@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
{filename} {filename}
CONTENT: CONTENT:
{content[:8000]} # Trim to safe size {content[:8000]}
""" """
return prompt return prompt

View File

@ -1,6 +1,7 @@
import logging import logging
from llama_index.core import VectorStoreIndex from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document from documents.models import Document
@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
logger = logging.getLogger("paperless.ai.chat") logger = logging.getLogger("paperless.ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient() client = AIClient()
index = load_index() index = load_index()
doc_ids = [doc.pk for doc in documents] doc_ids = [doc.pk for doc in documents]
@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
logger.warning("No nodes found for the given documents.") logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question." return "Sorry, I couldn't find any content to answer your question."
local_index = VectorStoreIndex.from_documents(nodes) local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever( retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5, similarity_top_k=3 if len(documents) == 1 else 5,
) )
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
query_engine = RetrieverQueryEngine.from_args( query_engine = RetrieverQueryEngine.from_args(
retriever=retriever, retriever=retriever,
llm=client.llm, llm=client.llm,
streaming=True,
) )
logger.debug("Document chat prompt: %s", prompt) logger.debug("Document chat prompt: %s", prompt)
response = query_engine.query(prompt)
logger.debug("Document chat response: %s", response) response_stream = query_engine.query(prompt)
return str(response)
for chunk in response_stream.response_gen:
yield chunk.text

View File

@ -1,3 +1,5 @@
import json
import httpx import httpx
from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse from llama_index.core.base.llms.types import ChatResponse
@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import SelectorPromptTemplate
from pydantic import Field from pydantic import Field
@ -37,33 +40,42 @@ class OllamaLLM(LLM):
data = response.json() data = response.json()
return CompletionResponse(text=data["response"]) return CompletionResponse(text=data["response"])
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
with httpx.Client(timeout=120.0) as client: return self.stream_complete(prompt, **kwargs)
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"messages": [
{
"role": message.role,
"content": message.content,
}
for message in messages
],
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return ChatResponse(text=data["response"])
# -- Required stubs for ABC:
def stream_complete( def stream_complete(
self, self,
prompt: str, prompt: SelectorPromptTemplate,
**kwargs, **kwargs,
) -> CompletionResponseGen: # pragma: no cover ) -> CompletionResponseGen:
raise NotImplementedError("stream_complete not supported") headers = {"Content-Type": "application/json"}
data = {
"model": self.model,
"prompt": prompt.format(llm=self),
"stream": True,
}
with httpx.stream(
"POST",
f"{self.base_url}/api/generate",
headers=headers,
json=data,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.strip():
continue
chunk = json.loads(line)
if "response" in chunk:
yield CompletionResponse(text=chunk["response"])
def chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("chat not supported")
def stream_chat( def stream_chat(
self, self,

View File

@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView from documents.views import BulkDownloadView
from documents.views import BulkEditObjectsView from documents.views import BulkEditObjectsView
from documents.views import BulkEditView from documents.views import BulkEditView
from documents.views import ChatStreamingView
from documents.views import CorrespondentViewSet from documents.views import CorrespondentViewSet
from documents.views import CustomFieldViewSet from documents.views import CustomFieldViewSet
from documents.views import DocumentTypeViewSet from documents.views import DocumentTypeViewSet
@ -139,6 +140,11 @@ urlpatterns = [
SelectionDataView.as_view(), SelectionDataView.as_view(),
name="selection_data", name="selection_data",
), ),
re_path(
"^chat/",
ChatStreamingView.as_view(),
name="chat_streaming_view",
),
], ],
), ),
), ),