Chat coverage

This commit is contained in:
shamoon 2025-04-26 01:18:37 -07:00
parent c02d9249e7
commit 2cafe4a2c0
No known key found for this signature in database
2 changed files with 144 additions and 1 deletions

View File

@ -37,7 +37,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
if len(nodes) == 0: if len(nodes) == 0:
logger.warning("No nodes found for the given documents.") logger.warning("No nodes found for the given documents.")
return "Sorry, I couldn't find any content to answer your question." yield "Sorry, I couldn't find any content to answer your question."
return
local_index = VectorStoreIndex(nodes=nodes) local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever( retriever = local_index.as_retriever(

View File

@ -0,0 +1,142 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from paperless.ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
mock_embed_model = MagicMock()
mock_embed_model._get_text_embedding_batch.return_value = [
[0.1] * 1536,
] # 1 vector per input
llama_settings.Settings._embed_model = mock_embed_model
yield
llama_settings.Settings._embed_model = None
@pytest.fixture(autouse=True)
def patch_embed_nodes():
with patch(
"llama_index.core.indices.vector_store.base.embed_nodes",
) as mock_embed_nodes:
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
node.node_id: [0.1] * 1536 for node in nodes
}
yield
@pytest.fixture
def mock_document():
doc = MagicMock()
doc.pk = 1
doc.title = "Test Document"
doc.filename = "test_file.pdf"
doc.content = "This is the document content."
return doc
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": mock_document.pk, "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": 1, "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": 2, "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_as_retriever.return_value = mock_retriever
# Mock response stream
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1)
doc2 = MagicMock(pk=2)
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_no_matching_nodes():
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == ["Sorry, I couldn't find any content to answer your question."]