mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Chat coverage
This commit is contained in:
parent
c02d9249e7
commit
2cafe4a2c0
@ -37,7 +37,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
|||||||
|
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
logger.warning("No nodes found for the given documents.")
|
logger.warning("No nodes found for the given documents.")
|
||||||
return "Sorry, I couldn't find any content to answer your question."
|
yield "Sorry, I couldn't find any content to answer your question."
|
||||||
|
return
|
||||||
|
|
||||||
local_index = VectorStoreIndex(nodes=nodes)
|
local_index = VectorStoreIndex(nodes=nodes)
|
||||||
retriever = local_index.as_retriever(
|
retriever = local_index.as_retriever(
|
||||||
|
142
src/paperless/tests/test_ai_chat.py
Normal file
142
src/paperless/tests/test_ai_chat.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
|
|
||||||
|
from paperless.ai.chat import stream_chat_with_documents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_embed_model():
|
||||||
|
from llama_index.core import settings as llama_settings
|
||||||
|
|
||||||
|
mock_embed_model = MagicMock()
|
||||||
|
mock_embed_model._get_text_embedding_batch.return_value = [
|
||||||
|
[0.1] * 1536,
|
||||||
|
] # 1 vector per input
|
||||||
|
llama_settings.Settings._embed_model = mock_embed_model
|
||||||
|
yield
|
||||||
|
llama_settings.Settings._embed_model = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_embed_nodes():
|
||||||
|
with patch(
|
||||||
|
"llama_index.core.indices.vector_store.base.embed_nodes",
|
||||||
|
) as mock_embed_nodes:
|
||||||
|
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
|
||||||
|
node.node_id: [0.1] * 1536 for node in nodes
|
||||||
|
}
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document():
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.pk = 1
|
||||||
|
doc.title = "Test Document"
|
||||||
|
doc.filename = "test_file.pdf"
|
||||||
|
doc.content = "This is the document content."
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||||
|
with (
|
||||||
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
|
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||||
|
patch(
|
||||||
|
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||||
|
) as mock_query_engine_cls,
|
||||||
|
):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
mock_client.llm = MagicMock()
|
||||||
|
|
||||||
|
mock_node = TextNode(
|
||||||
|
text="This is node content.",
|
||||||
|
metadata={"document_id": mock_document.pk, "title": "Test Document"},
|
||||||
|
)
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||||
|
mock_load_index.return_value = mock_index
|
||||||
|
|
||||||
|
mock_response_stream = MagicMock()
|
||||||
|
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||||
|
mock_query_engine = MagicMock()
|
||||||
|
mock_query_engine_cls.return_value = mock_query_engine
|
||||||
|
mock_query_engine.query.return_value = mock_response_stream
|
||||||
|
|
||||||
|
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||||
|
|
||||||
|
assert output == ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||||
|
with (
|
||||||
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
|
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||||
|
patch(
|
||||||
|
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||||
|
) as mock_query_engine_cls,
|
||||||
|
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||||
|
):
|
||||||
|
# Mock AIClient and LLM
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
mock_client.llm = MagicMock()
|
||||||
|
|
||||||
|
# Create two real TextNodes
|
||||||
|
mock_node1 = TextNode(
|
||||||
|
text="Content for doc 1.",
|
||||||
|
metadata={"document_id": 1, "title": "Document 1"},
|
||||||
|
)
|
||||||
|
mock_node2 = TextNode(
|
||||||
|
text="Content for doc 2.",
|
||||||
|
metadata={"document_id": 2, "title": "Document 2"},
|
||||||
|
)
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||||
|
mock_load_index.return_value = mock_index
|
||||||
|
|
||||||
|
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
|
||||||
|
mock_retriever = MagicMock()
|
||||||
|
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||||
|
mock_as_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
# Mock response stream
|
||||||
|
mock_response_stream = MagicMock()
|
||||||
|
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||||
|
|
||||||
|
# Mock RetrieverQueryEngine
|
||||||
|
mock_query_engine = MagicMock()
|
||||||
|
mock_query_engine_cls.return_value = mock_query_engine
|
||||||
|
mock_query_engine.query.return_value = mock_response_stream
|
||||||
|
|
||||||
|
# Fake documents
|
||||||
|
doc1 = MagicMock(pk=1)
|
||||||
|
doc2 = MagicMock(pk=2)
|
||||||
|
|
||||||
|
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||||
|
|
||||||
|
assert output == ["chunk1", "chunk2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat_no_matching_nodes():
|
||||||
|
with (
|
||||||
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
|
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||||
|
):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
mock_client.llm = MagicMock()
|
||||||
|
|
||||||
|
mock_index = MagicMock()
|
||||||
|
# No matching nodes
|
||||||
|
mock_index.docstore.docs.values.return_value = []
|
||||||
|
mock_load_index.return_value = mock_index
|
||||||
|
|
||||||
|
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||||
|
|
||||||
|
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
Loading…
x
Reference in New Issue
Block a user