Refactor and consolidate rag / embedding and tests

This commit is contained in:
shamoon 2025-04-28 17:36:23 -07:00
parent cd4540412a
commit a1fb3ee7de
No known key found for this signature in database
8 changed files with 121 additions and 77 deletions

View File

@ -5,7 +5,7 @@ from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from paperless.ai.client import AIClient
from paperless.ai.rag import get_context_for_document
from paperless.ai.indexing import query_similar_documents
from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.rag_classifier")
@ -65,6 +65,16 @@ def build_prompt_with_rag(document: Document) -> str:
return prompt
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
similar_docs = query_similar_documents(doc)[:max_docs]
context_blocks = []
for similar in similar_docs:
text = similar.content or ""
title = similar.title or similar.filename or "Untitled"
context_blocks.append(f"TITLE: {title}\n{text}")
return "\n\n".join(context_blocks)
def parse_ai_response(response: CompletionResponse) -> dict:
try:
raw = json.loads(response.text)

View File

@ -1,3 +1,4 @@
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
@ -12,7 +13,7 @@ EMBEDDING_DIMENSIONS = {
}
def get_embedding_model():
def get_embedding_model() -> BaseEmbedding:
config = AIConfig()
match config.llm_embedding_backend:

View File

@ -223,7 +223,10 @@ def query_similar_documents(document: Document, top_k: int = 5) -> list[Document
"""
Runs a similarity query and returns top-k similar Document objects.
"""
index = load_or_build_index()
storage_context = get_or_create_storage_context(rebuild=False)
embed_model = get_embedding_model()
llama_settings.embed_model = embed_model
index = load_or_build_index(storage_context, embed_model)
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
query_text = (document.title or "") + "\n" + (document.content or "")

View File

@ -1,12 +0,0 @@
from documents.models import Document
from paperless.ai.indexing import query_similar_documents
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
similar_docs = query_similar_documents(doc)[:max_docs]
context_blocks = []
for similar in similar_docs:
text = similar.content or ""
title = similar.title or similar.filename or "Untitled"
context_blocks.append(f"TITLE: {title}\n{text}")
return "\n\n".join(context_blocks)

View File

@ -56,7 +56,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": mock_document.pk, "title": "Test Document"},
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
@ -90,11 +90,11 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": 1, "title": "Document 1"},
metadata={"document_id": "1", "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": 2, "title": "Document 2"},
metadata={"document_id": "2", "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]

View File

@ -9,12 +9,43 @@ from documents.models import Document
from paperless.ai.ai_classifier import build_prompt_with_rag
from paperless.ai.ai_classifier import build_prompt_without_rag
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.ai_classifier import get_context_for_document
from paperless.ai.ai_classifier import parse_ai_response
@pytest.fixture
def mock_document():
return Document(filename="test.pdf", content="This is a test document content.")
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
@pytest.mark.django_db
@ -105,13 +136,63 @@ def test_use_without_rag_if_not_configured(
mock_build_prompt_without_rag.assert_called_once()
@pytest.mark.django_db
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_prompt_with_without_rag(mock_document):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
with patch(
"paperless.ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless.ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
mock_similar_documents,
):
mock_query_similar_documents.return_value = mock_similar_documents
result = get_context_for_document(mock_document, max_docs=2)
expected_result = (
"TITLE: Title 1\nContent of document 1\n\n"
"TITLE: file2.txt\nContent of document 2"
)
assert result == expected_result
mock_query_similar_documents.assert_called_once()
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""

View File

@ -7,10 +7,15 @@ from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless.ai.rag import get_context_for_document
from paperless.models import LLMEmbeddingBackend
@pytest.fixture
def mock_ai_config():
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
@ -46,59 +51,6 @@ def mock_document():
return doc
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless.ai.rag.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
mock_similar_documents,
):
mock_query_similar_documents.return_value = mock_similar_documents
result = get_context_for_document(mock_document, max_docs=2)
expected_result = (
"TITLE: Title 1\nContent of document 1\n\n"
"TITLE: file2.txt\nContent of document 2"
)
assert result == expected_result
mock_query_similar_documents.assert_called_once()
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless.ai.rag.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""
# Embedding
@pytest.fixture
def mock_ai_config():
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig
def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
@ -162,15 +163,23 @@ def test_update_llm_index_no_documents(
)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
):
with (
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_index = MagicMock()
mock_load_or_build_index.return_value = mock_index