Add fallback parsing for invalid ai responses

This commit is contained in:
shamoon 2025-04-30 13:03:31 -07:00
parent 47434dcb72
commit ffca85c146
No known key found for this signature in database
2 changed files with 94 additions and 24 deletions

View File

@ -16,7 +16,7 @@ logger = logging.getLogger("paperless_ai.rag_classifier")
def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or ""
content = truncate_content(document.content or "")
content = truncate_content(document.content[:4000] or "")
prompt = f"""
You are an assistant that extracts structured information from documents.
@ -43,7 +43,7 @@ def build_prompt_without_rag(document: Document) -> str:
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---
---------
FILENAME:
{filename}
@ -63,6 +63,10 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
CONTEXT FROM SIMILAR DOCUMENTS:
{context}
---------
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
"""
return prompt
@ -108,8 +112,24 @@ def parse_ai_response(response: CompletionResponse) -> dict:
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.exception("Invalid JSON in AI response")
return {}
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
try:
# search for a valid json string like { ... } in the response
start = response.text.index("{")
end = response.text.rindex("}") + 1
json_str = response.text[start:end]
raw = json.loads(json_str)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except (ValueError, json.JSONDecodeError):
logger.exception("Failed to parse AI response")
return {}
def get_ai_document_classification(

View File

@ -48,6 +48,26 @@ def mock_document():
return doc
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
@ -76,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
}
```
"""
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@ -154,26 +224,6 @@ def test_prompt_with_without_rag(mock_document):
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,