Refactor load_or_build_index

This commit is contained in:
shamoon 2025-04-28 21:39:39 -07:00
parent a957c5f053
commit 12e89088d7
No known key found for this signature in database
2 changed files with 30 additions and 41 deletions

View File

@ -76,11 +76,14 @@ def build_document_node(document: Document) -> list[BaseNode]:
return parser.get_nodes_from_documents([doc]) return parser.get_nodes_from_documents([doc])
def load_or_build_index(storage_context: StorageContext, embed_model, nodes=None): def load_or_build_index(nodes=None):
""" """
Load an existing VectorStoreIndex if present, Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty. or build a new one using provided nodes if storage is empty.
""" """
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context()
try: try:
return load_index_from_storage(storage_context=storage_context) return load_index_from_storage(storage_context=storage_context)
except ValueError as e: except ValueError as e:
@ -115,10 +118,6 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
""" """
Rebuild or update the LLM index. Rebuild or update the LLM index.
""" """
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=rebuild)
nodes = [] nodes = []
documents = Document.objects.all() documents = Document.objects.all()
@ -127,12 +126,15 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
return return
if rebuild: if rebuild:
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=rebuild)
# Rebuild index from scratch # Rebuild index from scratch
for document in tqdm.tqdm(documents, disable=progress_bar_disable): for document in tqdm.tqdm(documents, disable=progress_bar_disable):
document_nodes = build_document_node(document) document_nodes = build_document_node(document)
nodes.extend(document_nodes) nodes.extend(document_nodes)
VectorStoreIndex( index = VectorStoreIndex(
nodes=nodes, nodes=nodes,
storage_context=storage_context, storage_context=storage_context,
embed_model=embed_model, embed_model=embed_model,
@ -140,7 +142,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
) )
else: else:
# Update existing index # Update existing index
index = load_or_build_index(storage_context, embed_model) index = load_or_build_index()
all_node_ids = list(index.docstore.docs.keys()) all_node_ids = list(index.docstore.docs.keys())
existing_nodes = { existing_nodes = {
node.metadata.get("document_id"): node node.metadata.get("document_id"): node
@ -174,7 +176,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
else: else:
logger.info("No changes detected, skipping llm index rebuild.") logger.info("No changes detected, skipping llm index rebuild.")
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_add_or_update_document(document: Document): def llm_index_add_or_update_document(document: Document):
@ -182,46 +184,33 @@ def llm_index_add_or_update_document(document: Document):
Adds or updates a document in the LLM index. Adds or updates a document in the LLM index.
If the document already exists, it will be replaced. If the document already exists, it will be replaced.
""" """
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=False)
new_nodes = build_document_node(document) new_nodes = build_document_node(document)
index = load_or_build_index(storage_context, embed_model, nodes=new_nodes) index = load_or_build_index(nodes=new_nodes)
remove_document_docstore_nodes(document, index) remove_document_docstore_nodes(document, index)
index.insert_nodes(new_nodes) index.insert_nodes(new_nodes)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_remove_document(document: Document): def llm_index_remove_document(document: Document):
""" """
Removes a document from the LLM index. Removes a document from the LLM index.
""" """
embed_model = get_embedding_model() index = load_or_build_index()
llama_settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=False)
index = load_or_build_index(storage_context, embed_model)
remove_document_docstore_nodes(document, index) remove_document_docstore_nodes(document, index)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]: def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
""" """
Runs a similarity query and returns top-k similar Document objects. Runs a similarity query and returns top-k similar Document objects.
""" """
storage_context = get_or_create_storage_context(rebuild=False) index = load_or_build_index()
embed_model = get_embedding_model()
llama_settings.embed_model = embed_model
index = load_or_build_index(storage_context, embed_model)
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k) retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
query_text = (document.title or "") + "\n" + (document.content or "") query_text = (document.title or "") + "\n" + (document.content or "")

View File

@ -131,12 +131,13 @@ def test_get_or_create_storage_context_raises_exception(
indexing.get_or_create_storage_context(rebuild=False) indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_load_or_build_index_builds_when_nodes_given( def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir, temp_llm_index_dir,
mock_embed_model,
real_document, real_document,
): ):
storage_context = MagicMock()
with patch( with patch(
"paperless.ai.indexing.load_index_from_storage", "paperless.ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"), side_effect=ValueError("Index not found"),
@ -145,25 +146,26 @@ def test_load_or_build_index_builds_when_nodes_given(
"paperless.ai.indexing.VectorStoreIndex", "paperless.ai.indexing.VectorStoreIndex",
return_value=MagicMock(), return_value=MagicMock(),
) as mock_index_cls: ) as mock_index_cls:
indexing.load_or_build_index( with patch(
storage_context, "paperless.ai.indexing.get_or_create_storage_context",
mock_embed_model, return_value=MagicMock(),
nodes=[indexing.build_document_node(real_document)], ) as mock_storage:
) mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_index_cls.assert_called_once() indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
)
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes( def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir, temp_llm_index_dir,
mock_embed_model,
): ):
storage_context = MagicMock()
with patch( with patch(
"paperless.ai.indexing.load_index_from_storage", "paperless.ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"), side_effect=ValueError("Index not found"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
indexing.load_or_build_index(storage_context, mock_embed_model) indexing.load_or_build_index()
@pytest.mark.django_db @pytest.mark.django_db
@ -185,13 +187,11 @@ def test_remove_document_deletes_node_from_docstore(
mock_embed_model, mock_embed_model,
): ):
indexing.update_llm_index(rebuild=True) indexing.update_llm_index(rebuild=True)
storage_context = indexing.get_or_create_storage_context() index = indexing.load_or_build_index()
index = indexing.load_or_build_index(storage_context, mock_embed_model)
assert len(index.docstore.docs) == 1 assert len(index.docstore.docs) == 1
indexing.llm_index_remove_document(real_document) indexing.llm_index_remove_document(real_document)
storage_context = indexing.get_or_create_storage_context() index = indexing.load_or_build_index()
index = indexing.load_or_build_index(storage_context, mock_embed_model)
assert len(index.docstore.docs) == 0 assert len(index.docstore.docs) == 0