From 422efb2235e33d73c35152019095e8ec7ed17076 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 24 Apr 2025 22:03:21 -0700 Subject: [PATCH] Fix ollama, fix RAG [ci skip] --- .../management/commands/document_llmindex.py | 6 +- src/documents/tasks.py | 72 +++++++++++-------- src/paperless/ai/ai_classifier.py | 6 +- src/paperless/ai/client.py | 6 +- 4 files changed, 53 insertions(+), 37 deletions(-) diff --git a/src/documents/management/commands/document_llmindex.py b/src/documents/management/commands/document_llmindex.py index 2985a61e4..09ea477c2 100644 --- a/src/documents/management/commands/document_llmindex.py +++ b/src/documents/management/commands/document_llmindex.py @@ -15,5 +15,7 @@ class Command(ProgressBarMixin, BaseCommand): def handle(self, *args, **options): self.handle_progress_bar_mixin(**options) with transaction.atomic(): - if options["command"] == "rebuild": - llm_index_rebuild(progress_bar_disable=self.no_progress_bar) + llm_index_rebuild( + progress_bar_disable=self.no_progress_bar, + rebuild=options["command"] == "rebuild", + ) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 903149ba0..b12bcb076 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -7,6 +7,7 @@ from pathlib import Path from tempfile import TemporaryDirectory import faiss +import llama_index.core.settings as llama_settings import tqdm from celery import Task from celery import shared_task @@ -21,7 +22,9 @@ from filelock import FileLock from llama_index.core import Document as LlamaDocument from llama_index.core import StorageContext from llama_index.core import VectorStoreIndex -from llama_index.core.settings import Settings +from llama_index.core.node_parser import SimpleNodeParser +from llama_index.core.storage.docstore import SimpleDocumentStore +from llama_index.core.storage.index_store import SimpleIndexStore from llama_index.vector_stores.faiss import FaissVectorStore from whoosh.writing import AsyncWriter @@ -512,45 +515,56 @@ def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False): shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True) settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) - documents = Document.objects.all() - embed_model = get_embedding_model() + llama_settings.Settings.embed_model = embed_model if rebuild or not settings.LLM_INDEX_DIR.exists(): embedding_dim = get_embedding_dim() faiss_index = faiss.IndexFlatL2(embedding_dim) - vector_store = FaissVectorStore(faiss_index) + vector_store = FaissVectorStore(faiss_index=faiss_index) else: vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) - storage_context = StorageContext.from_defaults(vector_store=vector_store) - Settings.embed_model = embed_model - llm_docs = [] - for document in tqdm.tqdm(documents, disable=progress_bar_disable): + docstore = SimpleDocumentStore() + index_store = SimpleIndexStore() + + storage_context = StorageContext.from_defaults( + docstore=docstore, + index_store=index_store, + persist_dir=settings.LLM_INDEX_DIR, + vector_store=vector_store, + ) + + parser = SimpleNodeParser() + nodes = [] + + for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable): if not document.content: continue - llm_docs.append( - LlamaDocument( - text=build_llm_index_text(document), - metadata={ - "id": document.id, - "title": document.title, - "tags": [t.name for t in document.tags.all()], - "correspondent": document.correspondent.name - if document.correspondent - else None, - "document_type": document.document_type.name - if document.document_type - else None, - "created": document.created.isoformat(), - "added": document.added.isoformat(), - }, - ), - ) - index = VectorStoreIndex.from_documents( - llm_docs, + text = build_llm_index_text(document) + metadata = { + "document_id": document.id, + "title": document.title, + "tags": [t.name for t in document.tags.all()], + "correspondent": document.correspondent.name + if document.correspondent + else None, + "document_type": document.document_type.name + if document.document_type + else None, + "created": document.created.isoformat() if document.created else None, + "added": document.added.isoformat() if document.added else None, + } + + doc = LlamaDocument(text=text, metadata=metadata) + doc_nodes = parser.get_nodes_from_documents([doc]) + nodes.extend(doc_nodes) + + index = VectorStoreIndex( + nodes=nodes, storage_context=storage_context, + embed_model=embed_model, ) - settings.LLM_INDEX_DIR.mkdir(exist_ok=True) + index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index d5ec88323..f52548b62 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -62,7 +62,7 @@ def build_prompt_with_rag(document: Document) -> str: Only output valid JSON in the format below. No additional explanations. The JSON object must contain: - - title: A short, descriptive title + - title: A short, descriptive title based on the content - tags: A list of relevant topics - correspondents: People or organizations involved - document_types: Type or category of the document @@ -112,6 +112,6 @@ def get_ai_document_classification(document: Document) -> dict: client = AIClient() result = client.run_llm_query(prompt) return parse_ai_response(result) - except Exception: + except Exception as e: logger.exception("Failed AI classification") - return {} + raise e diff --git a/src/paperless/ai/client.py b/src/paperless/ai/client.py index 03012844f..d37468b4e 100644 --- a/src/paperless/ai/client.py +++ b/src/paperless/ai/client.py @@ -37,15 +37,15 @@ class AIClient: url = self.settings.llm_url or "http://localhost:11434" with httpx.Client(timeout=30.0) as client: response = client.post( - f"{url}/api/chat", + f"{url}/api/generate", json={ "model": self.settings.llm_model, - "messages": [{"role": "user", "content": prompt}], + "prompt": prompt, "stream": False, }, ) response.raise_for_status() - return response.json()["message"]["content"] + return response.json()["response"] def _run_openai_query(self, prompt: str) -> str: if not self.settings.llm_api_key: