diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 548acbc6c..115d51cd4 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag from paperless_ai.ai_classifier import build_prompt_without_rag from paperless_ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_context_for_document -from paperless_ai.ai_classifier import parse_ai_response @pytest.fixture @@ -75,50 +74,14 @@ def mock_similar_documents(): LLM_MODEL="some_model", ) def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): - mock_run_llm_query.return_value.text = json.dumps( - { - "title": "Test Title", - "tags": ["test", "document"], - "correspondents": ["John Doe"], - "document_types": ["report"], - "storage_paths": ["Reports"], - "dates": ["2023-01-01"], - }, - ) - - result = get_ai_document_classification(mock_document) - - assert result["title"] == "Test Title" - assert result["tags"] == ["test", "document"] - assert result["correspondents"] == ["John Doe"] - assert result["document_types"] == ["report"] - assert result["storage_paths"] == ["Reports"] - assert result["dates"] == ["2023-01-01"] - - -@pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -@override_settings( - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) -def test_get_ai_document_classification_fallback_parse_success( - mock_run_llm_query, - mock_document, -): - mock_run_llm_query.return_value.text = """ - There is some text before the JSON. - ```json - { + mock_run_llm_query.return_value = { "title": "Test Title", "tags": ["test", "document"], "correspondents": ["John Doe"], "document_types": ["report"], "storage_paths": ["Reports"], - "dates": ["2023-01-01"] + "dates": ["2023-01-01"], } - ``` - """ result = get_ai_document_classification(mock_document) @@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success( assert result["dates"] == ["2023-01-01"] -@pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -@override_settings( - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) -def test_get_ai_document_classification_parse_failure( - mock_run_llm_query, - mock_document, -): - mock_run_llm_query.return_value.text = "Invalid JSON response" - - result = get_ai_document_classification(mock_document) - assert result == {} - - @pytest.mark.django_db @patch("paperless_ai.client.AIClient.run_llm_query") def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): @@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen get_ai_document_classification(mock_document) -def test_parse_llm_classification_response_invalid_json(): - mock_response = MagicMock() - mock_response.text = "Invalid JSON response" - - result = parse_ai_response(mock_response) - - assert result == {} - - @pytest.mark.django_db @patch("paperless_ai.client.AIClient.run_llm_query") @patch("paperless_ai.ai_classifier.build_prompt_with_rag") @@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document): return_value="Context from similar documents", ): prompt = build_prompt_without_rag(mock_document) - assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt + assert "Additional context from similar documents:" not in prompt prompt = build_prompt_with_rag(mock_document) - assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt + assert "Additional context from similar documents:" in prompt @patch("paperless_ai.ai_classifier.query_similar_documents") diff --git a/src/paperless_ai/tests/test_client.py b/src/paperless_ai/tests/test_client.py index 7cd2b16b0..6ef7b332b 100644 --- a/src/paperless_ai/tests/test_client.py +++ b/src/paperless_ai/tests/test_client.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from llama_index.core.llms import ChatMessage +from llama_index.core.llms.llm import ToolSelection from paperless_ai.client import AIClient @@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm): mock_ai_config.llm_url = "http://test-url" mock_llm_instance = mock_ollama_llm.return_value - mock_llm_instance.complete.return_value = "test_result" + + tool_selection = ToolSelection( + tool_id="call_test", + tool_name="DocumentClassifierSchema", + tool_kwargs={ + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + ) + + mock_llm_instance.chat_with_tools.return_value = MagicMock() + mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection] client = AIClient() result = client.run_llm_query("test_prompt") - mock_llm_instance.complete.assert_called_once_with("test_prompt") - assert result == "test_result" + assert result["title"] == "Test Title" def test_run_chat(mock_ai_config, mock_ollama_llm):