mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Add fallback parsing for invalid ai responses
This commit is contained in:
parent
47434dcb72
commit
ffca85c146
@ -16,7 +16,7 @@ logger = logging.getLogger("paperless_ai.rag_classifier")
|
|||||||
|
|
||||||
def build_prompt_without_rag(document: Document) -> str:
|
def build_prompt_without_rag(document: Document) -> str:
|
||||||
filename = document.filename or ""
|
filename = document.filename or ""
|
||||||
content = truncate_content(document.content or "")
|
content = truncate_content(document.content[:4000] or "")
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
You are an assistant that extracts structured information from documents.
|
You are an assistant that extracts structured information from documents.
|
||||||
@ -43,7 +43,7 @@ def build_prompt_without_rag(document: Document) -> str:
|
|||||||
"storage_paths": ["xxxx", "xxxx"],
|
"storage_paths": ["xxxx", "xxxx"],
|
||||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||||
}}
|
}}
|
||||||
---
|
---------
|
||||||
|
|
||||||
FILENAME:
|
FILENAME:
|
||||||
{filename}
|
{filename}
|
||||||
@ -63,6 +63,10 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
|||||||
|
|
||||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||||
{context}
|
{context}
|
||||||
|
|
||||||
|
---------
|
||||||
|
|
||||||
|
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
@ -108,8 +112,24 @@ def parse_ai_response(response: CompletionResponse) -> dict:
|
|||||||
"dates": raw.get("dates", []),
|
"dates": raw.get("dates", []),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.exception("Invalid JSON in AI response")
|
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
||||||
return {}
|
try:
|
||||||
|
# search for a valid json string like { ... } in the response
|
||||||
|
start = response.text.index("{")
|
||||||
|
end = response.text.rindex("}") + 1
|
||||||
|
json_str = response.text[start:end]
|
||||||
|
raw = json.loads(json_str)
|
||||||
|
return {
|
||||||
|
"title": raw.get("title"),
|
||||||
|
"tags": raw.get("tags", []),
|
||||||
|
"correspondents": raw.get("correspondents", []),
|
||||||
|
"document_types": raw.get("document_types", []),
|
||||||
|
"storage_paths": raw.get("storage_paths", []),
|
||||||
|
"dates": raw.get("dates", []),
|
||||||
|
}
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
logger.exception("Failed to parse AI response")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_ai_document_classification(
|
def get_ai_document_classification(
|
||||||
|
@ -48,6 +48,26 @@ def mock_document():
|
|||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_similar_documents():
|
||||||
|
doc1 = MagicMock()
|
||||||
|
doc1.content = "Content of document 1"
|
||||||
|
doc1.title = "Title 1"
|
||||||
|
doc1.filename = "file1.txt"
|
||||||
|
|
||||||
|
doc2 = MagicMock()
|
||||||
|
doc2.content = "Content of document 2"
|
||||||
|
doc2.title = None
|
||||||
|
doc2.filename = "file2.txt"
|
||||||
|
|
||||||
|
doc3 = MagicMock()
|
||||||
|
doc3.content = None
|
||||||
|
doc3.title = None
|
||||||
|
doc3.filename = None
|
||||||
|
|
||||||
|
return [doc1, doc2, doc3]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
@override_settings(
|
@override_settings(
|
||||||
@ -76,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
|||||||
assert result["dates"] == ["2023-01-01"]
|
assert result["dates"] == ["2023-01-01"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
|
@override_settings(
|
||||||
|
LLM_BACKEND="ollama",
|
||||||
|
LLM_MODEL="some_model",
|
||||||
|
)
|
||||||
|
def test_get_ai_document_classification_fallback_parse_success(
|
||||||
|
mock_run_llm_query,
|
||||||
|
mock_document,
|
||||||
|
):
|
||||||
|
mock_run_llm_query.return_value.text = """
|
||||||
|
There is some text before the JSON.
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"title": "Test Title",
|
||||||
|
"tags": ["test", "document"],
|
||||||
|
"correspondents": ["John Doe"],
|
||||||
|
"document_types": ["report"],
|
||||||
|
"storage_paths": ["Reports"],
|
||||||
|
"dates": ["2023-01-01"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
|
assert result["title"] == "Test Title"
|
||||||
|
assert result["tags"] == ["test", "document"]
|
||||||
|
assert result["correspondents"] == ["John Doe"]
|
||||||
|
assert result["document_types"] == ["report"]
|
||||||
|
assert result["storage_paths"] == ["Reports"]
|
||||||
|
assert result["dates"] == ["2023-01-01"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
|
@override_settings(
|
||||||
|
LLM_BACKEND="ollama",
|
||||||
|
LLM_MODEL="some_model",
|
||||||
|
)
|
||||||
|
def test_get_ai_document_classification_parse_failure(
|
||||||
|
mock_run_llm_query,
|
||||||
|
mock_document,
|
||||||
|
):
|
||||||
|
mock_run_llm_query.return_value.text = "Invalid JSON response"
|
||||||
|
|
||||||
|
result = get_ai_document_classification(mock_document)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||||
@ -154,26 +224,6 @@ def test_prompt_with_without_rag(mock_document):
|
|||||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_similar_documents():
|
|
||||||
doc1 = MagicMock()
|
|
||||||
doc1.content = "Content of document 1"
|
|
||||||
doc1.title = "Title 1"
|
|
||||||
doc1.filename = "file1.txt"
|
|
||||||
|
|
||||||
doc2 = MagicMock()
|
|
||||||
doc2.content = "Content of document 2"
|
|
||||||
doc2.title = None
|
|
||||||
doc2.filename = "file2.txt"
|
|
||||||
|
|
||||||
doc3 = MagicMock()
|
|
||||||
doc3.content = None
|
|
||||||
doc3.title = None
|
|
||||||
doc3.filename = None
|
|
||||||
|
|
||||||
return [doc1, doc2, doc3]
|
|
||||||
|
|
||||||
|
|
||||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||||
def test_get_context_for_document(
|
def test_get_context_for_document(
|
||||||
mock_query_similar_documents,
|
mock_query_similar_documents,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user