mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Backend streaming chat
This commit is contained in:
parent
dd4684170c
commit
0f517a5971
@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -16,6 +17,7 @@ import httpx
|
|||||||
import pathvalidate
|
import pathvalidate
|
||||||
from celery import states
|
from celery import states
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.contrib.auth.decorators import login_required
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
@ -38,6 +40,7 @@ from django.http import HttpResponseBadRequest
|
|||||||
from django.http import HttpResponseForbidden
|
from django.http import HttpResponseForbidden
|
||||||
from django.http import HttpResponseRedirect
|
from django.http import HttpResponseRedirect
|
||||||
from django.http import HttpResponseServerError
|
from django.http import HttpResponseServerError
|
||||||
|
from django.http import StreamingHttpResponse
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
@ -45,6 +48,7 @@ from django.utils.timezone import make_aware
|
|||||||
from django.utils.translation import get_language
|
from django.utils.translation import get_language
|
||||||
from django.views import View
|
from django.views import View
|
||||||
from django.views.decorators.cache import cache_control
|
from django.views.decorators.cache import cache_control
|
||||||
|
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||||
from django.views.decorators.http import condition
|
from django.views.decorators.http import condition
|
||||||
from django.views.decorators.http import last_modified
|
from django.views.decorators.http import last_modified
|
||||||
from django.views.generic import TemplateView
|
from django.views.generic import TemplateView
|
||||||
@ -171,7 +175,7 @@ from documents.tasks import train_classifier
|
|||||||
from documents.templating.filepath import validate_filepath_template_and_render
|
from documents.templating.filepath import validate_filepath_template_and_render
|
||||||
from paperless import version
|
from paperless import version
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless.ai.chat import chat_with_documents
|
from paperless.ai.chat import stream_chat_with_documents
|
||||||
from paperless.ai.matching import extract_unmatched_names
|
from paperless.ai.matching import extract_unmatched_names
|
||||||
from paperless.ai.matching import match_correspondents_by_name
|
from paperless.ai.matching import match_correspondents_by_name
|
||||||
from paperless.ai.matching import match_document_types_by_name
|
from paperless.ai.matching import match_document_types_by_name
|
||||||
@ -1135,19 +1139,27 @@ class DocumentViewSet(
|
|||||||
"Error emailing document, check logs for more detail.",
|
"Error emailing document, check logs for more detail.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@action(methods=["post"], detail=False, url_path="chat")
|
|
||||||
def chat(self, request):
|
@method_decorator(ensure_csrf_cookie, name="dispatch")
|
||||||
|
@method_decorator(login_required, name="dispatch")
|
||||||
|
class ChatStreamingView(View):
|
||||||
|
def post(self, request):
|
||||||
ai_config = AIConfig()
|
ai_config = AIConfig()
|
||||||
if not ai_config.ai_enabled:
|
if not ai_config.ai_enabled:
|
||||||
return HttpResponseBadRequest("AI is required for this feature")
|
return HttpResponseBadRequest("AI is required for this feature")
|
||||||
|
|
||||||
question = request.data["q"]
|
try:
|
||||||
doc_id = request.data.get("document_id", None)
|
data = json.loads(request.body)
|
||||||
|
question = data["q"]
|
||||||
|
doc_id = data.get("document_id", None)
|
||||||
|
except (KeyError, json.JSONDecodeError):
|
||||||
|
return HttpResponseBadRequest("Invalid request")
|
||||||
|
|
||||||
if doc_id:
|
if doc_id:
|
||||||
try:
|
try:
|
||||||
document = Document.objects.get(id=doc_id)
|
document = Document.objects.get(id=doc_id)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
return HttpResponseBadRequest("Invalid document ID")
|
return HttpResponseBadRequest("Document not found")
|
||||||
|
|
||||||
if not has_perms_owner_aware(request.user, "view_document", document):
|
if not has_perms_owner_aware(request.user, "view_document", document):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
@ -1160,9 +1172,12 @@ class DocumentViewSet(
|
|||||||
Document,
|
Document,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_with_documents(question, documents)
|
response = StreamingHttpResponse(
|
||||||
|
stream_chat_with_documents(query_str=question, documents=documents),
|
||||||
return Response({"answer": result})
|
content_type="text/plain",
|
||||||
|
)
|
||||||
|
response["Cache-Control"] = "no-cache"
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@extend_schema_view(
|
@extend_schema_view(
|
||||||
|
@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str:
|
|||||||
{filename}
|
{filename}
|
||||||
|
|
||||||
CONTENT:
|
CONTENT:
|
||||||
{content[:8000]} # Trim to safe size
|
{content[:8000]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.prompts import PromptTemplate
|
||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index
|
|||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.chat")
|
logger = logging.getLogger("paperless.ai.chat")
|
||||||
|
|
||||||
|
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||||
|
template="""Context information is below.
|
||||||
|
---------------------
|
||||||
|
{context_str}
|
||||||
|
---------------------
|
||||||
|
Given the context information and not prior knowledge, answer the query.
|
||||||
|
Query: {query_str}
|
||||||
|
Answer:""",
|
||||||
|
)
|
||||||
|
|
||||||
def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
|
||||||
|
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
|
|
||||||
index = load_index()
|
index = load_index()
|
||||||
|
|
||||||
doc_ids = [doc.pk for doc in documents]
|
doc_ids = [doc.pk for doc in documents]
|
||||||
@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str:
|
|||||||
logger.warning("No nodes found for the given documents.")
|
logger.warning("No nodes found for the given documents.")
|
||||||
return "Sorry, I couldn't find any content to answer your question."
|
return "Sorry, I couldn't find any content to answer your question."
|
||||||
|
|
||||||
local_index = VectorStoreIndex.from_documents(nodes)
|
local_index = VectorStoreIndex(nodes=nodes)
|
||||||
retriever = local_index.as_retriever(
|
retriever = local_index.as_retriever(
|
||||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(documents) == 1:
|
||||||
|
# Just one doc — provide full content
|
||||||
|
doc = documents[0]
|
||||||
|
# TODO: include document metadata in the context
|
||||||
|
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
|
||||||
|
else:
|
||||||
|
top_nodes = retriever.retrieve(query_str)
|
||||||
|
context = "\n\n".join(
|
||||||
|
f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = CHAT_PROMPT_TMPL.partial_format(
|
||||||
|
context_str=context,
|
||||||
|
query_str=query_str,
|
||||||
|
).format(llm=client.llm)
|
||||||
|
|
||||||
query_engine = RetrieverQueryEngine.from_args(
|
query_engine = RetrieverQueryEngine.from_args(
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
llm=client.llm,
|
llm=client.llm,
|
||||||
|
streaming=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Document chat prompt: %s", prompt)
|
logger.debug("Document chat prompt: %s", prompt)
|
||||||
response = query_engine.query(prompt)
|
|
||||||
logger.debug("Document chat response: %s", response)
|
response_stream = query_engine.query(prompt)
|
||||||
return str(response)
|
|
||||||
|
for chunk in response_stream.response_gen:
|
||||||
|
yield chunk.text
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_index.core.base.llms.types import ChatMessage
|
from llama_index.core.base.llms.types import ChatMessage
|
||||||
from llama_index.core.base.llms.types import ChatResponse
|
from llama_index.core.base.llms.types import ChatResponse
|
||||||
@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse
|
|||||||
from llama_index.core.base.llms.types import CompletionResponseGen
|
from llama_index.core.base.llms.types import CompletionResponseGen
|
||||||
from llama_index.core.base.llms.types import LLMMetadata
|
from llama_index.core.base.llms.types import LLMMetadata
|
||||||
from llama_index.core.llms.llm import LLM
|
from llama_index.core.llms.llm import LLM
|
||||||
|
from llama_index.core.prompts import SelectorPromptTemplate
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
@ -37,33 +40,42 @@ class OllamaLLM(LLM):
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
return CompletionResponse(text=data["response"])
|
return CompletionResponse(text=data["response"])
|
||||||
|
|
||||||
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||||
with httpx.Client(timeout=120.0) as client:
|
return self.stream_complete(prompt, **kwargs)
|
||||||
response = client.post(
|
|
||||||
f"{self.base_url}/api/generate",
|
|
||||||
json={
|
|
||||||
"model": self.model,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": message.role,
|
|
||||||
"content": message.content,
|
|
||||||
}
|
|
||||||
for message in messages
|
|
||||||
],
|
|
||||||
"stream": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return ChatResponse(text=data["response"])
|
|
||||||
|
|
||||||
# -- Required stubs for ABC:
|
|
||||||
def stream_complete(
|
def stream_complete(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: SelectorPromptTemplate,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> CompletionResponseGen: # pragma: no cover
|
) -> CompletionResponseGen:
|
||||||
raise NotImplementedError("stream_complete not supported")
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": prompt.format(llm=self),
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.stream(
|
||||||
|
"POST",
|
||||||
|
f"{self.base_url}/api/generate",
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
timeout=60.0,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
chunk = json.loads(line)
|
||||||
|
if "response" in chunk:
|
||||||
|
yield CompletionResponse(text=chunk["response"])
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatResponse: # pragma: no cover
|
||||||
|
raise NotImplementedError("chat not supported")
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
|
@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter
|
|||||||
from documents.views import BulkDownloadView
|
from documents.views import BulkDownloadView
|
||||||
from documents.views import BulkEditObjectsView
|
from documents.views import BulkEditObjectsView
|
||||||
from documents.views import BulkEditView
|
from documents.views import BulkEditView
|
||||||
|
from documents.views import ChatStreamingView
|
||||||
from documents.views import CorrespondentViewSet
|
from documents.views import CorrespondentViewSet
|
||||||
from documents.views import CustomFieldViewSet
|
from documents.views import CustomFieldViewSet
|
||||||
from documents.views import DocumentTypeViewSet
|
from documents.views import DocumentTypeViewSet
|
||||||
@ -139,6 +140,11 @@ urlpatterns = [
|
|||||||
SelectionDataView.as_view(),
|
SelectionDataView.as_view(),
|
||||||
name="selection_data",
|
name="selection_data",
|
||||||
),
|
),
|
||||||
|
re_path(
|
||||||
|
"^chat/",
|
||||||
|
ChatStreamingView.as_view(),
|
||||||
|
name="chat_streaming_view",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user