mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-12 00:19:48 +00:00
Support dynamic determining of embedding dimensions
This commit is contained in:
@@ -570,6 +570,8 @@ def llmindex_index(
|
|||||||
|
|
||||||
task.date_done = timezone.now()
|
task.date_done = timezone.now()
|
||||||
task.save(update_fields=["status", "result", "date_done"])
|
task.save(update_fields=["status", "result", "date_done"])
|
||||||
|
else:
|
||||||
|
logger.info("LLM index is disabled, skipping update.")
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
|
@@ -1,3 +1,10 @@
|
|||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
@@ -7,11 +14,6 @@ from documents.models import Note
|
|||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
|
||||||
EMBEDDING_DIMENSIONS = {
|
|
||||||
"text-embedding-3-small": 1536,
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model() -> BaseEmbedding:
|
def get_embedding_model() -> BaseEmbedding:
|
||||||
config = AIConfig()
|
config = AIConfig()
|
||||||
@@ -34,15 +36,36 @@ def get_embedding_model() -> BaseEmbedding:
|
|||||||
|
|
||||||
|
|
||||||
def get_embedding_dim() -> int:
|
def get_embedding_dim() -> int:
|
||||||
|
"""
|
||||||
|
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||||
|
from a dummy embedding and stores it for future use.
|
||||||
|
"""
|
||||||
config = AIConfig()
|
config = AIConfig()
|
||||||
model = config.llm_embedding_model or (
|
model = config.llm_embedding_model or (
|
||||||
"text-embedding-3-small"
|
"text-embedding-3-small"
|
||||||
if config.llm_embedding_backend == "openai"
|
if config.llm_embedding_backend == "openai"
|
||||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
if model not in EMBEDDING_DIMENSIONS:
|
|
||||||
raise ValueError(f"Unknown embedding model: {model}")
|
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||||
return EMBEDDING_DIMENSIONS[model]
|
if meta_path.exists():
|
||||||
|
with meta_path.open() as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
if meta.get("embedding_model") != model:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||||
|
"You must rebuild the index.",
|
||||||
|
)
|
||||||
|
return meta["dim"]
|
||||||
|
|
||||||
|
embedding_model = get_embedding_model()
|
||||||
|
test_embed = embedding_model.get_text_embedding("test")
|
||||||
|
dim = len(test_embed)
|
||||||
|
|
||||||
|
with meta_path.open("w") as f:
|
||||||
|
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||||
|
|
||||||
|
return dim
|
||||||
|
|
||||||
|
|
||||||
def build_llm_index_text(doc: Document) -> str:
|
def build_llm_index_text(doc: Document) -> str:
|
||||||
|
@@ -138,6 +138,8 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
if rebuild or not vector_store_file_exists():
|
if rebuild or not vector_store_file_exists():
|
||||||
|
# remove meta.json to force re-detection of embedding dim
|
||||||
|
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||||
# Rebuild index from scratch
|
# Rebuild index from scratch
|
||||||
logger.info("Rebuilding LLM index.")
|
logger.info("Rebuilding LLM index.")
|
||||||
embed_model = get_embedding_model()
|
embed_model = get_embedding_model()
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -29,9 +30,16 @@ def real_document(db):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embed_model():
|
def mock_embed_model():
|
||||||
with patch("paperless_ai.indexing.get_embedding_model") as mock:
|
fake = FakeEmbedding()
|
||||||
mock.return_value = FakeEmbedding()
|
with (
|
||||||
yield mock
|
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||||
|
patch(
|
||||||
|
"paperless_ai.embedding.get_embedding_model",
|
||||||
|
) as mock_embedding,
|
||||||
|
):
|
||||||
|
mock_index.return_value = fake
|
||||||
|
mock_embedding.return_value = fake
|
||||||
|
yield mock_index
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbedding(BaseEmbedding):
|
class FakeEmbedding(BaseEmbedding):
|
||||||
@@ -72,6 +80,36 @@ def test_update_llm_index(
|
|||||||
assert any(temp_llm_index_dir.glob("*.json"))
|
assert any(temp_llm_index_dir.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_update_llm_index_removes_meta(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
real_document,
|
||||||
|
mock_embed_model,
|
||||||
|
):
|
||||||
|
# Pre-create a meta.json with incorrect data
|
||||||
|
(temp_llm_index_dir / "meta.json").write_text(
|
||||||
|
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("documents.models.Document.objects.all") as mock_all:
|
||||||
|
mock_queryset = MagicMock()
|
||||||
|
mock_queryset.exists.return_value = True
|
||||||
|
mock_queryset.__iter__.return_value = iter([real_document])
|
||||||
|
mock_all.return_value = mock_queryset
|
||||||
|
indexing.update_llm_index(rebuild=True)
|
||||||
|
|
||||||
|
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||||
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
|
config = AIConfig()
|
||||||
|
expected_model = config.llm_embedding_model or (
|
||||||
|
"text-embedding-3-small"
|
||||||
|
if config.llm_embedding_backend == "openai"
|
||||||
|
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
assert meta == {"embedding_model": expected_model, "dim": 384}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_update_llm_index_partial_update(
|
def test_update_llm_index_partial_update(
|
||||||
temp_llm_index_dir,
|
temp_llm_index_dir,
|
||||||
@@ -137,6 +175,7 @@ def test_get_or_create_storage_context_raises_exception(
|
|||||||
def test_load_or_build_index_builds_when_nodes_given(
|
def test_load_or_build_index_builds_when_nodes_given(
|
||||||
temp_llm_index_dir,
|
temp_llm_index_dir,
|
||||||
real_document,
|
real_document,
|
||||||
|
mock_embed_model,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
|
import json
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import paperless_ai.embedding as embedding
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
from paperless_ai.embedding import build_llm_index_text
|
from paperless_ai.embedding import build_llm_index_text
|
||||||
@@ -16,6 +18,14 @@ def mock_ai_config():
|
|||||||
yield MockAIConfig
|
yield MockAIConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_llm_index_dir(tmp_path):
|
||||||
|
original_dir = embedding.settings.LLM_INDEX_DIR
|
||||||
|
embedding.settings.LLM_INDEX_DIR = tmp_path
|
||||||
|
yield tmp_path
|
||||||
|
embedding.settings.LLM_INDEX_DIR = original_dir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_document():
|
def mock_document():
|
||||||
doc = MagicMock(spec=Document)
|
doc = MagicMock(spec=Document)
|
||||||
@@ -91,25 +101,51 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
|||||||
get_embedding_model()
|
get_embedding_model()
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_dim_openai(mock_ai_config):
|
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||||
mock_ai_config.return_value.llm_embedding_model = None
|
mock_ai_config.return_value.llm_embedding_model = None
|
||||||
|
|
||||||
assert get_embedding_dim() == 1536
|
class DummyEmbedding:
|
||||||
|
def get_text_embedding(self, text):
|
||||||
|
return [0.0] * 7
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"paperless_ai.embedding.get_embedding_model",
|
||||||
|
return_value=DummyEmbedding(),
|
||||||
|
) as mock_get:
|
||||||
|
dim = get_embedding_dim()
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
|
||||||
|
assert dim == 7
|
||||||
|
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||||
|
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_dim_huggingface(mock_ai_config):
|
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||||
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
|
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||||
mock_ai_config.return_value.llm_embedding_model = None
|
mock_ai_config.return_value.llm_embedding_model = None
|
||||||
|
|
||||||
assert get_embedding_dim() == 384
|
(temp_llm_index_dir / "meta.json").write_text(
|
||||||
|
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||||
|
assert get_embedding_dim() == 11
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_dim_unknown_model(mock_ai_config):
|
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||||
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
|
mock_ai_config.return_value.llm_embedding_model = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
|
(temp_llm_index_dir / "meta.json").write_text(
|
||||||
|
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError,
|
||||||
|
match="Embedding model changed from old to text-embedding-3-small",
|
||||||
|
):
|
||||||
get_embedding_dim()
|
get_embedding_dim()
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user