From dd4684170c3b94d2b23ce8670d293a74903f5768 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 25 Apr 2025 00:59:46 -0700 Subject: [PATCH] Fixup some tests --- src/documents/tests/test_api_app_config.py | 2 + src/paperless/ai/llms.py | 55 +++++++-- src/paperless/settings.py | 5 +- src/paperless/tests/test_ai_classifier.py | 99 +++++++-------- src/paperless/tests/test_ai_client.py | 134 ++++++++++----------- 5 files changed, 167 insertions(+), 128 deletions(-) diff --git a/src/documents/tests/test_api_app_config.py b/src/documents/tests/test_api_app_config.py index 61d745c76..25bed2bf6 100644 --- a/src/documents/tests/test_api_app_config.py +++ b/src/documents/tests/test_api_app_config.py @@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): "barcode_enable_tag": None, "barcode_tag_mapping": None, "ai_enabled": False, + "llm_embedding_backend": None, + "llm_embedding_model": None, "llm_backend": None, "llm_model": None, "llm_api_key": None, diff --git a/src/paperless/ai/llms.py b/src/paperless/ai/llms.py index b51045d45..c4b56f36d 100644 --- a/src/paperless/ai/llms.py +++ b/src/paperless/ai/llms.py @@ -37,28 +37,65 @@ class OllamaLLM(LLM): data = response.json() return CompletionResponse(text=data["response"]) + def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{self.base_url}/api/generate", + json={ + "model": self.model, + "messages": [ + { + "role": message.role, + "content": message.content, + } + for message in messages + ], + "stream": False, + }, + ) + response.raise_for_status() + data = response.json() + return ChatResponse(text=data["response"]) + # -- Required stubs for ABC: - def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: + def stream_complete( + self, + prompt: str, + **kwargs, + ) -> CompletionResponseGen: # pragma: no cover raise NotImplementedError("stream_complete not supported") - def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: - raise NotImplementedError("chat not supported") - - def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen: + def stream_chat( + self, + messages: list[ChatMessage], + **kwargs, + ) -> ChatResponseGen: # pragma: no cover raise NotImplementedError("stream_chat not supported") - async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse: + async def achat( + self, + messages: list[ChatMessage], + **kwargs, + ) -> ChatResponse: # pragma: no cover raise NotImplementedError("async chat not supported") async def astream_chat( self, messages: list[ChatMessage], **kwargs, - ) -> ChatResponseGen: + ) -> ChatResponseGen: # pragma: no cover raise NotImplementedError("async stream_chat not supported") - async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse: + async def acomplete( + self, + prompt: str, + **kwargs, + ) -> CompletionResponse: # pragma: no cover raise NotImplementedError("async complete not supported") - async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen: + async def astream_complete( + self, + prompt: str, + **kwargs, + ) -> CompletionResponseGen: # pragma: no cover raise NotImplementedError("async stream_complete not supported") diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 76a0a8c61..eb1a42fb8 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -1284,10 +1284,9 @@ OUTLOOK_OAUTH_ENABLED = bool( AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO") LLM_EMBEDDING_BACKEND = os.getenv( "PAPERLESS_LLM_EMBEDDING_BACKEND", - "local", -) # or "openai" +) # "local" or "openai" LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL") -LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama") # or "openai" +LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai" LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") LLM_URL = os.getenv("PAPERLESS_LLM_URL") diff --git a/src/paperless/tests/test_ai_classifier.py b/src/paperless/tests/test_ai_classifier.py index edb086bbe..a473652fc 100644 --- a/src/paperless/tests/test_ai_classifier.py +++ b/src/paperless/tests/test_ai_classifier.py @@ -1,11 +1,13 @@ import json +from unittest.mock import MagicMock from unittest.mock import patch import pytest +from django.test import override_settings from documents.models import Document from paperless.ai.ai_classifier import get_ai_document_classification -from paperless.ai.ai_classifier import parse_ai_classification_response +from paperless.ai.ai_classifier import parse_ai_response @pytest.fixture @@ -15,8 +17,12 @@ def mock_document(): @pytest.mark.django_db @patch("paperless.ai.client.AIClient.run_llm_query") +@override_settings( + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): - mock_response = json.dumps( + mock_run_llm_query.return_value.text = json.dumps( { "title": "Test Title", "tags": ["test", "document"], @@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen "dates": ["2023-01-01"], }, ) - mock_run_llm_query.return_value = mock_response result = get_ai_document_classification(mock_document) @@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): mock_run_llm_query.side_effect = Exception("LLM query failed") - result = get_ai_document_classification(mock_document) - - assert result == {} - - -def test_parse_llm_classification_response_valid(): - mock_response = json.dumps( - { - "title": "Test Title", - "tags": ["test", "document"], - "correspondents": ["John Doe"], - "document_types": ["report"], - "storage_paths": ["Reports"], - "dates": ["2023-01-01"], - }, - ) - - result = parse_ai_classification_response(mock_response) - - assert result["title"] == "Test Title" - assert result["tags"] == ["test", "document"] - assert result["correspondents"] == ["John Doe"] - assert result["document_types"] == ["report"] - assert result["storage_paths"] == ["Reports"] - assert result["dates"] == ["2023-01-01"] + # assert raises an exception + with pytest.raises(Exception): + get_ai_document_classification(mock_document) def test_parse_llm_classification_response_invalid_json(): - mock_response = "Invalid JSON" + mock_response = MagicMock() + mock_response.text = "Invalid JSON response" - result = parse_ai_classification_response(mock_response) + result = parse_ai_response(mock_response) assert result == {} -def test_parse_llm_classification_response_partial_data(): - mock_response = json.dumps( - { - "title": "Partial Data", - "tags": ["partial"], - "correspondents": "Jane Doe", - "document_types": "note", - "storage_paths": [], - "dates": [], - }, - ) +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient.run_llm_query") +@patch("paperless.ai.ai_classifier.build_prompt_with_rag") +@override_settings( + LLM_EMBEDDING_BACKEND="local", + LLM_EMBEDDING_MODEL="some_model", + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) +def test_use_rag_if_configured( + mock_build_prompt_with_rag, + mock_run_llm_query, + mock_document, +): + mock_build_prompt_with_rag.return_value = "Prompt with RAG" + mock_run_llm_query.return_value.text = json.dumps({}) + get_ai_document_classification(mock_document) + mock_build_prompt_with_rag.assert_called_once() - result = parse_ai_classification_response(mock_response) - assert result["title"] == "Partial Data" - assert result["tags"] == ["partial"] - assert result["correspondents"] == ["Jane Doe"] - assert result["document_types"] == ["note"] - assert result["storage_paths"] == [] - assert result["dates"] == [] +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient.run_llm_query") +@patch("paperless.ai.ai_classifier.build_prompt_without_rag") +@patch("paperless.config.AIConfig") +@override_settings( + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) +def test_use_without_rag_if_not_configured( + mock_ai_config, + mock_build_prompt_without_rag, + mock_run_llm_query, + mock_document, +): + mock_ai_config.llm_embedding_backend = None + mock_build_prompt_without_rag.return_value = "Prompt without RAG" + mock_run_llm_query.return_value.text = json.dumps({}) + get_ai_document_classification(mock_document) + mock_build_prompt_without_rag.assert_called_once() diff --git a/src/paperless/tests/test_ai_client.py b/src/paperless/tests/test_ai_client.py index 6a239279e..27b160d23 100644 --- a/src/paperless/tests/test_ai_client.py +++ b/src/paperless/tests/test_ai_client.py @@ -1,95 +1,93 @@ -import json +from unittest.mock import MagicMock from unittest.mock import patch import pytest -from django.conf import settings +from llama_index.core.llms import ChatMessage from paperless.ai.client import AIClient @pytest.fixture -def mock_settings(): - settings.LLM_BACKEND = "openai" - settings.LLM_MODEL = "gpt-3.5-turbo" - settings.LLM_API_KEY = "test-api-key" - yield settings +def mock_ai_config(): + with patch("paperless.ai.client.AIConfig") as MockAIConfig: + mock_config = MagicMock() + MockAIConfig.return_value = mock_config + yield mock_config -@pytest.mark.django_db -@patch("paperless.ai.client.AIClient._run_openai_query") -@patch("paperless.ai.client.AIClient._run_ollama_query") -def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): - mock_settings.LLM_BACKEND = "openai" - mock_openai_query.return_value = "OpenAI response" +@pytest.fixture +def mock_ollama_llm(): + with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM: + yield MockOllamaLLM + + +@pytest.fixture +def mock_openai_llm(): + with patch("paperless.ai.client.OpenAI") as MockOpenAI: + yield MockOpenAI + + +def test_get_llm_ollama(mock_ai_config, mock_ollama_llm): + mock_ai_config.llm_backend = "ollama" + mock_ai_config.llm_model = "test_model" + mock_ai_config.llm_url = "http://test-url" + client = AIClient() - result = client.run_llm_query("Test prompt") - assert result == "OpenAI response" - mock_openai_query.assert_called_once_with("Test prompt") - mock_ollama_query.assert_not_called() + + mock_ollama_llm.assert_called_once_with( + model="test_model", + base_url="http://test-url", + ) + assert client.llm == mock_ollama_llm.return_value -@pytest.mark.django_db -@patch("paperless.ai.client.AIClient._run_openai_query") -@patch("paperless.ai.client.AIClient._run_ollama_query") -def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): - mock_settings.LLM_BACKEND = "ollama" - mock_ollama_query.return_value = "Ollama response" +def test_get_llm_openai(mock_ai_config, mock_openai_llm): + mock_ai_config.llm_backend = "openai" + mock_ai_config.llm_model = "test_model" + mock_ai_config.openai_api_key = "test_api_key" + client = AIClient() - result = client.run_llm_query("Test prompt") - assert result == "Ollama response" - mock_ollama_query.assert_called_once_with("Test prompt") - mock_openai_query.assert_not_called() + + mock_openai_llm.assert_called_once_with( + model="test_model", + api_key="test_api_key", + ) + assert client.llm == mock_openai_llm.return_value -@pytest.mark.django_db -def test_run_llm_query_unsupported_backend(mock_settings): - mock_settings.LLM_BACKEND = "unsupported" - client = AIClient() +def test_get_llm_unsupported_backend(mock_ai_config): + mock_ai_config.llm_backend = "unsupported" + with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): - client.run_llm_query("Test prompt") + AIClient() -@pytest.mark.django_db -def test_run_openai_query(httpx_mock, mock_settings): - mock_settings.LLM_BACKEND = "openai" - httpx_mock.add_response( - url="https://api.openai.com/v1/chat/completions", - json={ - "choices": [{"message": {"content": "OpenAI response"}}], - }, - ) +def test_run_llm_query(mock_ai_config, mock_ollama_llm): + mock_ai_config.llm_backend = "ollama" + mock_ai_config.llm_model = "test_model" + mock_ai_config.llm_url = "http://test-url" + + mock_llm_instance = mock_ollama_llm.return_value + mock_llm_instance.complete.return_value = "test_result" client = AIClient() - result = client.run_llm_query("Test prompt") - assert result == "OpenAI response" + result = client.run_llm_query("test_prompt") - request = httpx_mock.get_request() - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}" - assert request.headers["Content-Type"] == "application/json" - assert json.loads(request.content) == { - "model": mock_settings.LLM_MODEL, - "messages": [{"role": "user", "content": "Test prompt"}], - "temperature": 0.3, - } + mock_llm_instance.complete.assert_called_once_with("test_prompt") + assert result == "test_result" -@pytest.mark.django_db -def test_run_ollama_query(httpx_mock, mock_settings): - mock_settings.LLM_BACKEND = "ollama" - httpx_mock.add_response( - url="http://localhost:11434/api/chat", - json={"message": {"content": "Ollama response"}}, - ) +def test_run_chat(mock_ai_config, mock_ollama_llm): + mock_ai_config.llm_backend = "ollama" + mock_ai_config.llm_model = "test_model" + mock_ai_config.llm_url = "http://test-url" + + mock_llm_instance = mock_ollama_llm.return_value + mock_llm_instance.chat.return_value = "test_chat_result" client = AIClient() - result = client.run_llm_query("Test prompt") - assert result == "Ollama response" + messages = [ChatMessage(role="user", content="Hello")] + result = client.run_chat(messages) - request = httpx_mock.get_request() - assert request.method == "POST" - assert json.loads(request.content) == { - "model": mock_settings.LLM_MODEL, - "messages": [{"role": "user", "content": "Test prompt"}], - "stream": False, - } + mock_llm_instance.chat.assert_called_once_with(messages) + assert result == "test_chat_result"