From a794008649a59480fe217285e2f504f7295b80f6 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 24 Apr 2025 23:41:31 -0700 Subject: [PATCH] Super basic doc chat [ci skip] --- src/documents/views.py | 12 ++++++++++++ src/paperless/ai/chat.py | 24 ++++++++++++++++++++++++ src/paperless/ai/client.py | 8 ++++---- src/paperless/ai/indexing.py | 11 +++++++---- 4 files changed, 47 insertions(+), 8 deletions(-) create mode 100644 src/paperless/ai/chat.py diff --git a/src/documents/views.py b/src/documents/views.py index 73d1f7b35..c2bf79f43 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -171,6 +171,7 @@ from documents.tasks import train_classifier from documents.templating.filepath import validate_filepath_template_and_render from paperless import version from paperless.ai.ai_classifier import get_ai_document_classification +from paperless.ai.chat import chat_with_documents from paperless.ai.matching import extract_unmatched_names from paperless.ai.matching import match_correspondents_by_name from paperless.ai.matching import match_document_types_by_name @@ -1134,6 +1135,17 @@ class DocumentViewSet( "Error emailing document, check logs for more detail.", ) + @action(methods=["post"], detail=False, url_path="chat") + def chat(self, request): + ai_config = AIConfig() + if not ai_config.ai_enabled: + return HttpResponseBadRequest("AI is required for this feature") + + question = request.data["q"] + result = chat_with_documents(question, request.user) + + return Response({"answer": result}) + @extend_schema_view( list=extend_schema( diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py new file mode 100644 index 000000000..eb485b641 --- /dev/null +++ b/src/paperless/ai/chat.py @@ -0,0 +1,24 @@ +import logging + +from django.contrib.auth.models import User +from llama_index.core.query_engine import RetrieverQueryEngine + +from paperless.ai.client import AIClient +from paperless.ai.indexing import get_document_retriever + +logger = logging.getLogger("paperless.ai.chat") + + +def chat_with_documents(prompt: str, user: User) -> str: + retriever = get_document_retriever(top_k=5) + client = AIClient() + + query_engine = RetrieverQueryEngine.from_args( + retriever=retriever, + llm=client.llm, + ) + + logger.debug("Document chat prompt: %s", prompt) + response = query_engine.query(prompt) + logger.debug("Document chat response: %s", response) + return str(response) diff --git a/src/paperless/ai/client.py b/src/paperless/ai/client.py index cf3b0b0eb..2ebb2b48d 100644 --- a/src/paperless/ai/client.py +++ b/src/paperless/ai/client.py @@ -14,6 +14,10 @@ class AIClient: A client for interacting with an LLM backend. """ + def __init__(self): + self.settings = AIConfig() + self.llm = self.get_llm() + def get_llm(self): if self.settings.llm_backend == "ollama": return OllamaLLM( @@ -28,10 +32,6 @@ class AIClient: else: raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}") - def __init__(self): - self.settings = AIConfig() - self.llm = self.get_llm() - def run_llm_query(self, prompt: str) -> str: logger.debug( "Running LLM query against %s with model %s", diff --git a/src/paperless/ai/indexing.py b/src/paperless/ai/indexing.py index 271b5f3cd..9ed09daa1 100644 --- a/src/paperless/ai/indexing.py +++ b/src/paperless/ai/indexing.py @@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model logger = logging.getLogger("paperless.ai.indexing") +def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever: + index = load_index() + return VectorIndexRetriever(index=index, similarity_top_k=top_k) + + def load_index() -> VectorStoreIndex: """Loads the persisted LlamaIndex from disk.""" vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) @@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex: def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]: """Runs a similarity query and returns top-k similar Document objects.""" - - # Load index - index = load_index() - retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k) + # Load the index + retriever = get_document_retriever(top_k=top_k) # Build query from the document text query_text = (document.title or "") + "\n" + (document.content or "")