mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-21 12:52:13 -05:00
RAG into suggestions
This commit is contained in:
parent
58f3b7be0a
commit
f405b6e7b7
@ -3,15 +3,13 @@ import logging
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.rag import get_context_for_document
|
||||
from paperless.config import AIConfig
|
||||
|
||||
logger = logging.getLogger("paperless.ai.ai_classifier")
|
||||
logger = logging.getLogger("paperless.ai.rag_classifier")
|
||||
|
||||
|
||||
def get_ai_document_classification(document: Document) -> dict:
|
||||
"""
|
||||
Returns classification suggestions for a given document using an LLM.
|
||||
Output schema matches the API's expected DocumentClassificationSuggestions format.
|
||||
"""
|
||||
def build_prompt_without_rag(document: Document) -> str:
|
||||
filename = document.filename or ""
|
||||
content = document.content or ""
|
||||
|
||||
@ -41,6 +39,7 @@ def get_ai_document_classification(document: Document) -> dict:
|
||||
}}
|
||||
---
|
||||
|
||||
|
||||
FILENAME:
|
||||
{filename}
|
||||
|
||||
@ -48,39 +47,71 @@ def get_ai_document_classification(document: Document) -> dict:
|
||||
{content[:8000]} # Trim to safe size
|
||||
"""
|
||||
|
||||
try:
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
suggestions = parse_ai_classification_response(result)
|
||||
return suggestions or {}
|
||||
except Exception:
|
||||
logger.exception("Error during LLM classification: %s", exc_info=True)
|
||||
return {}
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_ai_classification_response(text: str) -> dict:
|
||||
"""
|
||||
Parses LLM output and ensures it conforms to expected schema.
|
||||
def build_prompt_with_rag(document: Document) -> str:
|
||||
context = get_context_for_document(document)
|
||||
content = document.content or ""
|
||||
filename = document.filename or ""
|
||||
|
||||
prompt = f"""
|
||||
You are a helpful assistant that extracts structured information from documents.
|
||||
You have access to similar documents as context to help improve suggestions.
|
||||
|
||||
Only output valid JSON in the format below. No additional explanations.
|
||||
|
||||
The JSON object must contain:
|
||||
- title: A short, descriptive title
|
||||
- tags: A list of relevant topics
|
||||
- correspondents: People or organizations involved
|
||||
- document_types: Type or category of the document
|
||||
- storage_paths: Suggested folder paths
|
||||
- dates: Up to 3 relevant dates in YYYY-MM-DD
|
||||
|
||||
Here is an example document:
|
||||
FILENAME:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
{content[:4000]}
|
||||
|
||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||
{context[:4000]}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_ai_response(text: str) -> dict:
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": [raw["correspondents"]]
|
||||
if isinstance(raw.get("correspondents"), str)
|
||||
else raw.get("correspondents", []),
|
||||
"document_types": [raw["document_types"]]
|
||||
if isinstance(raw.get("document_types"), str)
|
||||
else raw.get("document_types", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": [d for d in raw.get("dates", []) if d],
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# fallback: try to extract JSON manually?
|
||||
logger.exception(
|
||||
"Failed to parse LLM classification response: %s",
|
||||
text,
|
||||
exc_info=True,
|
||||
)
|
||||
logger.exception("Invalid JSON in RAG response")
|
||||
return {}
|
||||
|
||||
|
||||
def get_ai_document_classification(document: Document) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
)
|
||||
|
||||
try:
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
return parse_ai_response(result)
|
||||
except Exception:
|
||||
logger.exception("Failed AI classification")
|
||||
return {}
|
||||
|
12
src/paperless/ai/rag.py
Normal file
12
src/paperless/ai/rag.py
Normal file
@ -0,0 +1,12 @@
|
||||
from documents.models import Document
|
||||
from paperless.ai.indexing import query_similar_documents
|
||||
|
||||
|
||||
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
||||
similar_docs = query_similar_documents(doc)[:max_docs]
|
||||
context_blocks = []
|
||||
for similar in similar_docs:
|
||||
text = similar.content or ""
|
||||
title = similar.title or similar.filename or "Untitled"
|
||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||
return "\n\n".join(context_blocks)
|
Loading…
x
Reference in New Issue
Block a user