From 12e89088d7ccded2e3fa0e3c3c97d9823d39e5b6 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 28 Apr 2025 21:39:39 -0700 Subject: [PATCH] Refactor load_or_build_index --- src/paperless/ai/indexing.py | 41 +++++++++---------------- src/paperless/tests/test_ai_indexing.py | 30 +++++++++--------- 2 files changed, 30 insertions(+), 41 deletions(-) diff --git a/src/paperless/ai/indexing.py b/src/paperless/ai/indexing.py index 840d58f37..9a32409ca 100644 --- a/src/paperless/ai/indexing.py +++ b/src/paperless/ai/indexing.py @@ -76,11 +76,14 @@ def build_document_node(document: Document) -> list[BaseNode]: return parser.get_nodes_from_documents([doc]) -def load_or_build_index(storage_context: StorageContext, embed_model, nodes=None): +def load_or_build_index(nodes=None): """ Load an existing VectorStoreIndex if present, or build a new one using provided nodes if storage is empty. """ + embed_model = get_embedding_model() + llama_settings.Settings.embed_model = embed_model + storage_context = get_or_create_storage_context() try: return load_index_from_storage(storage_context=storage_context) except ValueError as e: @@ -115,10 +118,6 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False): """ Rebuild or update the LLM index. """ - embed_model = get_embedding_model() - llama_settings.Settings.embed_model = embed_model - storage_context = get_or_create_storage_context(rebuild=rebuild) - nodes = [] documents = Document.objects.all() @@ -127,12 +126,15 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False): return if rebuild: + embed_model = get_embedding_model() + llama_settings.Settings.embed_model = embed_model + storage_context = get_or_create_storage_context(rebuild=rebuild) # Rebuild index from scratch for document in tqdm.tqdm(documents, disable=progress_bar_disable): document_nodes = build_document_node(document) nodes.extend(document_nodes) - VectorStoreIndex( + index = VectorStoreIndex( nodes=nodes, storage_context=storage_context, embed_model=embed_model, @@ -140,7 +142,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False): ) else: # Update existing index - index = load_or_build_index(storage_context, embed_model) + index = load_or_build_index() all_node_ids = list(index.docstore.docs.keys()) existing_nodes = { node.metadata.get("document_id"): node @@ -174,7 +176,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False): else: logger.info("No changes detected, skipping llm index rebuild.") - storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) + index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) def llm_index_add_or_update_document(document: Document): @@ -182,46 +184,33 @@ def llm_index_add_or_update_document(document: Document): Adds or updates a document in the LLM index. If the document already exists, it will be replaced. """ - embed_model = get_embedding_model() - llama_settings.Settings.embed_model = embed_model - - storage_context = get_or_create_storage_context(rebuild=False) - new_nodes = build_document_node(document) - index = load_or_build_index(storage_context, embed_model, nodes=new_nodes) + index = load_or_build_index(nodes=new_nodes) remove_document_docstore_nodes(document, index) index.insert_nodes(new_nodes) - storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) + index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) def llm_index_remove_document(document: Document): """ Removes a document from the LLM index. """ - embed_model = get_embedding_model() - llama_settings.embed_model = embed_model - - storage_context = get_or_create_storage_context(rebuild=False) - - index = load_or_build_index(storage_context, embed_model) + index = load_or_build_index() remove_document_docstore_nodes(document, index) - storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) + index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]: """ Runs a similarity query and returns top-k similar Document objects. """ - storage_context = get_or_create_storage_context(rebuild=False) - embed_model = get_embedding_model() - llama_settings.embed_model = embed_model - index = load_or_build_index(storage_context, embed_model) + index = load_or_build_index() retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k) query_text = (document.title or "") + "\n" + (document.content or "") diff --git a/src/paperless/tests/test_ai_indexing.py b/src/paperless/tests/test_ai_indexing.py index d7b83316d..101fdfb9e 100644 --- a/src/paperless/tests/test_ai_indexing.py +++ b/src/paperless/tests/test_ai_indexing.py @@ -131,12 +131,13 @@ def test_get_or_create_storage_context_raises_exception( indexing.get_or_create_storage_context(rebuild=False) +@override_settings( + LLM_EMBEDDING_BACKEND="huggingface", +) def test_load_or_build_index_builds_when_nodes_given( temp_llm_index_dir, - mock_embed_model, real_document, ): - storage_context = MagicMock() with patch( "paperless.ai.indexing.load_index_from_storage", side_effect=ValueError("Index not found"), @@ -145,25 +146,26 @@ def test_load_or_build_index_builds_when_nodes_given( "paperless.ai.indexing.VectorStoreIndex", return_value=MagicMock(), ) as mock_index_cls: - indexing.load_or_build_index( - storage_context, - mock_embed_model, - nodes=[indexing.build_document_node(real_document)], - ) - mock_index_cls.assert_called_once() + with patch( + "paperless.ai.indexing.get_or_create_storage_context", + return_value=MagicMock(), + ) as mock_storage: + mock_storage.return_value.persist_dir = temp_llm_index_dir + indexing.load_or_build_index( + nodes=[indexing.build_document_node(real_document)], + ) + mock_index_cls.assert_called_once() def test_load_or_build_index_raises_exception_when_no_nodes( temp_llm_index_dir, - mock_embed_model, ): - storage_context = MagicMock() with patch( "paperless.ai.indexing.load_index_from_storage", side_effect=ValueError("Index not found"), ): with pytest.raises(Exception): - indexing.load_or_build_index(storage_context, mock_embed_model) + indexing.load_or_build_index() @pytest.mark.django_db @@ -185,13 +187,11 @@ def test_remove_document_deletes_node_from_docstore( mock_embed_model, ): indexing.update_llm_index(rebuild=True) - storage_context = indexing.get_or_create_storage_context() - index = indexing.load_or_build_index(storage_context, mock_embed_model) + index = indexing.load_or_build_index() assert len(index.docstore.docs) == 1 indexing.llm_index_remove_document(real_document) - storage_context = indexing.get_or_create_storage_context() - index = indexing.load_or_build_index(storage_context, mock_embed_model) + index = indexing.load_or_build_index() assert len(index.docstore.docs) == 0