diff --git a/src/documents/views.py b/src/documents/views.py index 2e848f44c..eb751cec0 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1,4 +1,5 @@ import itertools +import json import logging import os import platform @@ -16,6 +17,7 @@ import httpx import pathvalidate from celery import states from django.conf import settings +from django.contrib.auth.decorators import login_required from django.contrib.auth.models import Group from django.contrib.auth.models import User from django.db import connections @@ -38,6 +40,7 @@ from django.http import HttpResponseBadRequest from django.http import HttpResponseForbidden from django.http import HttpResponseRedirect from django.http import HttpResponseServerError +from django.http import StreamingHttpResponse from django.shortcuts import get_object_or_404 from django.utils import timezone from django.utils.decorators import method_decorator @@ -45,6 +48,7 @@ from django.utils.timezone import make_aware from django.utils.translation import get_language from django.views import View from django.views.decorators.cache import cache_control +from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.http import condition from django.views.decorators.http import last_modified from django.views.generic import TemplateView @@ -171,7 +175,7 @@ from documents.tasks import train_classifier from documents.templating.filepath import validate_filepath_template_and_render from paperless import version from paperless.ai.ai_classifier import get_ai_document_classification -from paperless.ai.chat import chat_with_documents +from paperless.ai.chat import stream_chat_with_documents from paperless.ai.matching import extract_unmatched_names from paperless.ai.matching import match_correspondents_by_name from paperless.ai.matching import match_document_types_by_name @@ -1135,19 +1139,27 @@ class DocumentViewSet( "Error emailing document, check logs for more detail.", ) - @action(methods=["post"], detail=False, url_path="chat") - def chat(self, request): + +@method_decorator(ensure_csrf_cookie, name="dispatch") +@method_decorator(login_required, name="dispatch") +class ChatStreamingView(View): + def post(self, request): ai_config = AIConfig() if not ai_config.ai_enabled: return HttpResponseBadRequest("AI is required for this feature") - question = request.data["q"] - doc_id = request.data.get("document_id", None) + try: + data = json.loads(request.body) + question = data["q"] + doc_id = data.get("document_id", None) + except (KeyError, json.JSONDecodeError): + return HttpResponseBadRequest("Invalid request") + if doc_id: try: document = Document.objects.get(id=doc_id) except Document.DoesNotExist: - return HttpResponseBadRequest("Invalid document ID") + return HttpResponseBadRequest("Document not found") if not has_perms_owner_aware(request.user, "view_document", document): return HttpResponseForbidden("Insufficient permissions") @@ -1160,9 +1172,12 @@ class DocumentViewSet( Document, ) - result = chat_with_documents(question, documents) - - return Response({"answer": result}) + response = StreamingHttpResponse( + stream_chat_with_documents(query_str=question, documents=documents), + content_type="text/plain", + ) + response["Cache-Control"] = "no-cache" + return response @extend_schema_view( diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index 69274da56..f5822fa63 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -48,7 +48,7 @@ def build_prompt_without_rag(document: Document) -> str: {filename} CONTENT: - {content[:8000]} # Trim to safe size + {content[:8000]} """ return prompt diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py index 3ce109a79..6ad67b56f 100644 --- a/src/paperless/ai/chat.py +++ b/src/paperless/ai/chat.py @@ -1,6 +1,7 @@ import logging from llama_index.core import VectorStoreIndex +from llama_index.core.prompts import PromptTemplate from llama_index.core.query_engine import RetrieverQueryEngine from documents.models import Document @@ -9,10 +10,19 @@ from paperless.ai.indexing import load_index logger = logging.getLogger("paperless.ai.chat") +CHAT_PROMPT_TMPL = PromptTemplate( + template="""Context information is below. + --------------------- + {context_str} + --------------------- + Given the context information and not prior knowledge, answer the query. + Query: {query_str} + Answer:""", +) -def chat_with_documents(prompt: str, documents: list[Document]) -> str: + +def stream_chat_with_documents(query_str: str, documents: list[Document]): client = AIClient() - index = load_index() doc_ids = [doc.pk for doc in documents] @@ -28,17 +38,36 @@ def chat_with_documents(prompt: str, documents: list[Document]) -> str: logger.warning("No nodes found for the given documents.") return "Sorry, I couldn't find any content to answer your question." - local_index = VectorStoreIndex.from_documents(nodes) + local_index = VectorStoreIndex(nodes=nodes) retriever = local_index.as_retriever( similarity_top_k=3 if len(documents) == 1 else 5, ) + if len(documents) == 1: + # Just one doc — provide full content + doc = documents[0] + # TODO: include document metadata in the context + context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}" + else: + top_nodes = retriever.retrieve(query_str) + context = "\n\n".join( + f"TITLE: {node.metadata.get('title')}\n{node.text}" for node in top_nodes + ) + + prompt = CHAT_PROMPT_TMPL.partial_format( + context_str=context, + query_str=query_str, + ).format(llm=client.llm) + query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=client.llm, + streaming=True, ) logger.debug("Document chat prompt: %s", prompt) - response = query_engine.query(prompt) - logger.debug("Document chat response: %s", response) - return str(response) + + response_stream = query_engine.query(prompt) + + for chunk in response_stream.response_gen: + yield chunk.text diff --git a/src/paperless/ai/llms.py b/src/paperless/ai/llms.py index c4b56f36d..4f654b126 100644 --- a/src/paperless/ai/llms.py +++ b/src/paperless/ai/llms.py @@ -1,3 +1,5 @@ +import json + import httpx from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.types import ChatResponse @@ -6,6 +8,7 @@ from llama_index.core.base.llms.types import CompletionResponse from llama_index.core.base.llms.types import CompletionResponseGen from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.llms.llm import LLM +from llama_index.core.prompts import SelectorPromptTemplate from pydantic import Field @@ -37,33 +40,42 @@ class OllamaLLM(LLM): data = response.json() return CompletionResponse(text=data["response"]) - def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: - with httpx.Client(timeout=120.0) as client: - response = client.post( - f"{self.base_url}/api/generate", - json={ - "model": self.model, - "messages": [ - { - "role": message.role, - "content": message.content, - } - for message in messages - ], - "stream": False, - }, - ) - response.raise_for_status() - data = response.json() - return ChatResponse(text=data["response"]) + def stream(self, prompt: str, **kwargs) -> CompletionResponseGen: + return self.stream_complete(prompt, **kwargs) - # -- Required stubs for ABC: def stream_complete( self, - prompt: str, + prompt: SelectorPromptTemplate, **kwargs, - ) -> CompletionResponseGen: # pragma: no cover - raise NotImplementedError("stream_complete not supported") + ) -> CompletionResponseGen: + headers = {"Content-Type": "application/json"} + data = { + "model": self.model, + "prompt": prompt.format(llm=self), + "stream": True, + } + + with httpx.stream( + "POST", + f"{self.base_url}/api/generate", + headers=headers, + json=data, + timeout=60.0, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + if not line.strip(): + continue + chunk = json.loads(line) + if "response" in chunk: + yield CompletionResponse(text=chunk["response"]) + + def chat( + self, + messages: list[ChatMessage], + **kwargs, + ) -> ChatResponse: # pragma: no cover + raise NotImplementedError("chat not supported") def stream_chat( self, diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 39819fd3d..ccdf79734 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -21,6 +21,7 @@ from rest_framework.routers import DefaultRouter from documents.views import BulkDownloadView from documents.views import BulkEditObjectsView from documents.views import BulkEditView +from documents.views import ChatStreamingView from documents.views import CorrespondentViewSet from documents.views import CustomFieldViewSet from documents.views import DocumentTypeViewSet @@ -139,6 +140,11 @@ urlpatterns = [ SelectionDataView.as_view(), name="selection_data", ), + re_path( + "^chat/", + ChatStreamingView.as_view(), + name="chat_streaming_view", + ), ], ), ),