mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Better encapsulate backends, use llama_index OpenAI
This commit is contained in:
parent
7e3a22db3f
commit
20506d53c0
@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from llama_index.core.base.llms.types import CompletionResponse
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
from paperless.ai.rag import get_context_for_document
|
from paperless.ai.rag import get_context_for_document
|
||||||
@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str:
|
|||||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||||
|
|
||||||
|
Respond ONLY in JSON.
|
||||||
|
Each field must be a list of plain strings.
|
||||||
The format of the JSON object is as follows:
|
The format of the JSON object is as follows:
|
||||||
{{
|
{{
|
||||||
"title": "xxxxx",
|
"title": "xxxxx",
|
||||||
@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str:
|
|||||||
- storage_paths: Suggested folder paths
|
- storage_paths: Suggested folder paths
|
||||||
- dates: Up to 3 relevant dates in YYYY-MM-DD
|
- dates: Up to 3 relevant dates in YYYY-MM-DD
|
||||||
|
|
||||||
|
Respond ONLY in JSON.
|
||||||
|
Each field must be a list of plain strings.
|
||||||
|
The format of the JSON object is as follows:
|
||||||
|
{{
|
||||||
|
"title": "xxxxx",
|
||||||
|
"tags": ["xxxx", "xxxx"],
|
||||||
|
"correspondents": ["xxxx", "xxxx"],
|
||||||
|
"document_types": ["xxxx", "xxxx"],
|
||||||
|
"storage_paths": ["xxxx", "xxxx"],
|
||||||
|
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||||
|
}}
|
||||||
|
|
||||||
Here is the document:
|
Here is the document:
|
||||||
FILENAME:
|
FILENAME:
|
||||||
{filename}
|
{filename}
|
||||||
@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def parse_ai_response(text: str) -> dict:
|
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||||
try:
|
try:
|
||||||
raw = json.loads(text)
|
raw = json.loads(response.text)
|
||||||
return {
|
return {
|
||||||
"title": raw.get("title"),
|
"title": raw.get("title"),
|
||||||
"tags": raw.get("tags", []),
|
"tags": raw.get("tags", []),
|
||||||
@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict:
|
|||||||
"dates": raw.get("dates", []),
|
"dates": raw.get("dates", []),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.exception("Invalid JSON in RAG response")
|
logger.exception("Invalid JSON in AI response")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import httpx
|
from llama_index.core.llms import ChatMessage
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
|
from paperless.ai.llms import OllamaLLM
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.client")
|
logger = logging.getLogger("paperless.ai.client")
|
||||||
@ -12,8 +14,23 @@ class AIClient:
|
|||||||
A client for interacting with an LLM backend.
|
A client for interacting with an LLM backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def get_llm(self):
|
||||||
|
if self.settings.llm_backend == "ollama":
|
||||||
|
return OllamaLLM(
|
||||||
|
model=self.settings.llm_model or "llama3",
|
||||||
|
base_url=self.settings.llm_url or "http://localhost:11434",
|
||||||
|
)
|
||||||
|
elif self.settings.llm_backend == "openai":
|
||||||
|
return OpenAI(
|
||||||
|
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||||
|
api_key=self.settings.openai_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.settings = AIConfig()
|
self.settings = AIConfig()
|
||||||
|
self.llm = self.get_llm()
|
||||||
|
|
||||||
def run_llm_query(self, prompt: str) -> str:
|
def run_llm_query(self, prompt: str) -> str:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -21,50 +38,16 @@ class AIClient:
|
|||||||
self.settings.llm_backend,
|
self.settings.llm_backend,
|
||||||
self.settings.llm_model,
|
self.settings.llm_model,
|
||||||
)
|
)
|
||||||
match self.settings.llm_backend:
|
result = self.llm.complete(prompt)
|
||||||
case "openai":
|
|
||||||
result = self._run_openai_query(prompt)
|
|
||||||
case "ollama":
|
|
||||||
result = self._run_ollama_query(prompt)
|
|
||||||
case _:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported LLM backend: {self.settings.llm_backend}",
|
|
||||||
)
|
|
||||||
logger.debug("LLM query result: %s", result)
|
logger.debug("LLM query result: %s", result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _run_ollama_query(self, prompt: str) -> str:
|
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||||
url = self.settings.llm_url or "http://localhost:11434"
|
logger.debug(
|
||||||
with httpx.Client(timeout=60.0) as client:
|
"Running chat query against %s with model %s",
|
||||||
response = client.post(
|
self.settings.llm_backend,
|
||||||
f"{url}/api/generate",
|
self.settings.llm_model,
|
||||||
json={
|
)
|
||||||
"model": self.settings.llm_model,
|
result = self.llm.chat(messages)
|
||||||
"prompt": prompt,
|
logger.debug("Chat result: %s", result)
|
||||||
"stream": False,
|
return result
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()["response"]
|
|
||||||
|
|
||||||
def _run_openai_query(self, prompt: str) -> str:
|
|
||||||
if not self.settings.llm_api_key:
|
|
||||||
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
|
|
||||||
|
|
||||||
url = self.settings.llm_url or "https://api.openai.com"
|
|
||||||
|
|
||||||
with httpx.Client(timeout=30.0) as client:
|
|
||||||
response = client.post(
|
|
||||||
f"{url}/v1/chat/completions",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {self.settings.llm_api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"model": self.settings.llm_model,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.3,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()["choices"][0]["message"]["content"]
|
|
||||||
|
64
src/paperless/ai/llms.py
Normal file
64
src/paperless/ai/llms.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import httpx
|
||||||
|
from llama_index.core.base.llms.types import ChatMessage
|
||||||
|
from llama_index.core.base.llms.types import ChatResponse
|
||||||
|
from llama_index.core.base.llms.types import ChatResponseGen
|
||||||
|
from llama_index.core.base.llms.types import CompletionResponse
|
||||||
|
from llama_index.core.base.llms.types import CompletionResponseGen
|
||||||
|
from llama_index.core.base.llms.types import LLMMetadata
|
||||||
|
from llama_index.core.llms.llm import LLM
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaLLM(LLM):
|
||||||
|
model: str = Field(default="llama3")
|
||||||
|
base_url: str = Field(default="http://localhost:11434")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> LLMMetadata:
|
||||||
|
return LLMMetadata(
|
||||||
|
model_name=self.model,
|
||||||
|
is_chat_model=False,
|
||||||
|
context_window=4096,
|
||||||
|
num_output=512,
|
||||||
|
is_function_calling_model=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{self.base_url}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return CompletionResponse(text=data["response"])
|
||||||
|
|
||||||
|
# -- Required stubs for ABC:
|
||||||
|
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||||
|
raise NotImplementedError("stream_complete not supported")
|
||||||
|
|
||||||
|
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||||
|
raise NotImplementedError("chat not supported")
|
||||||
|
|
||||||
|
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
|
||||||
|
raise NotImplementedError("stream_chat not supported")
|
||||||
|
|
||||||
|
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||||
|
raise NotImplementedError("async chat not supported")
|
||||||
|
|
||||||
|
async def astream_chat(
|
||||||
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatResponseGen:
|
||||||
|
raise NotImplementedError("async stream_chat not supported")
|
||||||
|
|
||||||
|
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
|
||||||
|
raise NotImplementedError("async complete not supported")
|
||||||
|
|
||||||
|
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
||||||
|
raise NotImplementedError("async stream_complete not supported")
|
Loading…
x
Reference in New Issue
Block a user