Just use the built-in ollama LLM class of course

This commit is contained in:
shamoon
2025-04-25 12:01:23 -07:00
parent 183d369350
commit 37e1290e00
6 changed files with 38 additions and 379 deletions

View File

@@ -70,5 +70,4 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk.text
yield from response_stream.response_gen

View File

@@ -1,9 +1,9 @@
import logging
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.ai.llms import OllamaLLM
from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.client")
@@ -20,9 +20,10 @@ class AIClient:
def get_llm(self):
if self.settings.llm_backend == "ollama":
return OllamaLLM(
return Ollama(
model=self.settings.llm_model or "llama3",
base_url=self.settings.llm_url or "http://localhost:11434",
request_timeout=120,
)
elif self.settings.llm_backend == "openai":
return OpenAI(

View File

@@ -1,113 +0,0 @@
import json
import httpx
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatResponse
from llama_index.core.base.llms.types import ChatResponseGen
from llama_index.core.base.llms.types import CompletionResponse
from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import SelectorPromptTemplate
from pydantic import Field
class OllamaLLM(LLM):
model: str = Field(default="llama3")
base_url: str = Field(default="http://localhost:11434")
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
model_name=self.model,
is_chat_model=False,
context_window=4096,
num_output=512,
is_function_calling_model=False,
)
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"prompt": prompt,
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return CompletionResponse(text=data["response"])
def stream(self, prompt: str, **kwargs) -> CompletionResponseGen:
return self.stream_complete(prompt, **kwargs)
def stream_complete(
self,
prompt: SelectorPromptTemplate,
**kwargs,
) -> CompletionResponseGen:
headers = {"Content-Type": "application/json"}
data = {
"model": self.model,
"prompt": prompt.format(llm=self),
"stream": True,
}
with httpx.stream(
"POST",
f"{self.base_url}/api/generate",
headers=headers,
json=data,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line.strip():
continue
chunk = json.loads(line)
if "response" in chunk:
yield CompletionResponse(text=chunk["response"])
def chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("chat not supported")
def stream_chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("stream_chat not supported")
async def achat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("async chat not supported")
async def astream_chat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("async stream_chat not supported")
async def acomplete(
self,
prompt: str,
**kwargs,
) -> CompletionResponse: # pragma: no cover
raise NotImplementedError("async complete not supported")
async def astream_complete(
self,
prompt: str,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("async stream_complete not supported")

View File

@@ -17,8 +17,8 @@ def mock_ai_config():
@pytest.fixture
def mock_ollama_llm():
with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
yield MockOllamaLLM
with patch("paperless.ai.client.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
@@ -37,6 +37,7 @@ def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
request_timeout=120,
)
assert client.llm == mock_ollama_llm.return_value