Fix ollama, fix RAG

[ci skip]
This commit is contained in:
shamoon 2025-04-24 22:03:21 -07:00
parent f405b6e7b7
commit 422efb2235
No known key found for this signature in database
4 changed files with 53 additions and 37 deletions

View File

@ -15,5 +15,7 @@ class Command(ProgressBarMixin, BaseCommand):
def handle(self, *args, **options):
self.handle_progress_bar_mixin(**options)
with transaction.atomic():
if options["command"] == "rebuild":
llm_index_rebuild(progress_bar_disable=self.no_progress_bar)
llm_index_rebuild(
progress_bar_disable=self.no_progress_bar,
rebuild=options["command"] == "rebuild",
)

View File

@ -7,6 +7,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory
import faiss
import llama_index.core.settings as llama_settings
import tqdm
from celery import Task
from celery import shared_task
@ -21,7 +22,9 @@ from filelock import FileLock
from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core.settings import Settings
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from whoosh.writing import AsyncWriter
@ -512,45 +515,56 @@ def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
documents = Document.objects.all()
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
if rebuild or not settings.LLM_INDEX_DIR.exists():
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index)
vector_store = FaissVectorStore(faiss_index=faiss_index)
else:
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
Settings.embed_model = embed_model
llm_docs = []
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
storage_context = StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
persist_dir=settings.LLM_INDEX_DIR,
vector_store=vector_store,
)
parser = SimpleNodeParser()
nodes = []
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
if not document.content:
continue
llm_docs.append(
LlamaDocument(
text=build_llm_index_text(document),
metadata={
"id": document.id,
"title": document.title,
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
"document_type": document.document_type.name
if document.document_type
else None,
"created": document.created.isoformat(),
"added": document.added.isoformat(),
},
),
)
index = VectorStoreIndex.from_documents(
llm_docs,
text = build_llm_index_text(document)
metadata = {
"document_id": document.id,
"title": document.title,
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
"document_type": document.document_type.name
if document.document_type
else None,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
}
doc = LlamaDocument(text=text, metadata=metadata)
doc_nodes = parser.get_nodes_from_documents([doc])
nodes.extend(doc_nodes)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
settings.LLM_INDEX_DIR.mkdir(exist_ok=True)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)

View File

@ -62,7 +62,7 @@ def build_prompt_with_rag(document: Document) -> str:
Only output valid JSON in the format below. No additional explanations.
The JSON object must contain:
- title: A short, descriptive title
- title: A short, descriptive title based on the content
- tags: A list of relevant topics
- correspondents: People or organizations involved
- document_types: Type or category of the document
@ -112,6 +112,6 @@ def get_ai_document_classification(document: Document) -> dict:
client = AIClient()
result = client.run_llm_query(prompt)
return parse_ai_response(result)
except Exception:
except Exception as e:
logger.exception("Failed AI classification")
return {}
raise e

View File

@ -37,15 +37,15 @@ class AIClient:
url = self.settings.llm_url or "http://localhost:11434"
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{url}/api/chat",
f"{url}/api/generate",
json={
"model": self.settings.llm_model,
"messages": [{"role": "user", "content": prompt}],
"prompt": prompt,
"stream": False,
},
)
response.raise_for_status()
return response.json()["message"]["content"]
return response.json()["response"]
def _run_openai_query(self, prompt: str) -> str:
if not self.settings.llm_api_key: