mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Better encapsulate backends, use llama_index OpenAI
This commit is contained in:
parent
7e3a22db3f
commit
20506d53c0
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.rag import get_context_for_document
|
||||
@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str:
|
||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
Respond ONLY in JSON.
|
||||
Each field must be a list of plain strings.
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str:
|
||||
- storage_paths: Suggested folder paths
|
||||
- dates: Up to 3 relevant dates in YYYY-MM-DD
|
||||
|
||||
Respond ONLY in JSON.
|
||||
Each field must be a list of plain strings.
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
"tags": ["xxxx", "xxxx"],
|
||||
"correspondents": ["xxxx", "xxxx"],
|
||||
"document_types": ["xxxx", "xxxx"],
|
||||
"storage_paths": ["xxxx", "xxxx"],
|
||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||
}}
|
||||
|
||||
Here is the document:
|
||||
FILENAME:
|
||||
{filename}
|
||||
@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_ai_response(text: str) -> dict:
|
||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
raw = json.loads(response.text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict:
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Invalid JSON in RAG response")
|
||||
logger.exception("Invalid JSON in AI response")
|
||||
return {}
|
||||
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.ai.llms import OllamaLLM
|
||||
from paperless.config import AIConfig
|
||||
|
||||
logger = logging.getLogger("paperless.ai.client")
|
||||
@ -12,8 +14,23 @@ class AIClient:
|
||||
A client for interacting with an LLM backend.
|
||||
"""
|
||||
|
||||
def get_llm(self):
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return OllamaLLM(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
base_url=self.settings.llm_url or "http://localhost:11434",
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_key=self.settings.openai_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||
|
||||
def __init__(self):
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def run_llm_query(self, prompt: str) -> str:
|
||||
logger.debug(
|
||||
@ -21,50 +38,16 @@ class AIClient:
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
match self.settings.llm_backend:
|
||||
case "openai":
|
||||
result = self._run_openai_query(prompt)
|
||||
case "ollama":
|
||||
result = self._run_ollama_query(prompt)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported LLM backend: {self.settings.llm_backend}",
|
||||
)
|
||||
result = self.llm.complete(prompt)
|
||||
logger.debug("LLM query result: %s", result)
|
||||
return result
|
||||
|
||||
def _run_ollama_query(self, prompt: str) -> str:
|
||||
url = self.settings.llm_url or "http://localhost:11434"
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/api/generate",
|
||||
json={
|
||||
"model": self.settings.llm_model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["response"]
|
||||
|
||||
def _run_openai_query(self, prompt: str) -> str:
|
||||
if not self.settings.llm_api_key:
|
||||
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
|
||||
|
||||
url = self.settings.llm_url or "https://api.openai.com"
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.settings.llm_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.settings.llm_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
logger.debug(
|
||||
"Running chat query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.chat(messages)
|
||||
logger.debug("Chat result: %s", result)
|
||||
return result
|
||||
|
64
src/paperless/ai/llms.py
Normal file
64
src/paperless/ai/llms.py
Normal file
@ -0,0 +1,64 @@
|
||||
import httpx
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from llama_index.core.base.llms.types import ChatResponse
|
||||
from llama_index.core.base.llms.types import ChatResponseGen
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
from llama_index.core.base.llms.types import CompletionResponseGen
|
||||
from llama_index.core.base.llms.types import LLMMetadata
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OllamaLLM(LLM):
|
||||
model: str = Field(default="llama3")
|
||||
base_url: str = Field(default="http://localhost:11434")
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
return LLMMetadata(
|
||||
model_name=self.model,
|
||||
is_chat_model=False,
|
||||
context_window=4096,
|
||||
num_output=512,
|
||||
is_function_calling_model=False,
|
||||
)
|
||||
|
||||
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return CompletionResponse(text=data["response"])
|
||||
|
||||
# -- Required stubs for ABC:
|
||||
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||
raise NotImplementedError("stream_complete not supported")
|
||||
|
||||
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||
raise NotImplementedError("chat not supported")
|
||||
|
||||
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
|
||||
raise NotImplementedError("stream_chat not supported")
|
||||
|
||||
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||
raise NotImplementedError("async chat not supported")
|
||||
|
||||
async def astream_chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
**kwargs,
|
||||
) -> ChatResponseGen:
|
||||
raise NotImplementedError("async stream_chat not supported")
|
||||
|
||||
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
|
||||
raise NotImplementedError("async complete not supported")
|
||||
|
||||
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||
raise NotImplementedError("async stream_complete not supported")
|
Loading…
x
Reference in New Issue
Block a user