Fixup some tests

This commit is contained in:
shamoon 2025-04-25 00:59:46 -07:00
parent 88ac3098ef
commit dd4684170c
No known key found for this signature in database
5 changed files with 167 additions and 128 deletions

View File

@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"barcode_enable_tag": None, "barcode_enable_tag": None,
"barcode_tag_mapping": None, "barcode_tag_mapping": None,
"ai_enabled": False, "ai_enabled": False,
"llm_embedding_backend": None,
"llm_embedding_model": None,
"llm_backend": None, "llm_backend": None,
"llm_model": None, "llm_model": None,
"llm_api_key": None, "llm_api_key": None,

View File

@ -37,28 +37,65 @@ class OllamaLLM(LLM):
data = response.json() data = response.json()
return CompletionResponse(text=data["response"]) return CompletionResponse(text=data["response"])
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"messages": [
{
"role": message.role,
"content": message.content,
}
for message in messages
],
"stream": False,
},
)
response.raise_for_status()
data = response.json()
return ChatResponse(text=data["response"])
# -- Required stubs for ABC: # -- Required stubs for ABC:
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: def stream_complete(
self,
prompt: str,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("stream_complete not supported") raise NotImplementedError("stream_complete not supported")
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: def stream_chat(
raise NotImplementedError("chat not supported") self,
messages: list[ChatMessage],
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen: **kwargs,
) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("stream_chat not supported") raise NotImplementedError("stream_chat not supported")
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: async def achat(
self,
messages: list[ChatMessage],
**kwargs,
) -> ChatResponse: # pragma: no cover
raise NotImplementedError("async chat not supported") raise NotImplementedError("async chat not supported")
async def astream_chat( async def astream_chat(
self, self,
messages: list[ChatMessage], messages: list[ChatMessage],
**kwargs, **kwargs,
) -> ChatResponseGen: ) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("async stream_chat not supported") raise NotImplementedError("async stream_chat not supported")
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse: async def acomplete(
self,
prompt: str,
**kwargs,
) -> CompletionResponse: # pragma: no cover
raise NotImplementedError("async complete not supported") raise NotImplementedError("async complete not supported")
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: async def astream_complete(
self,
prompt: str,
**kwargs,
) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("async stream_complete not supported") raise NotImplementedError("async stream_complete not supported")

View File

@ -1284,10 +1284,9 @@ OUTLOOK_OAUTH_ENABLED = bool(
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO") AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
LLM_EMBEDDING_BACKEND = os.getenv( LLM_EMBEDDING_BACKEND = os.getenv(
"PAPERLESS_LLM_EMBEDDING_BACKEND", "PAPERLESS_LLM_EMBEDDING_BACKEND",
"local", ) # "local" or "openai"
) # or "openai"
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL") LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama") # or "openai" LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai"
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
LLM_URL = os.getenv("PAPERLESS_LLM_URL") LLM_URL = os.getenv("PAPERLESS_LLM_URL")

View File

@ -1,11 +1,13 @@
import json import json
from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from django.test import override_settings
from documents.models import Document from documents.models import Document
from paperless.ai.ai_classifier import get_ai_document_classification from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.ai_classifier import parse_ai_classification_response from paperless.ai.ai_classifier import parse_ai_response
@pytest.fixture @pytest.fixture
@ -15,8 +17,12 @@ def mock_document():
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query") @patch("paperless.ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_response = json.dumps( mock_run_llm_query.return_value.text = json.dumps(
{ {
"title": "Test Title", "title": "Test Title",
"tags": ["test", "document"], "tags": ["test", "document"],
@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
"dates": ["2023-01-01"], "dates": ["2023-01-01"],
}, },
) )
mock_run_llm_query.return_value = mock_response
result = get_ai_document_classification(mock_document) result = get_ai_document_classification(mock_document)
@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed") mock_run_llm_query.side_effect = Exception("LLM query failed")
result = get_ai_document_classification(mock_document) # assert raises an exception
with pytest.raises(Exception):
assert result == {} get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_valid():
mock_response = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = parse_ai_classification_response(mock_response)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
def test_parse_llm_classification_response_invalid_json(): def test_parse_llm_classification_response_invalid_json():
mock_response = "Invalid JSON" mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_classification_response(mock_response) result = parse_ai_response(mock_response)
assert result == {} assert result == {}
def test_parse_llm_classification_response_partial_data(): @pytest.mark.django_db
mock_response = json.dumps( @patch("paperless.ai.client.AIClient.run_llm_query")
{ @patch("paperless.ai.ai_classifier.build_prompt_with_rag")
"title": "Partial Data", @override_settings(
"tags": ["partial"], LLM_EMBEDDING_BACKEND="local",
"correspondents": "Jane Doe", LLM_EMBEDDING_MODEL="some_model",
"document_types": "note", LLM_BACKEND="ollama",
"storage_paths": [], LLM_MODEL="some_model",
"dates": [],
},
) )
def test_use_rag_if_configured(
mock_build_prompt_with_rag,
mock_run_llm_query,
mock_document,
):
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_with_rag.assert_called_once()
result = parse_ai_classification_response(mock_response)
assert result["title"] == "Partial Data" @pytest.mark.django_db
assert result["tags"] == ["partial"] @patch("paperless.ai.client.AIClient.run_llm_query")
assert result["correspondents"] == ["Jane Doe"] @patch("paperless.ai.ai_classifier.build_prompt_without_rag")
assert result["document_types"] == ["note"] @patch("paperless.config.AIConfig")
assert result["storage_paths"] == [] @override_settings(
assert result["dates"] == [] LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_without_rag_if_not_configured(
mock_ai_config,
mock_build_prompt_without_rag,
mock_run_llm_query,
mock_document,
):
mock_ai_config.llm_embedding_backend = None
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_without_rag.assert_called_once()

View File

@ -1,95 +1,93 @@
import json from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from django.conf import settings from llama_index.core.llms import ChatMessage
from paperless.ai.client import AIClient from paperless.ai.client import AIClient
@pytest.fixture @pytest.fixture
def mock_settings(): def mock_ai_config():
settings.LLM_BACKEND = "openai" with patch("paperless.ai.client.AIConfig") as MockAIConfig:
settings.LLM_MODEL = "gpt-3.5-turbo" mock_config = MagicMock()
settings.LLM_API_KEY = "test-api-key" MockAIConfig.return_value = mock_config
yield settings yield mock_config
@pytest.mark.django_db @pytest.fixture
@patch("paperless.ai.client.AIClient._run_openai_query") def mock_ollama_llm():
@patch("paperless.ai.client.AIClient._run_ollama_query") with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): yield MockOllamaLLM
mock_settings.LLM_BACKEND = "openai"
mock_openai_query.return_value = "OpenAI response"
@pytest.fixture
def mock_openai_llm():
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
client = AIClient() client = AIClient()
result = client.run_llm_query("Test prompt")
assert result == "OpenAI response" mock_ollama_llm.assert_called_once_with(
mock_openai_query.assert_called_once_with("Test prompt") model="test_model",
mock_ollama_query.assert_not_called() base_url="http://test-url",
)
assert client.llm == mock_ollama_llm.return_value
@pytest.mark.django_db def test_get_llm_openai(mock_ai_config, mock_openai_llm):
@patch("paperless.ai.client.AIClient._run_openai_query") mock_ai_config.llm_backend = "openai"
@patch("paperless.ai.client.AIClient._run_ollama_query") mock_ai_config.llm_model = "test_model"
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): mock_ai_config.openai_api_key = "test_api_key"
mock_settings.LLM_BACKEND = "ollama"
mock_ollama_query.return_value = "Ollama response"
client = AIClient() client = AIClient()
result = client.run_llm_query("Test prompt")
assert result == "Ollama response" mock_openai_llm.assert_called_once_with(
mock_ollama_query.assert_called_once_with("Test prompt") model="test_model",
mock_openai_query.assert_not_called() api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
@pytest.mark.django_db def test_get_llm_unsupported_backend(mock_ai_config):
def test_run_llm_query_unsupported_backend(mock_settings): mock_ai_config.llm_backend = "unsupported"
mock_settings.LLM_BACKEND = "unsupported"
client = AIClient()
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
client.run_llm_query("Test prompt") AIClient()
@pytest.mark.django_db def test_run_llm_query(mock_ai_config, mock_ollama_llm):
def test_run_openai_query(httpx_mock, mock_settings): mock_ai_config.llm_backend = "ollama"
mock_settings.LLM_BACKEND = "openai" mock_ai_config.llm_model = "test_model"
httpx_mock.add_response( mock_ai_config.llm_url = "http://test-url"
url="https://api.openai.com/v1/chat/completions",
json={ mock_llm_instance = mock_ollama_llm.return_value
"choices": [{"message": {"content": "OpenAI response"}}], mock_llm_instance.complete.return_value = "test_result"
},
)
client = AIClient() client = AIClient()
result = client.run_llm_query("Test prompt") result = client.run_llm_query("test_prompt")
assert result == "OpenAI response"
request = httpx_mock.get_request() mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert request.method == "POST" assert result == "test_result"
assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
assert request.headers["Content-Type"] == "application/json"
assert json.loads(request.content) == {
"model": mock_settings.LLM_MODEL,
"messages": [{"role": "user", "content": "Test prompt"}],
"temperature": 0.3,
}
@pytest.mark.django_db def test_run_chat(mock_ai_config, mock_ollama_llm):
def test_run_ollama_query(httpx_mock, mock_settings): mock_ai_config.llm_backend = "ollama"
mock_settings.LLM_BACKEND = "ollama" mock_ai_config.llm_model = "test_model"
httpx_mock.add_response( mock_ai_config.llm_url = "http://test-url"
url="http://localhost:11434/api/chat",
json={"message": {"content": "Ollama response"}}, mock_llm_instance = mock_ollama_llm.return_value
) mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient() client = AIClient()
result = client.run_llm_query("Test prompt") messages = [ChatMessage(role="user", content="Hello")]
assert result == "Ollama response" result = client.run_chat(messages)
request = httpx_mock.get_request() mock_llm_instance.chat.assert_called_once_with(messages)
assert request.method == "POST" assert result == "test_chat_result"
assert json.loads(request.content) == {
"model": mock_settings.LLM_MODEL,
"messages": [{"role": "user", "content": "Test prompt"}],
"stream": False,
}