From 20506d53c03212dbaf0ca48b0c42b4084dde6dc4 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 24 Apr 2025 23:20:27 -0700 Subject: [PATCH] Better encapsulate backends, use llama_index OpenAI --- src/paperless/ai/ai_classifier.py | 22 ++++++++-- src/paperless/ai/client.py | 73 ++++++++++++------------------- src/paperless/ai/llms.py | 64 +++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 48 deletions(-) create mode 100644 src/paperless/ai/llms.py diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index 704b894a4..69274da56 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -1,6 +1,8 @@ import json import logging +from llama_index.core.base.llms.types import CompletionResponse + from documents.models import Document from paperless.ai.client import AIClient from paperless.ai.rag import get_context_for_document @@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str: - storage_paths: Suggested folder paths (e.g. "Medical/Insurance") - dates: List up to 3 relevant dates in YYYY-MM-DD format + Respond ONLY in JSON. + Each field must be a list of plain strings. The format of the JSON object is as follows: {{ "title": "xxxxx", @@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str: - storage_paths: Suggested folder paths - dates: Up to 3 relevant dates in YYYY-MM-DD + Respond ONLY in JSON. + Each field must be a list of plain strings. + The format of the JSON object is as follows: + {{ + "title": "xxxxx", + "tags": ["xxxx", "xxxx"], + "correspondents": ["xxxx", "xxxx"], + "document_types": ["xxxx", "xxxx"], + "storage_paths": ["xxxx", "xxxx"], + "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"], + }} + Here is the document: FILENAME: {filename} @@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str: return prompt -def parse_ai_response(text: str) -> dict: +def parse_ai_response(response: CompletionResponse) -> dict: try: - raw = json.loads(text) + raw = json.loads(response.text) return { "title": raw.get("title"), "tags": raw.get("tags", []), @@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict: "dates": raw.get("dates", []), } except json.JSONDecodeError: - logger.exception("Invalid JSON in RAG response") + logger.exception("Invalid JSON in AI response") return {} diff --git a/src/paperless/ai/client.py b/src/paperless/ai/client.py index 514605e91..cf3b0b0eb 100644 --- a/src/paperless/ai/client.py +++ b/src/paperless/ai/client.py @@ -1,7 +1,9 @@ import logging -import httpx +from llama_index.core.llms import ChatMessage +from llama_index.llms.openai import OpenAI +from paperless.ai.llms import OllamaLLM from paperless.config import AIConfig logger = logging.getLogger("paperless.ai.client") @@ -12,8 +14,23 @@ class AIClient: A client for interacting with an LLM backend. """ + def get_llm(self): + if self.settings.llm_backend == "ollama": + return OllamaLLM( + model=self.settings.llm_model or "llama3", + base_url=self.settings.llm_url or "http://localhost:11434", + ) + elif self.settings.llm_backend == "openai": + return OpenAI( + model=self.settings.llm_model or "gpt-3.5-turbo", + api_key=self.settings.openai_api_key, + ) + else: + raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}") + def __init__(self): self.settings = AIConfig() + self.llm = self.get_llm() def run_llm_query(self, prompt: str) -> str: logger.debug( @@ -21,50 +38,16 @@ class AIClient: self.settings.llm_backend, self.settings.llm_model, ) - match self.settings.llm_backend: - case "openai": - result = self._run_openai_query(prompt) - case "ollama": - result = self._run_ollama_query(prompt) - case _: - raise ValueError( - f"Unsupported LLM backend: {self.settings.llm_backend}", - ) + result = self.llm.complete(prompt) logger.debug("LLM query result: %s", result) return result - def _run_ollama_query(self, prompt: str) -> str: - url = self.settings.llm_url or "http://localhost:11434" - with httpx.Client(timeout=60.0) as client: - response = client.post( - f"{url}/api/generate", - json={ - "model": self.settings.llm_model, - "prompt": prompt, - "stream": False, - }, - ) - response.raise_for_status() - return response.json()["response"] - - def _run_openai_query(self, prompt: str) -> str: - if not self.settings.llm_api_key: - raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") - - url = self.settings.llm_url or "https://api.openai.com" - - with httpx.Client(timeout=30.0) as client: - response = client.post( - f"{url}/v1/chat/completions", - headers={ - "Authorization": f"Bearer {self.settings.llm_api_key}", - "Content-Type": "application/json", - }, - json={ - "model": self.settings.llm_model, - "messages": [{"role": "user", "content": prompt}], - "temperature": 0.3, - }, - ) - response.raise_for_status() - return response.json()["choices"][0]["message"]["content"] + def run_chat(self, messages: list[ChatMessage]) -> str: + logger.debug( + "Running chat query against %s with model %s", + self.settings.llm_backend, + self.settings.llm_model, + ) + result = self.llm.chat(messages) + logger.debug("Chat result: %s", result) + return result diff --git a/src/paperless/ai/llms.py b/src/paperless/ai/llms.py new file mode 100644 index 000000000..b51045d45 --- /dev/null +++ b/src/paperless/ai/llms.py @@ -0,0 +1,64 @@ +import httpx +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.base.llms.types import ChatResponse +from llama_index.core.base.llms.types import ChatResponseGen +from llama_index.core.base.llms.types import CompletionResponse +from llama_index.core.base.llms.types import CompletionResponseGen +from llama_index.core.base.llms.types import LLMMetadata +from llama_index.core.llms.llm import LLM +from pydantic import Field + + +class OllamaLLM(LLM): + model: str = Field(default="llama3") + base_url: str = Field(default="http://localhost:11434") + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + model_name=self.model, + is_chat_model=False, + context_window=4096, + num_output=512, + is_function_calling_model=False, + ) + + def complete(self, prompt: str, **kwargs) -> CompletionResponse: + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{self.base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + }, + ) + response.raise_for_status() + data = response.json() + return CompletionResponse(text=data["response"]) + + # -- Required stubs for ABC: + def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: + raise NotImplementedError("stream_complete not supported") + + def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: + raise NotImplementedError("chat not supported") + + def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen: + raise NotImplementedError("stream_chat not supported") + + async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: + raise NotImplementedError("async chat not supported") + + async def astream_chat( + self, + messages: list[ChatMessage], + **kwargs, + ) -> ChatResponseGen: + raise NotImplementedError("async stream_chat not supported") + + async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse: + raise NotImplementedError("async complete not supported") + + async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: + raise NotImplementedError("async stream_complete not supported")