Incremental llm index update, add scheduled llm index task

This commit is contained in:
shamoon 2025-04-28 10:29:07 -07:00
parent 8852965117
commit 6ee6a37816
No known key found for this signature in database
8 changed files with 154 additions and 48 deletions

View File

@ -1763,3 +1763,10 @@ current backend. This setting is required to be set to use the AI features.
: The URL to use for the AI backend. This is required for the Ollama backend only. : The URL to use for the AI backend. This is required for the Ollama backend only.
Defaults to None. Defaults to None.
#### [`PAPERLESS_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_LLM_INDEX_TASK_CRON) {#PAPERLESS_LLM_INDEX_TASK_CRON}
: Configures the schedule to update the AI embeddings for all documents. Only performed if
AI is enabled and the LLM embedding backend is set.
Defaults to `10 2 * * *`, once per day.

View File

@ -2,20 +2,20 @@ from django.core.management import BaseCommand
from django.db import transaction from django.db import transaction
from documents.management.commands.mixins import ProgressBarMixin from documents.management.commands.mixins import ProgressBarMixin
from documents.tasks import llm_index_rebuild from documents.tasks import llmindex_index
class Command(ProgressBarMixin, BaseCommand): class Command(ProgressBarMixin, BaseCommand):
help = "Manages the LLM-based vector index for Paperless." help = "Manages the LLM-based vector index for Paperless."
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("command", choices=["rebuild"]) parser.add_argument("command", choices=["rebuild", "update"])
self.add_argument_progress_bar_mixin(parser) self.add_argument_progress_bar_mixin(parser)
def handle(self, *args, **options): def handle(self, *args, **options):
self.handle_progress_bar_mixin(**options) self.handle_progress_bar_mixin(**options)
with transaction.atomic(): with transaction.atomic():
llm_index_rebuild( llmindex_index(
progress_bar_disable=self.no_progress_bar, progress_bar_disable=self.no_progress_bar,
rebuild=options["command"] == "rebuild", rebuild=options["command"] == "rebuild",
) )

View File

@ -54,7 +54,7 @@ from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows from documents.signals.handlers import run_workflows
from paperless.ai.indexing import llm_index_add_or_update_document from paperless.ai.indexing import llm_index_add_or_update_document
from paperless.ai.indexing import llm_index_remove_document from paperless.ai.indexing import llm_index_remove_document
from paperless.ai.indexing import rebuild_llm_index from paperless.ai.indexing import update_llm_index
from paperless.config import AIConfig from paperless.config import AIConfig
if settings.AUDIT_LOG_ENABLED: if settings.AUDIT_LOG_ENABLED:
@ -511,11 +511,14 @@ def check_scheduled_workflows():
) )
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False): @shared_task
rebuild_llm_index( def llmindex_index(*, progress_bar_disable=False, rebuild=False):
progress_bar_disable=progress_bar_disable, ai_config = AIConfig()
rebuild=rebuild, if ai_config.llm_index_enabled():
) update_llm_index(
progress_bar_disable=progress_bar_disable,
rebuild=rebuild,
)
@shared_task @shared_task
@ -531,6 +534,6 @@ def remove_document_from_llm_index(document):
# TODO: schedule to run periodically # TODO: schedule to run periodically
@shared_task @shared_task
def rebuild_llm_index_task(): def rebuild_llm_index_task():
from paperless.ai.indexing import rebuild_llm_index from paperless.ai.indexing import update_llm_index
rebuild_llm_index(rebuild=True) update_llm_index(rebuild=True)

View File

