mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Refactor and consolidate rag / embedding and tests
This commit is contained in:
parent
cd4540412a
commit
a1fb3ee7de
@ -5,7 +5,7 @@ from llama_index.core.base.llms.types import CompletionResponse
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.rag import get_context_for_document
|
||||
from paperless.ai.indexing import query_similar_documents
|
||||
from paperless.config import AIConfig
|
||||
|
||||
logger = logging.getLogger("paperless.ai.rag_classifier")
|
||||
@ -65,6 +65,16 @@ def build_prompt_with_rag(document: Document) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
||||
similar_docs = query_similar_documents(doc)[:max_docs]
|
||||
context_blocks = []
|
||||
for similar in similar_docs:
|
||||
text = similar.content or ""
|
||||
title = similar.title or similar.filename or "Untitled"
|
||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||
return "\n\n".join(context_blocks)
|
||||
|
||||
|
||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||
try:
|
||||
raw = json.loads(response.text)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
@ -12,7 +13,7 @@ EMBEDDING_DIMENSIONS = {
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model():
|
||||
def get_embedding_model() -> BaseEmbedding:
|
||||
config = AIConfig()
|
||||
|
||||
match config.llm_embedding_backend:
|
||||
|
@ -223,7 +223,10 @@ def query_similar_documents(document: Document, top_k: int = 5) -> list[Document
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.embed_model = embed_model
|
||||
index = load_or_build_index(storage_context, embed_model)
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||
|
||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||
|
@ -1,12 +0,0 @@
|
||||
from documents.models import Document
|
||||
from paperless.ai.indexing import query_similar_documents
|
||||
|
||||
|
||||
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
||||
similar_docs = query_similar_documents(doc)[:max_docs]
|
||||
context_blocks = []
|
||||
for similar in similar_docs:
|
||||
text = similar.content or ""
|
||||
title = similar.title or similar.filename or "Untitled"
|
||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||
return "\n\n".join(context_blocks)
|
@ -56,7 +56,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": mock_document.pk, "title": "Test Document"},
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
@ -90,11 +90,11 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": 1, "title": "Document 1"},
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
)
|
||||
mock_node2 = TextNode(
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": 2, "title": "Document 2"},
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||
|
@ -9,12 +9,43 @@ from documents.models import Document
|
||||
from paperless.ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless.ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||
from paperless.ai.ai_classifier import get_context_for_document
|
||||
from paperless.ai.ai_classifier import parse_ai_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
return Document(filename="test.pdf", content="This is a test document content.")
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@ -105,13 +136,63 @@ def test_use_without_rag_if_not_configured(
|
||||
mock_build_prompt_without_rag.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_prompt_with_without_rag(mock_document):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||
with patch(
|
||||
"paperless.ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similar_documents():
|
||||
doc1 = MagicMock()
|
||||
doc1.content = "Content of document 1"
|
||||
doc1.title = "Title 1"
|
||||
doc1.filename = "file1.txt"
|
||||
|
||||
doc2 = MagicMock()
|
||||
doc2.content = "Content of document 2"
|
||||
doc2.title = None
|
||||
doc2.filename = "file2.txt"
|
||||
|
||||
doc3 = MagicMock()
|
||||
doc3.content = None
|
||||
doc3.title = None
|
||||
doc3.filename = None
|
||||
|
||||
return [doc1, doc2, doc3]
|
||||
|
||||
|
||||
@patch("paperless.ai.ai_classifier.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
mock_query_similar_documents,
|
||||
mock_document,
|
||||
mock_similar_documents,
|
||||
):
|
||||
mock_query_similar_documents.return_value = mock_similar_documents
|
||||
|
||||
result = get_context_for_document(mock_document, max_docs=2)
|
||||
|
||||
expected_result = (
|
||||
"TITLE: Title 1\nContent of document 1\n\n"
|
||||
"TITLE: file2.txt\nContent of document 2"
|
||||
)
|
||||
assert result == expected_result
|
||||
mock_query_similar_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
|
@ -7,10 +7,15 @@ from documents.models import Document
|
||||
from paperless.ai.embedding import build_llm_index_text
|
||||
from paperless.ai.embedding import get_embedding_dim
|
||||
from paperless.ai.embedding import get_embedding_model
|
||||
from paperless.ai.rag import get_context_for_document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
@ -46,59 +51,6 @@ def mock_document():
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similar_documents():
|
||||
doc1 = MagicMock()
|
||||
doc1.content = "Content of document 1"
|
||||
doc1.title = "Title 1"
|
||||
doc1.filename = "file1.txt"
|
||||
|
||||
doc2 = MagicMock()
|
||||
doc2.content = "Content of document 2"
|
||||
doc2.title = None
|
||||
doc2.filename = "file2.txt"
|
||||
|
||||
doc3 = MagicMock()
|
||||
doc3.content = None
|
||||
doc3.title = None
|
||||
doc3.filename = None
|
||||
|
||||
return [doc1, doc2, doc3]
|
||||
|
||||
|
||||
@patch("paperless.ai.rag.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
mock_query_similar_documents,
|
||||
mock_document,
|
||||
mock_similar_documents,
|
||||
):
|
||||
mock_query_similar_documents.return_value = mock_similar_documents
|
||||
|
||||
result = get_context_for_document(mock_document, max_docs=2)
|
||||
|
||||
expected_result = (
|
||||
"TITLE: Title 1\nContent of document 1\n\n"
|
||||
"TITLE: file2.txt\nContent of document 2"
|
||||
)
|
||||
assert result == expected_result
|
||||
mock_query_similar_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless.ai.rag.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
|
||||
|
||||
# Embedding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
def test_get_embedding_model_openai(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
@ -162,15 +163,23 @@ def test_update_llm_index_no_documents(
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
):
|
||||
with (
|
||||
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user