mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-25 13:08:21 -05:00
Fix ollama, fix RAG
[ci skip]
This commit is contained in:
parent
f405b6e7b7
commit
422efb2235
@ -15,5 +15,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
self.handle_progress_bar_mixin(**options)
|
self.handle_progress_bar_mixin(**options)
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
if options["command"] == "rebuild":
|
llm_index_rebuild(
|
||||||
llm_index_rebuild(progress_bar_disable=self.no_progress_bar)
|
progress_bar_disable=self.no_progress_bar,
|
||||||
|
rebuild=options["command"] == "rebuild",
|
||||||
|
)
|
||||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
import llama_index.core.settings as llama_settings
|
||||||
import tqdm
|
import tqdm
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
@ -21,7 +22,9 @@ from filelock import FileLock
|
|||||||
from llama_index.core import Document as LlamaDocument
|
from llama_index.core import Document as LlamaDocument
|
||||||
from llama_index.core import StorageContext
|
from llama_index.core import StorageContext
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.node_parser import SimpleNodeParser
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
from whoosh.writing import AsyncWriter
|
from whoosh.writing import AsyncWriter
|
||||||
|
|
||||||
@ -512,28 +515,36 @@ def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
|
|||||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
documents = Document.objects.all()
|
|
||||||
|
|
||||||
embed_model = get_embedding_model()
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.Settings.embed_model = embed_model
|
||||||
|
|
||||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||||
embedding_dim = get_embedding_dim()
|
embedding_dim = get_embedding_dim()
|
||||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||||
vector_store = FaissVectorStore(faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
else:
|
else:
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
||||||
Settings.embed_model = embed_model
|
|
||||||
|
|
||||||
llm_docs = []
|
docstore = SimpleDocumentStore()
|
||||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
index_store = SimpleIndexStore()
|
||||||
|
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
docstore=docstore,
|
||||||
|
index_store=index_store,
|
||||||
|
persist_dir=settings.LLM_INDEX_DIR,
|
||||||
|
vector_store=vector_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = SimpleNodeParser()
|
||||||
|
nodes = []
|
||||||
|
|
||||||
|
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
|
||||||
if not document.content:
|
if not document.content:
|
||||||
continue
|
continue
|
||||||
llm_docs.append(
|
|
||||||
LlamaDocument(
|
text = build_llm_index_text(document)
|
||||||
text=build_llm_index_text(document),
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"id": document.id,
|
"document_id": document.id,
|
||||||
"title": document.title,
|
"title": document.title,
|
||||||
"tags": [t.name for t in document.tags.all()],
|
"tags": [t.name for t in document.tags.all()],
|
||||||
"correspondent": document.correspondent.name
|
"correspondent": document.correspondent.name
|
||||||
@ -542,15 +553,18 @@ def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
|
|||||||
"document_type": document.document_type.name
|
"document_type": document.document_type.name
|
||||||
if document.document_type
|
if document.document_type
|
||||||
else None,
|
else None,
|
||||||
"created": document.created.isoformat(),
|
"created": document.created.isoformat() if document.created else None,
|
||||||
"added": document.added.isoformat(),
|
"added": document.added.isoformat() if document.added else None,
|
||||||
},
|
}
|
||||||
),
|
|
||||||
|
doc = LlamaDocument(text=text, metadata=metadata)
|
||||||
|
doc_nodes = parser.get_nodes_from_documents([doc])
|
||||||
|
nodes.extend(doc_nodes)
|
||||||
|
|
||||||
|
index = VectorStoreIndex(
|
||||||
|
nodes=nodes,
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(
|
|
||||||
llm_docs,
|
|
||||||
storage_context=storage_context,
|
|
||||||
)
|
|
||||||
settings.LLM_INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
|
@ -62,7 +62,7 @@ def build_prompt_with_rag(document: Document) -> str:
|
|||||||
Only output valid JSON in the format below. No additional explanations.
|
Only output valid JSON in the format below. No additional explanations.
|
||||||
|
|
||||||
The JSON object must contain:
|
The JSON object must contain:
|
||||||
- title: A short, descriptive title
|
- title: A short, descriptive title based on the content
|
||||||
- tags: A list of relevant topics
|
- tags: A list of relevant topics
|
||||||
- correspondents: People or organizations involved
|
- correspondents: People or organizations involved
|
||||||
- document_types: Type or category of the document
|
- document_types: Type or category of the document
|
||||||
@ -112,6 +112,6 @@ def get_ai_document_classification(document: Document) -> dict:
|
|||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query(prompt)
|
result = client.run_llm_query(prompt)
|
||||||
return parse_ai_response(result)
|
return parse_ai_response(result)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Failed AI classification")
|
logger.exception("Failed AI classification")
|
||||||
return {}
|
raise e
|
||||||
|
@ -37,15 +37,15 @@ class AIClient:
|
|||||||
url = self.settings.llm_url or "http://localhost:11434"
|
url = self.settings.llm_url or "http://localhost:11434"
|
||||||
with httpx.Client(timeout=30.0) as client:
|
with httpx.Client(timeout=30.0) as client:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"{url}/api/chat",
|
f"{url}/api/generate",
|
||||||
json={
|
json={
|
||||||
"model": self.settings.llm_model,
|
"model": self.settings.llm_model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"prompt": prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()["message"]["content"]
|
return response.json()["response"]
|
||||||
|
|
||||||
def _run_openai_query(self, prompt: str) -> str:
|
def _run_openai_query(self, prompt: str) -> str:
|
||||||
if not self.settings.llm_api_key:
|
if not self.settings.llm_api_key:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user