@ -8,6 +8,7 @@ from django.conf import settings
from llama_index.core import Document as LlamaDocument from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
from llama_index.core.node_parser import SimpleNodeParser from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode from llama_index.core.schema import BaseNode
@ -70,7 +71,7 @@ def build_document_node(document: Document) -> list[BaseNode]:
text = build_llm_index_text(document) text = build_llm_index_text(document)
metadata = { metadata = {
"document_id": document.id, "document_id": str(document.id),
"title": document.title, "title": document.title,
"tags": [t.name for t in document.tags.all()], "tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name "correspondent": document.correspondent.name
@ -81,32 +82,29 @@ def build_document_node(document: Document) -> list[BaseNode]:
else None, else None,
"created": document.created.isoformat() if document.created else None, "created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None, "added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
} }
doc = LlamaDocument(text=text, metadata=metadata) doc = LlamaDocument(text=text, metadata=metadata)
parser = SimpleNodeParser() parser = SimpleNodeParser()
return parser.get_nodes_from_documents([doc]) return parser.get_nodes_from_documents([doc])
def load_or_build_index(storage_context, embed_model, nodes=None): def load_or_build_index(storage_context: StorageContext, embed_model, nodes=None):
""" """
Load an existing VectorStoreIndex if present, Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty. or build a new one using provided nodes if storage is empty.
""" """
try: try:
return load_index_from_storage(storage_context=storage_context)
except ValueError as e:
logger.debug("Failed to load index from storage: %s", e)
if not nodes:
return None
return VectorStoreIndex( return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context, storage_context=storage_context,
embed_model=embed_model, embed_model=embed_model,
) )
except ValueError as e:
if "One of nodes, objects, or index_struct must be provided" in str(e):
if not nodes:
return None
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
raise
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex): def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
@ -125,31 +123,74 @@ def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
index.docstore.delete_document(node_id) index.docstore.delete_document(node_id)
def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False): def update_llm_index(*, progress_bar_disable=False, rebuild=False):
""" """
Rebuilds the LLM index from scratch. Rebuild or update the LLM index.
""" """
embed_model = get_embedding_model() embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=rebuild) storage_context = get_or_create_storage_context(rebuild=rebuild)
nodes = [] nodes = []
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable): documents = Document.objects.all()
document_nodes = build_document_node(document) if not documents.exists():
nodes.extend(document_nodes) logger.warning("No documents found to index.")
return
if not nodes: if rebuild:
raise RuntimeError( # Rebuild index from scratch
"No nodes to index — check that documents are available and have content.", for document in tqdm.tqdm(documents, disable=progress_bar_disable):
document_nodes = build_document_node(document)
nodes.extend(document_nodes)
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=not progress_bar_disable,
) )
else:
# Update existing index
index = load_or_build_index(storage_context, embed_model)
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = {
node.metadata.get("document_id"): node
for node in index.docstore.get_nodes(all_node_ids)
}
node_ids_to_remove = []
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
doc_id = str(document.id)
document_modified = document.modified.isoformat()
if doc_id in existing_nodes:
node = existing_nodes[doc_id]
node_modified = node.metadata.get("modified")
if node_modified == document_modified:
continue
node_ids_to_remove.append(node.node_id)
nodes.extend(build_document_node(document))
else:
# New document, add it
nodes.extend(build_document_node(document))
if node_ids_to_remove or nodes:
logger.info(
"Updating LLM index with %d new nodes and removing %d old nodes.",
len(nodes),
len(node_ids_to_remove),
)
if node_ids_to_remove:
index.delete_nodes(node_ids_to_remove)
if nodes:
index.insert_nodes(nodes)
else:
logger.info("No changes detected, skipping llm index rebuild.")
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
@ -187,6 +228,7 @@ def llm_index_remove_document(document: Document):
storage_context = get_or_create_storage_context(rebuild=False) storage_context = get_or_create_storage_context(rebuild=False)
index = load_or_build_index(storage_context, embed_model) index = load_or_build_index(storage_context, embed_model)
if index is None: if index is None:
return return

View File

@ -201,6 +201,4 @@ class AIConfig(BaseConfig):
self.llm_url = app_config.llm_url or settings.LLM_URL self.llm_url = app_config.llm_url or settings.LLM_URL
def llm_index_enabled(self) -> bool: def llm_index_enabled(self) -> bool:
return ( return self.ai_enabled and self.llm_embedding_backend
self.ai_enabled and self.llm_embedding_backend and self.llm_embedding_model
)

View File

@ -227,6 +227,20 @@ def _parse_beat_schedule() -> dict:
"expires": 59.0 * 60.0, "expires": 59.0 * 60.0,
}, },
}, },
{
"name": "Rebuild LLM index",
"env_key": "PAPERLESS_LLM_INDEX_TASK_CRON",
# Default daily at 02:10
"env_default": "10 2 * * *",
"task": "documents.tasks.llmindex_index",
"options": {
# 1 hour before default schedule sends again
"expires": 23.0 * 60.0 * 60.0,
"kwargs": {
"progress_bar_disable": True,
},
},
},
] ]
for task in tasks: for task in tasks:
# Either get the environment setting or use the default # Either get the environment setting or use the default

View File

