mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Super basic doc chat
[ci skip]
This commit is contained in:
parent
20506d53c0
commit
a794008649
@ -171,6 +171,7 @@ from documents.tasks import train_classifier
|
|||||||
from documents.templating.filepath import validate_filepath_template_and_render
|
from documents.templating.filepath import validate_filepath_template_and_render
|
||||||
from paperless import version
|
from paperless import version
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
|
from paperless.ai.chat import chat_with_documents
|
||||||
from paperless.ai.matching import extract_unmatched_names
|
from paperless.ai.matching import extract_unmatched_names
|
||||||
from paperless.ai.matching import match_correspondents_by_name
|
from paperless.ai.matching import match_correspondents_by_name
|
||||||
from paperless.ai.matching import match_document_types_by_name
|
from paperless.ai.matching import match_document_types_by_name
|
||||||
@ -1134,6 +1135,17 @@ class DocumentViewSet(
|
|||||||
"Error emailing document, check logs for more detail.",
|
"Error emailing document, check logs for more detail.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@action(methods=["post"], detail=False, url_path="chat")
|
||||||
|
def chat(self, request):
|
||||||
|
ai_config = AIConfig()
|
||||||
|
if not ai_config.ai_enabled:
|
||||||
|
return HttpResponseBadRequest("AI is required for this feature")
|
||||||
|
|
||||||
|
question = request.data["q"]
|
||||||
|
result = chat_with_documents(question, request.user)
|
||||||
|
|
||||||
|
return Response({"answer": result})
|
||||||
|
|
||||||
|
|
||||||
@extend_schema_view(
|
@extend_schema_view(
|
||||||
list=extend_schema(
|
list=extend_schema(
|
||||||
|
24
src/paperless/ai/chat.py
Normal file
24
src/paperless/ai/chat.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from django.contrib.auth.models import User
|
||||||
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
|
||||||
|
from paperless.ai.client import AIClient
|
||||||
|
from paperless.ai.indexing import get_document_retriever
|
||||||
|
|
||||||
|
logger = logging.getLogger("paperless.ai.chat")
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_documents(prompt: str, user: User) -> str:
|
||||||
|
retriever = get_document_retriever(top_k=5)
|
||||||
|
client = AIClient()
|
||||||
|
|
||||||
|
query_engine = RetrieverQueryEngine.from_args(
|
||||||
|
retriever=retriever,
|
||||||
|
llm=client.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Document chat prompt: %s", prompt)
|
||||||
|
response = query_engine.query(prompt)
|
||||||
|
logger.debug("Document chat response: %s", response)
|
||||||
|
return str(response)
|
@ -14,6 +14,10 @@ class AIClient:
|
|||||||
A client for interacting with an LLM backend.
|
A client for interacting with an LLM backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.settings = AIConfig()
|
||||||
|
self.llm = self.get_llm()
|
||||||
|
|
||||||
def get_llm(self):
|
def get_llm(self):
|
||||||
if self.settings.llm_backend == "ollama":
|
if self.settings.llm_backend == "ollama":
|
||||||
return OllamaLLM(
|
return OllamaLLM(
|
||||||
@ -28,10 +32,6 @@ class AIClient:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.settings = AIConfig()
|
|
||||||
self.llm = self.get_llm()
|
|
||||||
|
|
||||||
def run_llm_query(self, prompt: str) -> str:
|
def run_llm_query(self, prompt: str) -> str:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Running LLM query against %s with model %s",
|
"Running LLM query against %s with model %s",
|
||||||
|
@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model
|
|||||||
logger = logging.getLogger("paperless.ai.indexing")
|
logger = logging.getLogger("paperless.ai.indexing")
|
||||||
|
|
||||||
|
|
||||||
|
def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
|
||||||
|
index = load_index()
|
||||||
|
return VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||||
|
|
||||||
|
|
||||||
def load_index() -> VectorStoreIndex:
|
def load_index() -> VectorStoreIndex:
|
||||||
"""Loads the persisted LlamaIndex from disk."""
|
"""Loads the persisted LlamaIndex from disk."""
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex:
|
|||||||
|
|
||||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||||
"""Runs a similarity query and returns top-k similar Document objects."""
|
"""Runs a similarity query and returns top-k similar Document objects."""
|
||||||
|
# Load the index
|
||||||
# Load index
|
retriever = get_document_retriever(top_k=top_k)
|
||||||
index = load_index()
|
|
||||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
|
||||||
|
|
||||||
# Build query from the document text
|
# Build query from the document text
|
||||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user