Fix tests for change to structured output

This commit is contained in:
shamoon 2025-07-15 14:34:54 -07:00
parent 20bae4bd41
commit 5bfbe856a6
No known key found for this signature in database
2 changed files with 22 additions and 69 deletions

View File

@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture @pytest.fixture
@ -75,50 +74,14 @@ def mock_similar_documents():
LLM_MODEL="some_model", LLM_MODEL="some_model",
) )
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps( mock_run_llm_query.return_value = {
{
"title": "Test Title", "title": "Test Title",
"tags": ["test", "document"], "tags": ["test", "document"],
"correspondents": ["John Doe"], "correspondents": ["John Doe"],
"document_types": ["report"], "document_types": ["report"],
"storage_paths": ["Reports"], "storage_paths": ["Reports"],
"dates": ["2023-01-01"], "dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
} }
```
"""
result = get_ai_document_classification(mock_document) result = get_ai_document_classification(mock_document)
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
assert result["dates"] == ["2023-01-01"] assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
get_ai_document_classification(mock_document) get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag") @patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
return_value="Context from similar documents", return_value="Context from similar documents",
): ):
prompt = build_prompt_without_rag(mock_document) prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document) prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents") @patch("paperless_ai.ai_classifier.query_similar_documents")

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from llama_index.core.llms import ChatMessage from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient from paperless_ai.client import AIClient
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_url = "http://test-url" mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
tool_kwargs={
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
client = AIClient() client = AIClient()
result = client.run_llm_query("test_prompt") result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt") assert result["title"] == "Test Title"
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm): def test_run_chat(mock_ai_config, mock_ollama_llm):