mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-16 17:25:11 -05:00
Fix tests for change to structured output
This commit is contained in:
parent
20bae4bd41
commit
5bfbe856a6
@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
from paperless_ai.ai_classifier import parse_ai_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -75,50 +74,14 @@ def mock_similar_documents():
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value.text = json.dumps(
|
||||
{
|
||||
mock_run_llm_query.return_value = {
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["report"]
|
||||
assert result["storage_paths"] == ["Reports"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_fallback_parse_success(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = """
|
||||
There is some text before the JSON.
|
||||
```json
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"]
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_parse_failure(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = "Invalid JSON response"
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
|
||||
get_ai_document_classification(mock_document)
|
||||
|
||||
|
||||
def test_parse_llm_classification_response_invalid_json():
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Invalid JSON response"
|
||||
|
||||
result = parse_ai_response(mock_response)
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||
assert "Additional context from similar documents:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||
assert "Additional context from similar documents:" in prompt
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
|
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.llms.llm import ToolSelection
|
||||
|
||||
from paperless_ai.client import AIClient
|
||||
|
||||
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.complete.return_value = "test_result"
|
||||
|
||||
tool_selection = ToolSelection(
|
||||
tool_id="call_test",
|
||||
tool_name="DocumentClassifierSchema",
|
||||
tool_kwargs={
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
mock_llm_instance.chat_with_tools.return_value = MagicMock()
|
||||
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("test_prompt")
|
||||
|
||||
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
||||
assert result == "test_result"
|
||||
assert result["title"] == "Test Title"
|
||||
|
||||
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
|
Loading…
x
Reference in New Issue
Block a user