mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-26 01:16:16 +00:00
Support dynamic determining of embedding dimensions
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import paperless_ai.embedding as embedding
|
||||
from documents.models import Document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
@@ -16,6 +18,14 @@ def mock_ai_config():
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = embedding.settings.LLM_INDEX_DIR
|
||||
embedding.settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
embedding.settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
@@ -91,25 +101,51 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_dim_openai(mock_ai_config):
|
||||
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 1536
|
||||
class DummyEmbedding:
|
||||
def get_text_embedding(self, text):
|
||||
return [0.0] * 7
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
return_value=DummyEmbedding(),
|
||||
) as mock_get:
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
|
||||
assert dim == 7
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||
|
||||
|
||||
def test_get_embedding_dim_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
|
||||
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 384
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||
)
|
||||
|
||||
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embedding_dim_unknown_model(mock_ai_config):
|
||||
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Embedding model changed from old to text-embedding-3-small",
|
||||
):
|
||||
get_embedding_dim()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user