Fix tests for change to structured output

This commit is contained in:
shamoon
2025-07-15 14:34:54 -07:00
parent 20bae4bd41
commit 5bfbe856a6
2 changed files with 22 additions and 69 deletions

View File

@@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
@@ -75,50 +74,14 @@ def mock_similar_documents():
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
mock_run_llm_query.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
"dates": ["2023-01-01"],
}
```
"""
result = get_ai_document_classification(mock_document)
@@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents")