Support dynamic determining of embedding dimensions

This commit is contained in:
shamoon
2025-08-08 22:41:14 -04:00
parent 541108688a
commit 7bd9b385aa
5 changed files with 121 additions and 19 deletions

View File

@@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -29,9 +30,16 @@ def real_document(db):
@pytest.fixture
def mock_embed_model():
with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
@@ -72,6 +80,36 @@ def test_update_llm_index(
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_removes_meta(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
# Pre-create a meta.json with incorrect data
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 1}),
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
from paperless.config import AIConfig
config = AIConfig()
expected_model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai"
else "sentence-transformers/all-MiniLM-L6-v2"
)
assert meta == {"embedding_model": expected_model, "dim": 384}
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
@@ -137,6 +175,7 @@ def test_get_or_create_storage_context_raises_exception(
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
with (
patch(