diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index b75ceb1e5..55c7c7704 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -16,7 +16,7 @@ logger = logging.getLogger("paperless_ai.rag_classifier") def build_prompt_without_rag(document: Document) -> str: filename = document.filename or "" - content = truncate_content(document.content or "") + content = truncate_content(document.content[:4000] or "") prompt = f""" You are an assistant that extracts structured information from documents. @@ -43,7 +43,7 @@ def build_prompt_without_rag(document: Document) -> str: "storage_paths": ["xxxx", "xxxx"], "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"], }} - --- + --------- FILENAME: {filename} @@ -63,6 +63,10 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str: CONTEXT FROM SIMILAR DOCUMENTS: {context} + + --------- + + DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT. """ return prompt @@ -108,8 +112,24 @@ def parse_ai_response(response: CompletionResponse) -> dict: "dates": raw.get("dates", []), } except json.JSONDecodeError: - logger.exception("Invalid JSON in AI response") - return {} + logger.warning("Invalid JSON in AI response, attempting modified parsing...") + try: + # search for a valid json string like { ... } in the response + start = response.text.index("{") + end = response.text.rindex("}") + 1 + json_str = response.text[start:end] + raw = json.loads(json_str) + return { + "title": raw.get("title"), + "tags": raw.get("tags", []), + "correspondents": raw.get("correspondents", []), + "document_types": raw.get("document_types", []), + "storage_paths": raw.get("storage_paths", []), + "dates": raw.get("dates", []), + } + except (ValueError, json.JSONDecodeError): + logger.exception("Failed to parse AI response") + return {} def get_ai_document_classification( diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 408678f7b..548acbc6c 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -48,6 +48,26 @@ def mock_document(): return doc +@pytest.fixture +def mock_similar_documents(): + doc1 = MagicMock() + doc1.content = "Content of document 1" + doc1.title = "Title 1" + doc1.filename = "file1.txt" + + doc2 = MagicMock() + doc2.content = "Content of document 2" + doc2.title = None + doc2.filename = "file2.txt" + + doc3 = MagicMock() + doc3.content = None + doc3.title = None + doc3.filename = None + + return [doc1, doc2, doc3] + + @pytest.mark.django_db @patch("paperless_ai.client.AIClient.run_llm_query") @override_settings( @@ -76,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen assert result["dates"] == ["2023-01-01"] +@pytest.mark.django_db +@patch("paperless_ai.client.AIClient.run_llm_query") +@override_settings( + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) +def test_get_ai_document_classification_fallback_parse_success( + mock_run_llm_query, + mock_document, +): + mock_run_llm_query.return_value.text = """ + There is some text before the JSON. + ```json + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"] + } + ``` + """ + + result = get_ai_document_classification(mock_document) + + assert result["title"] == "Test Title" + assert result["tags"] == ["test", "document"] + assert result["correspondents"] == ["John Doe"] + assert result["document_types"] == ["report"] + assert result["storage_paths"] == ["Reports"] + assert result["dates"] == ["2023-01-01"] + + +@pytest.mark.django_db +@patch("paperless_ai.client.AIClient.run_llm_query") +@override_settings( + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) +def test_get_ai_document_classification_parse_failure( + mock_run_llm_query, + mock_document, +): + mock_run_llm_query.return_value.text = "Invalid JSON response" + + result = get_ai_document_classification(mock_document) + assert result == {} + + @pytest.mark.django_db @patch("paperless_ai.client.AIClient.run_llm_query") def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): @@ -154,26 +224,6 @@ def test_prompt_with_without_rag(mock_document): assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt -@pytest.fixture -def mock_similar_documents(): - doc1 = MagicMock() - doc1.content = "Content of document 1" - doc1.title = "Title 1" - doc1.filename = "file1.txt" - - doc2 = MagicMock() - doc2.content = "Content of document 2" - doc2.title = None - doc2.filename = "file2.txt" - - doc3 = MagicMock() - doc3.content = None - doc3.title = None - doc3.filename = None - - return [doc1, doc2, doc3] - - @patch("paperless_ai.ai_classifier.query_similar_documents") def test_get_context_for_document( mock_query_similar_documents,