Better encapsulate backends, use llama_index OpenAI

This commit is contained in:
shamoon 2025-04-24 23:20:27 -07:00
parent 7e3a22db3f
commit 20506d53c0
No known key found for this signature in database
3 changed files with 111 additions and 48 deletions

View File

@ -1,6 +1,8 @@
import json
import logging
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from paperless.ai.client import AIClient
from paperless.ai.rag import get_context_for_document
@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str:
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format
Respond ONLY in JSON.
Each field must be a list of plain strings.
The format of the JSON object is as follows:
{{
"title": "xxxxx",
@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str:
- storage_paths: Suggested folder paths
- dates: Up to 3 relevant dates in YYYY-MM-DD
Respond ONLY in JSON.
Each field must be a list of plain strings.
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
Here is the document:
FILENAME:
{filename}
@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str:
return prompt
def parse_ai_response(text: str) -> dict:
def parse_ai_response(response: CompletionResponse) -> dict:
try:
raw = json.loads(text)
raw = json.loads(response.text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict:
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.exception("Invalid JSON in RAG response")
logger.exception("Invalid JSON in AI response")
return {}

View File

@ -1,7 +1,9 @@
import logging
import httpx
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI
from paperless.ai.llms import OllamaLLM
from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.client")
@ -12,8 +14,23 @@ class AIClient:
A client for interacting with an LLM backend.
"""
def get_llm(self):
if self.settings.llm_backend == "ollama":
return OllamaLLM(
model=self.settings.llm_model or "llama3",
base_url=self.settings.llm_url or "http://localhost:11434",
)
elif self.settings.llm_backend == "openai":
return OpenAI(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_key=self.settings.openai_api_key,
)
else:
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def run_llm_query(self, prompt: str) -> str:
logger.debug(
@ -21,50 +38,16 @@ class AIClient:
self.settings.llm_backend,
self.settings.llm_model,
)
match self.settings.llm_backend:
case "openai":
result = self._run_openai_query(prompt)
case "ollama":
result = self._run_ollama_query(prompt)
case _:
raise ValueError(
f"Unsupported LLM backend: {self.settings.llm_backend}",
)
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result)
return result
def _run_ollama_query(self, prompt: str) -> str:
url = self.settings.llm_url or "http://localhost:11434"
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{url}/api/generate",
json={
"model": self.settings.llm_model,
"prompt": prompt,
"stream": False,
},
)
response.raise_for_status()
return response.json()["response"]
def _run_openai_query(self, prompt: str) -> str:
if not self.settings.llm_api_key:
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
url = self.settings.llm_url or "https://api.openai.com"
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{url}/v1/chat/completions",
headers={
"Authorization": f"Bearer {self.settings.llm_api_key}",
"Content-Type": "application/json",
},
json={
"model": self.settings.llm_model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result

64
src/paperless/ai/llms.py Normal file
View File

@ -0,0 +1,64 @@
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
from llama_index.core.base.llms.types import ChatResponseGen
from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from pydantic import Field
class OllamaLLM(LLM):
model: str = Field(default="llama3")
base_url: str = Field(default="http://localhost:11434")
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
model_name=self.model,
is_chat_model=False,
context_window=4096,
num_output=512,
is_function_calling_model=False,
)
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"prompt": prompt,
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return CompletionResponse(text=data["response"])
# -- Required stubs for ABC:
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
raise NotImplementedError("stream_complete not supported")
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
raise NotImplementedError("chat not supported")
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
raise NotImplementedError("stream_chat not supported")
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
raise NotImplementedError("async chat not supported")
async def astream_chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponseGen:
raise NotImplementedError("async stream_chat not supported")
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
raise NotImplementedError("async complete not supported")
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
raise NotImplementedError("async stream_complete not supported")