paperless-ngx/src/paperless/tests/test_ai_client.py
2025-06-09 15:21:15 -07:00

94 lines
2.6 KiB
Python

from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from paperless.ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_ollama_llm():
with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
yield MockOllamaLLM
@pytest.fixture
def mock_openai_llm():
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
client = AIClient()
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
)
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai"
mock_ai_config.llm_model = "test_model"
mock_ai_config.openai_api_key = "test_api_key"
client = AIClient()
mock_openai_llm.assert_called_once_with(
model="test_model",
api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
AIClient()
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
client = AIClient()
result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"