mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Refactor and consolidate rag / embedding and tests
This commit is contained in:
parent
cd4540412a
commit
a1fb3ee7de
@ -5,7 +5,7 @@ from llama_index.core.base.llms.types import CompletionResponse
|
|||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
from paperless.ai.rag import get_context_for_document
|
from paperless.ai.indexing import query_similar_documents
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.rag_classifier")
|
logger = logging.getLogger("paperless.ai.rag_classifier")
|
||||||
@ -65,6 +65,16 @@ def build_prompt_with_rag(document: Document) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
||||||
|
similar_docs = query_similar_documents(doc)[:max_docs]
|
||||||
|
context_blocks = []
|
||||||
|
for similar in similar_docs:
|
||||||
|
text = similar.content or ""
|
||||||
|
title = similar.title or similar.filename or "Untitled"
|
||||||
|
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||||
|
return "\n\n".join(context_blocks)
|
||||||
|
|
||||||
|
|
||||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||||
try:
|
try:
|
||||||
raw = json.loads(response.text)
|
raw = json.loads(response.text)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
|
||||||
@ -12,7 +13,7 @@ EMBEDDING_DIMENSIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model():
|
def get_embedding_model() -> BaseEmbedding:
|
||||||
config = AIConfig()
|
config = AIConfig()
|
||||||
|
|
||||||
match config.llm_embedding_backend:
|
match config.llm_embedding_backend:
|
||||||
|
@ -223,7 +223,10 @@ def query_similar_documents(document: Document, top_k: int = 5) -> list[Document
|
|||||||
"""
|
"""
|
||||||
Runs a similarity query and returns top-k similar Document objects.
|
Runs a similarity query and returns top-k similar Document objects.
|
||||||
"""
|
"""
|
||||||
index = load_or_build_index()
|
storage_context = get_or_create_storage_context(rebuild=False)
|
||||||
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.embed_model = embed_model
|
||||||
|
index = load_or_build_index(storage_context, embed_model)
|
||||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||||
|
|
||||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
from documents.models import Document
|
|
||||||
from paperless.ai.indexing import query_similar_documents
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
|
||||||
similar_docs = query_similar_documents(doc)[:max_docs]
|
|
||||||
context_blocks = []
|
|
||||||
for similar in similar_docs:
|
|
||||||
text = similar.content or ""
|
|
||||||
title = similar.title or similar.filename or "Untitled"
|
|
||||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
|
||||||
return "\n\n".join(context_blocks)
|
|
@ -56,7 +56,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
|
|||||||
|
|
||||||
mock_node = TextNode(
|
mock_node = TextNode(
|
||||||
text="This is node content.",
|
text="This is node content.",
|
||||||
metadata={"document_id": mock_document.pk, "title": "Test Document"},
|
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||||
)
|
)
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||||
@ -90,11 +90,11 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
|||||||
# Create two real TextNodes
|
# Create two real TextNodes
|
||||||
mock_node1 = TextNode(
|
mock_node1 = TextNode(
|
||||||
text="Content for doc 1.",
|
text="Content for doc 1.",
|
||||||
metadata={"document_id": 1, "title": "Document 1"},
|
metadata={"document_id": "1", "title": "Document 1"},
|
||||||
)
|
)
|
||||||
mock_node2 = TextNode(
|
mock_node2 = TextNode(
|
||||||
text="Content for doc 2.",
|
text="Content for doc 2.",
|
||||||
metadata={"document_id": 2, "title": "Document 2"},
|
metadata={"document_id": "2", "title": "Document 2"},
|
||||||
)
|
)
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||||
|
@ -9,12 +9,43 @@ from documents.models import Document
|
|||||||
from paperless.ai.ai_classifier import build_prompt_with_rag
|
from paperless.ai.ai_classifier import build_prompt_with_rag
|
||||||
from paperless.ai.ai_classifier import build_prompt_without_rag
|
from paperless.ai.ai_classifier import build_prompt_without_rag
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
|
from paperless.ai.ai_classifier import get_context_for_document
|
||||||
from paperless.ai.ai_classifier import parse_ai_response
|
from paperless.ai.ai_classifier import parse_ai_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_document():
|
def mock_document():
|
||||||
return Document(filename="test.pdf", content="This is a test document content.")
|
doc = MagicMock(spec=Document)
|
||||||
|
doc.title = "Test Title"
|
||||||
|
doc.filename = "test_file.pdf"
|
||||||
|
doc.created = "2023-01-01"
|
||||||
|
doc.added = "2023-01-02"
|
||||||
|
doc.modified = "2023-01-03"
|
||||||
|
|
||||||
|
tag1 = MagicMock()
|
||||||
|
tag1.name = "Tag1"
|
||||||
|
tag2 = MagicMock()
|
||||||
|
tag2.name = "Tag2"
|
||||||
|
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||||
|
|
||||||
|
doc.document_type = MagicMock()
|
||||||
|
doc.document_type.name = "Invoice"
|
||||||
|
doc.correspondent = MagicMock()
|
||||||
|
doc.correspondent.name = "Test Correspondent"
|
||||||
|
doc.archive_serial_number = "12345"
|
||||||
|
doc.content = "This is the document content."
|
||||||
|
|
||||||
|
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||||
|
cf1.field = MagicMock()
|
||||||
|
cf1.field.name = "Field1"
|
||||||
|
cf1.value = "Value1"
|
||||||
|
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||||
|
cf2.field = MagicMock()
|
||||||
|
cf2.field.name = "Field2"
|
||||||
|
cf2.value = "Value2"
|
||||||
|
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||||
|
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@ -105,13 +136,63 @@ def test_use_without_rag_if_not_configured(
|
|||||||
mock_build_prompt_without_rag.assert_called_once()
|
mock_build_prompt_without_rag.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
@override_settings(
|
@override_settings(
|
||||||
|
LLM_EMBEDDING_BACKEND="huggingface",
|
||||||
LLM_BACKEND="ollama",
|
LLM_BACKEND="ollama",
|
||||||
LLM_MODEL="some_model",
|
LLM_MODEL="some_model",
|
||||||
)
|
)
|
||||||
def test_prompt_with_without_rag(mock_document):
|
def test_prompt_with_without_rag(mock_document):
|
||||||
prompt = build_prompt_without_rag(mock_document)
|
with patch(
|
||||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
"paperless.ai.ai_classifier.get_context_for_document",
|
||||||
|
return_value="Context from similar documents",
|
||||||
|
):
|
||||||
|
prompt = build_prompt_without_rag(mock_document)
|
||||||
|
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||||
|
|
||||||
prompt = build_prompt_with_rag(mock_document)
|
prompt = build_prompt_with_rag(mock_document)
|
||||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_similar_documents():
|
||||||
|
doc1 = MagicMock()
|
||||||
|
doc1.content = "Content of document 1"
|
||||||
|
doc1.title = "Title 1"
|
||||||
|
doc1.filename = "file1.txt"
|
||||||
|
|
||||||
|
doc2 = MagicMock()
|
||||||
|
doc2.content = "Content of document 2"
|
||||||
|
doc2.title = None
|
||||||
|
doc2.filename = "file2.txt"
|
||||||
|
|
||||||
|
doc3 = MagicMock()
|
||||||
|
doc3.content = None
|
||||||
|
doc3.title = None
|
||||||
|
doc3.filename = None
|
||||||
|
|
||||||
|
return [doc1, doc2, doc3]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("paperless.ai.ai_classifier.query_similar_documents")
|
||||||
|
def test_get_context_for_document(
|
||||||
|
mock_query_similar_documents,
|
||||||
|
mock_document,
|
||||||
|
mock_similar_documents,
|
||||||
|
):
|
||||||
|
mock_query_similar_documents.return_value = mock_similar_documents
|
||||||
|
|
||||||
|
result = get_context_for_document(mock_document, max_docs=2)
|
||||||
|
|
||||||
|
expected_result = (
|
||||||
|
"TITLE: Title 1\nContent of document 1\n\n"
|
||||||
|
"TITLE: file2.txt\nContent of document 2"
|
||||||
|
)
|
||||||
|
assert result == expected_result
|
||||||
|
mock_query_similar_documents.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||||
|
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||||
|
result = get_context_for_document(mock_document)
|
||||||
|
assert result == ""
|
||||||
|
@ -7,10 +7,15 @@ from documents.models import Document
|
|||||||
from paperless.ai.embedding import build_llm_index_text
|
from paperless.ai.embedding import build_llm_index_text
|
||||||
from paperless.ai.embedding import get_embedding_dim
|
from paperless.ai.embedding import get_embedding_dim
|
||||||
from paperless.ai.embedding import get_embedding_model
|
from paperless.ai.embedding import get_embedding_model
|
||||||
from paperless.ai.rag import get_context_for_document
|
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ai_config():
|
||||||
|
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
|
||||||
|
yield MockAIConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_document():
|
def mock_document():
|
||||||
doc = MagicMock(spec=Document)
|
doc = MagicMock(spec=Document)
|
||||||
@ -46,59 +51,6 @@ def mock_document():
|
|||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_similar_documents():
|
|
||||||
doc1 = MagicMock()
|
|
||||||
doc1.content = "Content of document 1"
|
|
||||||
doc1.title = "Title 1"
|
|
||||||
doc1.filename = "file1.txt"
|
|
||||||
|
|
||||||
doc2 = MagicMock()
|
|
||||||
doc2.content = "Content of document 2"
|
|
||||||
doc2.title = None
|
|
||||||
doc2.filename = "file2.txt"
|
|
||||||
|
|
||||||
doc3 = MagicMock()
|
|
||||||
doc3.content = None
|
|
||||||
doc3.title = None
|
|
||||||
doc3.filename = None
|
|
||||||
|
|
||||||
return [doc1, doc2, doc3]
|
|
||||||
|
|
||||||
|
|
||||||
@patch("paperless.ai.rag.query_similar_documents")
|
|
||||||
def test_get_context_for_document(
|
|
||||||
mock_query_similar_documents,
|
|
||||||
mock_document,
|
|
||||||
mock_similar_documents,
|
|
||||||
):
|
|
||||||
mock_query_similar_documents.return_value = mock_similar_documents
|
|
||||||
|
|
||||||
result = get_context_for_document(mock_document, max_docs=2)
|
|
||||||
|
|
||||||
expected_result = (
|
|
||||||
"TITLE: Title 1\nContent of document 1\n\n"
|
|
||||||
"TITLE: file2.txt\nContent of document 2"
|
|
||||||
)
|
|
||||||
assert result == expected_result
|
|
||||||
mock_query_similar_documents.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
|
||||||
with patch("paperless.ai.rag.query_similar_documents", return_value=[]):
|
|
||||||
result = get_context_for_document(mock_document)
|
|
||||||
assert result == ""
|
|
||||||
|
|
||||||
|
|
||||||
# Embedding
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_ai_config():
|
|
||||||
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
|
|
||||||
yield MockAIConfig
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_model_openai(mock_ai_config):
|
def test_get_embedding_model_openai(mock_ai_config):
|
||||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from django.test import override_settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
|
|
||||||
@ -162,15 +163,23 @@ def test_update_llm_index_no_documents(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
LLM_EMBEDDING_BACKEND="huggingface",
|
||||||
|
LLM_BACKEND="ollama",
|
||||||
|
)
|
||||||
def test_query_similar_documents(
|
def test_query_similar_documents(
|
||||||
temp_llm_index_dir,
|
temp_llm_index_dir,
|
||||||
real_document,
|
real_document,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
|
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||||
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
||||||
):
|
):
|
||||||
|
mock_storage.return_value = MagicMock()
|
||||||
|
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||||
|
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
mock_load_or_build_index.return_value = mock_index
|
mock_load_or_build_index.return_value = mock_index
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user