mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
RAG into suggestions
This commit is contained in:
parent
58f3b7be0a
commit
f405b6e7b7
@ -3,15 +3,13 @@ import logging
|
|||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
|
from paperless.ai.rag import get_context_for_document
|
||||||
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.ai_classifier")
|
logger = logging.getLogger("paperless.ai.rag_classifier")
|
||||||
|
|
||||||
|
|
||||||
def get_ai_document_classification(document: Document) -> dict:
|
def build_prompt_without_rag(document: Document) -> str:
|
||||||
"""
|
|
||||||
Returns classification suggestions for a given document using an LLM.
|
|
||||||
Output schema matches the API's expected DocumentClassificationSuggestions format.
|
|
||||||
"""
|
|
||||||
filename = document.filename or ""
|
filename = document.filename or ""
|
||||||
content = document.content or ""
|
content = document.content or ""
|
||||||
|
|
||||||
@ -41,6 +39,7 @@ def get_ai_document_classification(document: Document) -> dict:
|
|||||||
}}
|
}}
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
FILENAME:
|
FILENAME:
|
||||||
{filename}
|
{filename}
|
||||||
|
|
||||||
@ -48,39 +47,71 @@ def get_ai_document_classification(document: Document) -> dict:
|
|||||||
{content[:8000]} # Trim to safe size
|
{content[:8000]} # Trim to safe size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
return prompt
|
||||||
client = AIClient()
|
|
||||||
result = client.run_llm_query(prompt)
|
|
||||||
suggestions = parse_ai_classification_response(result)
|
|
||||||
return suggestions or {}
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error during LLM classification: %s", exc_info=True)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_ai_classification_response(text: str) -> dict:
|
def build_prompt_with_rag(document: Document) -> str:
|
||||||
"""
|
context = get_context_for_document(document)
|
||||||
Parses LLM output and ensures it conforms to expected schema.
|
content = document.content or ""
|
||||||
|
filename = document.filename or ""
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are a helpful assistant that extracts structured information from documents.
|
||||||
|
You have access to similar documents as context to help improve suggestions.
|
||||||
|
|
||||||
|
Only output valid JSON in the format below. No additional explanations.
|
||||||
|
|
||||||
|
The JSON object must contain:
|
||||||
|
- title: A short, descriptive title
|
||||||
|
- tags: A list of relevant topics
|
||||||
|
- correspondents: People or organizations involved
|
||||||
|
- document_types: Type or category of the document
|
||||||
|
- storage_paths: Suggested folder paths
|
||||||
|
- dates: Up to 3 relevant dates in YYYY-MM-DD
|
||||||
|
|
||||||
|
Here is an example document:
|
||||||
|
FILENAME:
|
||||||
|
{filename}
|
||||||
|
|
||||||
|
CONTENT:
|
||||||
|
{content[:4000]}
|
||||||
|
|
||||||
|
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||||
|
{context[:4000]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ai_response(text: str) -> dict:
|
||||||
try:
|
try:
|
||||||
raw = json.loads(text)
|
raw = json.loads(text)
|
||||||
return {
|
return {
|
||||||
"title": raw.get("title"),
|
"title": raw.get("title"),
|
||||||
"tags": raw.get("tags", []),
|
"tags": raw.get("tags", []),
|
||||||
"correspondents": [raw["correspondents"]]
|
"correspondents": raw.get("correspondents", []),
|
||||||
if isinstance(raw.get("correspondents"), str)
|
"document_types": raw.get("document_types", []),
|
||||||
else raw.get("correspondents", []),
|
|
||||||
"document_types": [raw["document_types"]]
|
|
||||||
if isinstance(raw.get("document_types"), str)
|
|
||||||
else raw.get("document_types", []),
|
|
||||||
"storage_paths": raw.get("storage_paths", []),
|
"storage_paths": raw.get("storage_paths", []),
|
||||||
"dates": [d for d in raw.get("dates", []) if d],
|
"dates": raw.get("dates", []),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# fallback: try to extract JSON manually?
|
logger.exception("Invalid JSON in RAG response")
|
||||||
logger.exception(
|
return {}
|
||||||
"Failed to parse LLM classification response: %s",
|
|
||||||
text,
|
|
||||||
exc_info=True,
|
def get_ai_document_classification(document: Document) -> dict:
|
||||||
)
|
ai_config = AIConfig()
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
build_prompt_with_rag(document)
|
||||||
|
if ai_config.llm_embedding_backend
|
||||||
|
else build_prompt_without_rag(document)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = AIClient()
|
||||||
|
result = client.run_llm_query(prompt)
|
||||||
|
return parse_ai_response(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed AI classification")
|
||||||
return {}
|
return {}
|
||||||
|
12
src/paperless/ai/rag.py
Normal file
12
src/paperless/ai/rag.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from documents.models import Document
|
||||||
|
from paperless.ai.indexing import query_similar_documents
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
|
||||||
|
similar_docs = query_similar_documents(doc)[:max_docs]
|
||||||
|
context_blocks = []
|
||||||
|
for similar in similar_docs:
|
||||||
|
text = similar.content or ""
|
||||||
|
title = similar.title or similar.filename or "Untitled"
|
||||||
|
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||||
|
return "\n\n".join(context_blocks)
|
Loading…
x
Reference in New Issue
Block a user