@ -53,7 +53,7 @@ class FakeEmbedding(BaseEmbedding):
def test_build_document_node(real_document): def test_build_document_node(real_document):
nodes = indexing.build_document_node(real_document) nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0 assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == real_document.id assert nodes[0].metadata["document_id"] == str(real_document.id)
@pytest.mark.django_db @pytest.mark.django_db
@ -63,8 +63,11 @@ def test_rebuild_llm_index(
mock_embed_model, mock_embed_model,
): ):
with patch("documents.models.Document.objects.all") as mock_all: with patch("documents.models.Document.objects.all") as mock_all:
mock_all.return_value = [real_document] mock_queryset = MagicMock()
indexing.rebuild_llm_index(rebuild=True) mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
assert any(temp_llm_index_dir.glob("*.json")) assert any(temp_llm_index_dir.glob("*.json"))
@ -75,7 +78,7 @@ def test_add_or_update_document_updates_existing_entry(
real_document, real_document,
mock_embed_model, mock_embed_model,
): ):
indexing.rebuild_llm_index(rebuild=True) indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document) indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json")) assert any(temp_llm_index_dir.glob("*.json"))
@ -87,7 +90,7 @@ def test_remove_document_deletes_node_from_docstore(
real_document, real_document,
mock_embed_model, mock_embed_model,
): ):
indexing.rebuild_llm_index(rebuild=True) indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document) indexing.llm_index_add_or_update_document(real_document)
indexing.llm_index_remove_document(real_document) indexing.llm_index_remove_document(real_document)
@ -100,10 +103,17 @@ def test_rebuild_llm_index_no_documents(
mock_embed_model, mock_embed_model,
): ):
with patch("documents.models.Document.objects.all") as mock_all: with patch("documents.models.Document.objects.all") as mock_all:
mock_all.return_value = [] mock_queryset = MagicMock()
mock_queryset.exists.return_value = False
mock_queryset.__iter__.return_value = iter([])
mock_all.return_value = mock_queryset
with pytest.raises(RuntimeError, match="No nodes to index"): # check log message
indexing.rebuild_llm_index(rebuild=True) with patch("paperless.ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with(
"No documents found to index.",
)
def test_query_similar_documents( def test_query_similar_documents(

View File

@ -158,6 +158,7 @@ class TestCeleryScheduleParsing(TestCase):
SANITY_EXPIRE_TIME = ((7.0 * 24.0) - 1.0) * 60.0 * 60.0 SANITY_EXPIRE_TIME = ((7.0 * 24.0) - 1.0) * 60.0 * 60.0
EMPTY_TRASH_EXPIRE_TIME = 23.0 * 60.0 * 60.0 EMPTY_TRASH_EXPIRE_TIME = 23.0 * 60.0 * 60.0
RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME = 59.0 * 60.0 RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME = 59.0 * 60.0
LLM_INDEX_EXPIRE_TIME = 23.0 * 60.0 * 60.0
def test_schedule_configuration_default(self): def test_schedule_configuration_default(self):
""" """
@ -202,6 +203,16 @@ class TestCeleryScheduleParsing(TestCase):
"schedule": crontab(minute="5", hour="*/1"), "schedule": crontab(minute="5", hour="*/1"),
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
}, },
"Rebuild LLM index": {
"task": "documents.tasks.llmindex_index",
"schedule": crontab(minute=10, hour=2),
"options": {
"expires": self.LLM_INDEX_EXPIRE_TIME,
"kwargs": {
"progress_bar_disable": True,
},
},
},
}, },
schedule, schedule,
) )
@ -254,6 +265,16 @@ class TestCeleryScheduleParsing(TestCase):
"schedule": crontab(minute="5", hour="*/1"), "schedule": crontab(minute="5", hour="*/1"),
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
}, },
"Rebuild LLM index": {
"task": "documents.tasks.llmindex_index",
"schedule": crontab(minute=10, hour=2),
"options": {
"expires": self.LLM_INDEX_EXPIRE_TIME,
"kwargs": {
"progress_bar_disable": True,
},
},
},
}, },
schedule, schedule,
) )
@ -298,6 +319,16 @@ class TestCeleryScheduleParsing(TestCase):
"schedule": crontab(minute="5", hour="*/1"), "schedule": crontab(minute="5", hour="*/1"),
"options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME}, "options": {"expires": self.RUN_SCHEDULED_WORKFLOWS_EXPIRE_TIME},
}, },
"Rebuild LLM index": {
"task": "documents.tasks.llmindex_index",
"schedule": crontab(minute=10, hour=2),
"options": {
"expires": self.LLM_INDEX_EXPIRE_TIME,
"kwargs": {
"progress_bar_disable": True,
},
},
},
}, },
schedule, schedule,
) )
@ -320,6 +351,7 @@ class TestCeleryScheduleParsing(TestCase):
"PAPERLESS_INDEX_TASK_CRON": "disable", "PAPERLESS_INDEX_TASK_CRON": "disable",
"PAPERLESS_EMPTY_TRASH_TASK_CRON": "disable", "PAPERLESS_EMPTY_TRASH_TASK_CRON": "disable",
"PAPERLESS_WORKFLOW_SCHEDULED_TASK_CRON": "disable", "PAPERLESS_WORKFLOW_SCHEDULED_TASK_CRON": "disable",
"PAPERLESS_LLM_INDEX_TASK_CRON": "disable",
}, },
): ):
schedule = _parse_beat_schedule() schedule = _parse_beat_schedule()