Better encapsulate backends, use llama_index OpenAI

This commit is contained in:
shamoon 2025-04-24 23:20:27 -07:00
parent 7e3a22db3f
commit 20506d53c0
No known key found for this signature in database
3 changed files with 111 additions and 48 deletions

View File

@ -1,6 +1,8 @@
import json import json
import logging import logging
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document from documents.models import Document
from paperless.ai.client import AIClient from paperless.ai.client import AIClient
from paperless.ai.rag import get_context_for_document from paperless.ai.rag import get_context_for_document
@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str:
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance") - storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format - dates: List up to 3 relevant dates in YYYY-MM-DD format
Respond ONLY in JSON.
Each field must be a list of plain strings.
The format of the JSON object is as follows: The format of the JSON object is as follows:
{{ {{
"title": "xxxxx", "title": "xxxxx",
@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str:
- storage_paths: Suggested folder paths - storage_paths: Suggested folder paths
- dates: Up to 3 relevant dates in YYYY-MM-DD - dates: Up to 3 relevant dates in YYYY-MM-DD
Respond ONLY in JSON.
Each field must be a list of plain strings.
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
Here is the document: Here is the document:
FILENAME: FILENAME:
{filename} {filename}
@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str:
return prompt return prompt
def parse_ai_response(text: str) -> dict: def parse_ai_response(response: CompletionResponse) -> dict:
try: try:
raw = json.loads(text) raw = json.loads(response.text)
return { return {
"title": raw.get("title"), "title": raw.get("title"),
"tags": raw.get("tags", []), "tags": raw.get("tags", []),
@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict:
"dates": raw.get("dates", []), "dates": raw.get("dates", []),
} }
except json.JSONDecodeError: except json.JSONDecodeError:
logger.exception("Invalid JSON in RAG response") logger.exception("Invalid JSON in AI response")
return {} return {}

View File

@ -1,7 +1,9 @@
import logging import logging
import httpx from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI
from paperless.ai.llms import OllamaLLM
from paperless.config import AIConfig from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.client") logger = logging.getLogger("paperless.ai.client")
@ -12,8 +14,23 @@ class AIClient:
A client for interacting with an LLM backend. A client for interacting with an LLM backend.
""" """
def get_llm(self):
if self.settings.llm_backend == "ollama":
return OllamaLLM(
model=self.settings.llm_model or "llama3",
base_url=self.settings.llm_url or "http://localhost:11434",
)
elif self.settings.llm_backend == "openai":
return OpenAI(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_key=self.settings.openai_api_key,
)
else:
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
def __init__(self): def __init__(self):
self.settings = AIConfig() self.settings = AIConfig()
self.llm = self.get_llm()
def run_llm_query(self, prompt: str) -> str: def run_llm_query(self, prompt: str) -> str:
logger.debug( logger.debug(
@ -21,50 +38,16 @@ class AIClient:
self.settings.llm_backend, self.settings.llm_backend,
self.settings.llm_model, self.settings.llm_model,
) )
match self.settings.llm_backend: result = self.llm.complete(prompt)
case "openai":
result = self._run_openai_query(prompt)
case "ollama":
result = self._run_ollama_query(prompt)
case _:
raise ValueError(
f"Unsupported LLM backend: {self.settings.llm_backend}",
)
logger.debug("LLM query result: %s", result) logger.debug("LLM query result: %s", result)
return result return result
def _run_ollama_query(self, prompt: str) -> str: def run_chat(self, messages: list[ChatMessage]) -> str:
url = self.settings.llm_url or "http://localhost:11434" logger.debug(
with httpx.Client(timeout=60.0) as client: "Running chat query against %s with model %s",
response = client.post( self.settings.llm_backend,
f"{url}/api/generate", self.settings.llm_model,
json={ )
"model": self.settings.llm_model, result = self.llm.chat(messages)
"prompt": prompt, logger.debug("Chat result: %s", result)
"stream": False, return result
},
)
response.raise_for_status()
return response.json()["response"]
def _run_openai_query(self, prompt: str) -> str:
if not self.settings.llm_api_key:
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
url = self.settings.llm_url or "https://api.openai.com"
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{url}/v1/chat/completions",
headers={
"Authorization": f"Bearer {self.settings.llm_api_key}",
"Content-Type": "application/json",
},
json={
"model": self.settings.llm_model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]

64
src/paperless/ai/llms.py Normal file
View File

@ -0,0 +1,64 @@
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
from llama_index.core.base.llms.types import ChatResponseGen
from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from pydantic import Field
class OllamaLLM(LLM):
model: str = Field(default="llama3")
base_url: str = Field(default="http://localhost:11434")
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
model_name=self.model,
is_chat_model=False,
context_window=4096,
num_output=512,
is_function_calling_model=False,
)
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"prompt": prompt,
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return CompletionResponse(text=data["response"])
# -- Required stubs for ABC:
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
raise NotImplementedError("stream_complete not supported")
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
raise NotImplementedError("chat not supported")
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
raise NotImplementedError("stream_chat not supported")
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
raise NotImplementedError("async chat not supported")
async def astream_chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponseGen:
raise NotImplementedError("async stream_chat not supported")
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
raise NotImplementedError("async complete not supported")
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
raise NotImplementedError("async stream_complete not supported")