mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-21 12:52:13 -05:00
Super basic doc chat
[ci skip]
This commit is contained in:
parent
20506d53c0
commit
a794008649
@ -171,6 +171,7 @@ from documents.tasks import train_classifier
|
||||
from documents.templating.filepath import validate_filepath_template_and_render
|
||||
from paperless import version
|
||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||
from paperless.ai.chat import chat_with_documents
|
||||
from paperless.ai.matching import extract_unmatched_names
|
||||
from paperless.ai.matching import match_correspondents_by_name
|
||||
from paperless.ai.matching import match_document_types_by_name
|
||||
@ -1134,6 +1135,17 @@ class DocumentViewSet(
|
||||
"Error emailing document, check logs for more detail.",
|
||||
)
|
||||
|
||||
@action(methods=["post"], detail=False, url_path="chat")
|
||||
def chat(self, request):
|
||||
ai_config = AIConfig()
|
||||
if not ai_config.ai_enabled:
|
||||
return HttpResponseBadRequest("AI is required for this feature")
|
||||
|
||||
question = request.data["q"]
|
||||
result = chat_with_documents(question, request.user)
|
||||
|
||||
return Response({"answer": result})
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
|
24
src/paperless/ai/chat.py
Normal file
24
src/paperless/ai/chat.py
Normal file
@ -0,0 +1,24 @@
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.indexing import get_document_retriever
|
||||
|
||||
logger = logging.getLogger("paperless.ai.chat")
|
||||
|
||||
|
||||
def chat_with_documents(prompt: str, user: User) -> str:
|
||||
retriever = get_document_retriever(top_k=5)
|
||||
client = AIClient()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
)
|
||||
|
||||
logger.debug("Document chat prompt: %s", prompt)
|
||||
response = query_engine.query(prompt)
|
||||
logger.debug("Document chat response: %s", response)
|
||||
return str(response)
|
@ -14,6 +14,10 @@ class AIClient:
|
||||
A client for interacting with an LLM backend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self):
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return OllamaLLM(
|
||||
@ -28,10 +32,6 @@ class AIClient:
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||
|
||||
def __init__(self):
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def run_llm_query(self, prompt: str) -> str:
|
||||
logger.debug(
|
||||
"Running LLM query against %s with model %s",
|
||||
|
@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model
|
||||
logger = logging.getLogger("paperless.ai.indexing")
|
||||
|
||||
|
||||
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
|
||||
index = load_index()
|
||||
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||
|
||||
|
||||
def load_index() -> VectorStoreIndex:
|
||||
"""Loads the persisted LlamaIndex from disk."""
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex:
|
||||
|
||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||
"""Runs a similarity query and returns top-k similar Document objects."""
|
||||
|
||||
# Load index
|
||||
index = load_index()
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||
# Load the index
|
||||
retriever = get_document_retriever(top_k=top_k)
|
||||
|
||||
# Build query from the document text
|
||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||
|
Loading…
x
Reference in New Issue
Block a user