diff --git a/src/documents/views.py b/src/documents/views.py index eb751cec0..fdd04d47c 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1144,6 +1144,7 @@ class DocumentViewSet( @method_decorator(login_required, name="dispatch") class ChatStreamingView(View): def post(self, request): + request.compress_exempt = True ai_config = AIConfig() if not ai_config.ai_enabled: return HttpResponseBadRequest("AI is required for this feature") @@ -1174,7 +1175,7 @@ class ChatStreamingView(View): response = StreamingHttpResponse( stream_chat_with_documents(query_str=question, documents=documents), - content_type="text/plain", + content_type="text/event-stream", ) response["Cache-Control"] = "no-cache" return response diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py index 04bc9d2bd..ad14bda4d 100644 --- a/src/paperless/ai/chat.py +++ b/src/paperless/ai/chat.py @@ -1,4 +1,5 @@ import logging +import sys from llama_index.core import VectorStoreIndex from llama_index.core.prompts import PromptTemplate @@ -70,4 +71,6 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): response_stream = query_engine.query(prompt) - yield from response_stream.response_gen + for chunk in response_stream.response_gen: + yield chunk + sys.stdout.flush() diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 4d720150b..f0d1edeb7 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -11,6 +11,7 @@ from typing import Final from urllib.parse import urlparse from celery.schedules import crontab +from compression_middleware.middleware import CompressionMiddleware from concurrent_log_handler.queue import setup_logging_queues from django.utils.translation import gettext_lazy as _ from dotenv import load_dotenv @@ -375,6 +376,19 @@ MIDDLEWARE = [ if __get_boolean("PAPERLESS_ENABLE_COMPRESSION", "yes"): # pragma: no cover MIDDLEWARE.insert(0, "compression_middleware.middleware.CompressionMiddleware") +# Workaround to not compress streaming responses (e.g. chat). +# See https://github.com/friedelwolff/django-compression-middleware/pull/7 +original_process_response = CompressionMiddleware.process_response + + +def patched_process_response(self, request, response): + if getattr(request, "compress_exempt", False): + return response + return original_process_response(self, request, response) + + +CompressionMiddleware.process_response = patched_process_response + ROOT_URLCONF = "paperless.urls"