Add fallback parsing for invalid ai responses

This commit is contained in:
shamoon
2025-04-30 13:03:31 -07:00
parent 90bd878cf2
commit d9cbd3652a
2 changed files with 94 additions and 24 deletions

View File

@@ -48,6 +48,26 @@ def mock_document():
return doc
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
@@ -76,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
}
```
"""
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@@ -154,26 +224,6 @@ def test_prompt_with_without_rag(mock_document):
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,