Move to structured output

This commit is contained in:
shamoon
2025-07-15 14:27:29 -07:00
parent b94912a392
commit 20bae4bd41
3 changed files with 49 additions and 67 deletions

View File

@@ -1,10 +1,12 @@
import logging
from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client")
@@ -18,7 +20,7 @@ class AIClient:
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self):
def get_llm(self) -> Ollama | OpenAI:
if self.settings.llm_backend == "ollama":
return Ollama(
model=self.settings.llm_model or "llama3",
@@ -39,9 +41,21 @@ class AIClient:
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result)
return result
user_msg = ChatMessage(role="user", content=prompt)
tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_calls=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(