From 2cafe4a2c0c070894d557a5a6b350c117f8b10e6 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 26 Apr 2025 01:18:37 -0700 Subject: [PATCH] Chat coverage --- src/paperless/ai/chat.py | 3 +- src/paperless/tests/test_ai_chat.py | 142 ++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 src/paperless/tests/test_ai_chat.py diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py index ad14bda4d..7141177d7 100644 --- a/src/paperless/ai/chat.py +++ b/src/paperless/ai/chat.py @@ -37,7 +37,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): if len(nodes) == 0: logger.warning("No nodes found for the given documents.") - return "Sorry, I couldn't find any content to answer your question." + yield "Sorry, I couldn't find any content to answer your question." + return local_index = VectorStoreIndex(nodes=nodes) retriever = local_index.as_retriever( diff --git a/src/paperless/tests/test_ai_chat.py b/src/paperless/tests/test_ai_chat.py new file mode 100644 index 000000000..2b792c4c8 --- /dev/null +++ b/src/paperless/tests/test_ai_chat.py @@ -0,0 +1,142 @@ +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from llama_index.core import VectorStoreIndex +from llama_index.core.schema import TextNode + +from paperless.ai.chat import stream_chat_with_documents + + +@pytest.fixture(autouse=True) +def patch_embed_model(): + from llama_index.core import settings as llama_settings + + mock_embed_model = MagicMock() + mock_embed_model._get_text_embedding_batch.return_value = [ + [0.1] * 1536, + ] # 1 vector per input + llama_settings.Settings._embed_model = mock_embed_model + yield + llama_settings.Settings._embed_model = None + + +@pytest.fixture(autouse=True) +def patch_embed_nodes(): + with patch( + "llama_index.core.indices.vector_store.base.embed_nodes", + ) as mock_embed_nodes: + mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: { + node.node_id: [0.1] * 1536 for node in nodes + } + yield + + +@pytest.fixture +def mock_document(): + doc = MagicMock() + doc.pk = 1 + doc.title = "Test Document" + doc.filename = "test_file.pdf" + doc.content = "This is the document content." + return doc + + +def test_stream_chat_with_one_document_full_content(mock_document): + with ( + patch("paperless.ai.chat.AIClient") as mock_client_cls, + patch("paperless.ai.chat.load_index") as mock_load_index, + patch( + "paperless.ai.chat.RetrieverQueryEngine.from_args", + ) as mock_query_engine_cls, + ): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.llm = MagicMock() + + mock_node = TextNode( + text="This is node content.", + metadata={"document_id": mock_document.pk, "title": "Test Document"}, + ) + mock_index = MagicMock() + mock_index.docstore.docs.values.return_value = [mock_node] + mock_load_index.return_value = mock_index + + mock_response_stream = MagicMock() + mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) + mock_query_engine = MagicMock() + mock_query_engine_cls.return_value = mock_query_engine + mock_query_engine.query.return_value = mock_response_stream + + output = list(stream_chat_with_documents("What is this?", [mock_document])) + + assert output == ["chunk1", "chunk2"] + + +def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes): + with ( + patch("paperless.ai.chat.AIClient") as mock_client_cls, + patch("paperless.ai.chat.load_index") as mock_load_index, + patch( + "paperless.ai.chat.RetrieverQueryEngine.from_args", + ) as mock_query_engine_cls, + patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, + ): + # Mock AIClient and LLM + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.llm = MagicMock() + + # Create two real TextNodes + mock_node1 = TextNode( + text="Content for doc 1.", + metadata={"document_id": 1, "title": "Document 1"}, + ) + mock_node2 = TextNode( + text="Content for doc 2.", + metadata={"document_id": 2, "title": "Document 2"}, + ) + mock_index = MagicMock() + mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2] + mock_load_index.return_value = mock_index + + # Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2 + mock_retriever = MagicMock() + mock_retriever.retrieve.return_value = [mock_node1, mock_node2] + mock_as_retriever.return_value = mock_retriever + + # Mock response stream + mock_response_stream = MagicMock() + mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) + + # Mock RetrieverQueryEngine + mock_query_engine = MagicMock() + mock_query_engine_cls.return_value = mock_query_engine + mock_query_engine.query.return_value = mock_response_stream + + # Fake documents + doc1 = MagicMock(pk=1) + doc2 = MagicMock(pk=2) + + output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) + + assert output == ["chunk1", "chunk2"] + + +def test_stream_chat_no_matching_nodes(): + with ( + patch("paperless.ai.chat.AIClient") as mock_client_cls, + patch("paperless.ai.chat.load_index") as mock_load_index, + ): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.llm = MagicMock() + + mock_index = MagicMock() + # No matching nodes + mock_index.docstore.docs.values.return_value = [] + mock_load_index.return_value = mock_index + + output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) + + assert output == ["Sorry, I couldn't find any content to answer your question."]