Support dynamic determining of embedding dimensions

This commit is contained in:
shamoon
2025-08-08 22:41:14 -04:00
parent 541108688a
commit 7bd9b385aa
5 changed files with 121 additions and 19 deletions

View File

@@ -570,6 +570,8 @@ def llmindex_index(
task.date_done = timezone.now()
task.save(update_fields=["status", "result", "date_done"])
else:
logger.info("LLM index is disabled, skipping update.")
@shared_task

View File

@@ -1,3 +1,10 @@
import json
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
from django.conf import settings
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
@@ -7,11 +14,6 @@ from documents.models import Note
from paperless.config import AIConfig
from paperless.models import LLMEmbeddingBackend
EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"sentence-transformers/all-MiniLM-L6-v2": 384,
}
def get_embedding_model() -> BaseEmbedding:
config = AIConfig()
@@ -34,15 +36,36 @@ def get_embedding_model() -> BaseEmbedding:
def get_embedding_dim() -> int:
"""
Loads embedding dimension from meta.json if available, otherwise infers it
from a dummy embedding and stores it for future use.
"""
config = AIConfig()
model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai"
else "sentence-transformers/all-MiniLM-L6-v2"
)
if model not in EMBEDDING_DIMENSIONS:
raise ValueError(f"Unknown embedding model: {model}")
return EMBEDDING_DIMENSIONS[model]
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
if meta_path.exists():
with meta_path.open() as f:
meta = json.load(f)
if meta.get("embedding_model") != model:
raise RuntimeError(
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
"You must rebuild the index.",
)
return meta["dim"]
embedding_model = get_embedding_model()
test_embed = embedding_model.get_text_embedding("test")
dim = len(test_embed)
with meta_path.open("w") as f:
json.dump({"embedding_model": model, "dim": dim}, f)
return dim
def build_llm_index_text(doc: Document) -> str:

View File

@@ -138,6 +138,8 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
return msg
if rebuild or not vector_store_file_exists():
# remove meta.json to force re-detection of embedding dim
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
# Rebuild index from scratch
logger.info("Rebuilding LLM index.")
embed_model = get_embedding_model()

View File

@@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -29,9 +30,16 @@ def real_document(db):
@pytest.fixture
def mock_embed_model():
with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
@@ -72,6 +80,36 @@ def test_update_llm_index(
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_removes_meta(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
# Pre-create a meta.json with incorrect data
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 1}),
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
from paperless.config import AIConfig
config = AIConfig()
expected_model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai"
else "sentence-transformers/all-MiniLM-L6-v2"
)
assert meta == {"embedding_model": expected_model, "dim": 384}
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
@@ -137,6 +175,7 @@ def test_get_or_create_storage_context_raises_exception(
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
with (
patch(

View File

@@ -1,8 +1,10 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import paperless_ai.embedding as embedding
from documents.models import Document
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
@@ -16,6 +18,14 @@ def mock_ai_config():
yield MockAIConfig
@pytest.fixture
def temp_llm_index_dir(tmp_path):
original_dir = embedding.settings.LLM_INDEX_DIR
embedding.settings.LLM_INDEX_DIR = tmp_path
yield tmp_path
embedding.settings.LLM_INDEX_DIR = original_dir
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
@@ -91,25 +101,51 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
get_embedding_model()
def test_get_embedding_dim_openai(mock_ai_config):
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 1536
class DummyEmbedding:
def get_text_embedding(self, text):
return [0.0] * 7
with patch(
"paperless_ai.embedding.get_embedding_model",
return_value=DummyEmbedding(),
) as mock_get:
dim = get_embedding_dim()
mock_get.assert_called_once()
assert dim == 7
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
def test_get_embedding_dim_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 384
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
)
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
assert get_embedding_dim() == 11
mock_get.assert_not_called()
def test_get_embedding_dim_unknown_model(mock_ai_config):
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
mock_ai_config.return_value.llm_embedding_model = None
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 11}),
)
with pytest.raises(
RuntimeError,
match="Embedding model changed from old to text-embedding-3-small",
):
get_embedding_dim()