mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Refactor load_or_build_index
This commit is contained in:
parent
a957c5f053
commit
12e89088d7
@ -76,11 +76,14 @@ def build_document_node(document: Document) -> list[BaseNode]:
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(storage_context: StorageContext, embed_model, nodes=None):
|
||||
def load_or_build_index(nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
try:
|
||||
return load_index_from_storage(storage_context=storage_context)
|
||||
except ValueError as e:
|
||||
@ -115,10 +118,6 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||
"""
|
||||
Rebuild or update the LLM index.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=rebuild)
|
||||
|
||||
nodes = []
|
||||
|
||||
documents = Document.objects.all()
|
||||
@ -127,12 +126,15 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||
return
|
||||
|
||||
if rebuild:
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=rebuild)
|
||||
# Rebuild index from scratch
|
||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
||||
document_nodes = build_document_node(document)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
VectorStoreIndex(
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
@ -140,7 +142,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||
)
|
||||
else:
|
||||
# Update existing index
|
||||
index = load_or_build_index(storage_context, embed_model)
|
||||
index = load_or_build_index()
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = {
|
||||
node.metadata.get("document_id"): node
|
||||
@ -174,7 +176,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||
else:
|
||||
logger.info("No changes detected, skipping llm index rebuild.")
|
||||
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
@ -182,46 +184,33 @@ def llm_index_add_or_update_document(document: Document):
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
|
||||
new_nodes = build_document_node(document)
|
||||
|
||||
index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
|
||||
index = load_or_build_index(nodes=new_nodes)
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_remove_document(document: Document):
|
||||
"""
|
||||
Removes a document from the LLM index.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.embed_model = embed_model
|
||||
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
|
||||
index = load_or_build_index(storage_context, embed_model)
|
||||
index = load_or_build_index()
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.embed_model = embed_model
|
||||
index = load_or_build_index(storage_context, embed_model)
|
||||
index = load_or_build_index()
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||
|
||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||
|
@ -131,12 +131,13 @@ def test_get_or_create_storage_context_raises_exception(
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
real_document,
|
||||
):
|
||||
storage_context = MagicMock()
|
||||
with patch(
|
||||
"paperless.ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
@ -145,25 +146,26 @@ def test_load_or_build_index_builds_when_nodes_given(
|
||||
"paperless.ai.indexing.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls:
|
||||
indexing.load_or_build_index(
|
||||
storage_context,
|
||||
mock_embed_model,
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
with patch(
|
||||
"paperless.ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage:
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
storage_context = MagicMock()
|
||||
with patch(
|
||||
"paperless.ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index(storage_context, mock_embed_model)
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@ -185,13 +187,11 @@ def test_remove_document_deletes_node_from_docstore(
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
storage_context = indexing.get_or_create_storage_context()
|
||||
index = indexing.load_or_build_index(storage_context, mock_embed_model)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
storage_context = indexing.get_or_create_storage_context()
|
||||
index = indexing.load_or_build_index(storage_context, mock_embed_model)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 0
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user