2025-07-15 14:34:54 -07:00

110 lines
3.1 KiB
Python

from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_ollama_llm():
with patch("paperless_ai.client.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
def mock_openai_llm():
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
client = AIClient()
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
request_timeout=120,
)
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
client = AIClient()
mock_openai_llm.assert_called_once_with(
model="test_model",
api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
AIClient()
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
tool_kwargs={
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
client = AIClient()
result = client.run_llm_query("test_prompt")
assert result["title"] == "Test Title"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"