mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-16 17:25:11 -05:00
Move to structured output
This commit is contained in:
parent
b94912a392
commit
20bae4bd41
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
@ -18,58 +17,34 @@ def build_prompt_without_rag(document: Document) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
|
||||
prompt = f"""
|
||||
You are an assistant that extracts structured information from documents.
|
||||
Only respond with the JSON object as described below.
|
||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
||||
Suggested tags and document types must be strictly based on the content of the document.
|
||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
||||
Each field must be a list of plain strings.
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
The JSON object must contain the following fields:
|
||||
- title: A short, descriptive title
|
||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
||||
- correspondents: A list of names or organizations mentioned in the document
|
||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||
Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
- The type or category of the document
|
||||
- Suggested folder paths for storing the document
|
||||
- Up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
"tags": ["xxxx", "xxxx"],
|
||||
"correspondents": ["xxxx", "xxxx"],
|
||||
"document_types": ["xxxx", "xxxx"],
|
||||
"storage_paths": ["xxxx", "xxxx"],
|
||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||
}}
|
||||
---------
|
||||
|
||||
FILENAME:
|
||||
Filename:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
Content:
|
||||
{content}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
prompt = build_prompt_without_rag(document)
|
||||
|
||||
prompt += f"""
|
||||
return f"""{base_prompt}
|
||||
|
||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||
Additional context from similar documents:
|
||||
{context}
|
||||
|
||||
---------
|
||||
|
||||
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_context_for_document(
|
||||
@ -100,36 +75,19 @@ def get_context_for_document(
|
||||
return "\n\n".join(context_blocks)
|
||||
|
||||
|
||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||
def parse_ai_response(raw: dict) -> dict:
|
||||
try:
|
||||
raw = json.loads(response.text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"title": raw.get("title", ""),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
||||
try:
|
||||
# search for a valid json string like { ... } in the response
|
||||
start = response.text.index("{")
|
||||
end = response.text.rindex("}") + 1
|
||||
json_str = response.text[start:end]
|
||||
raw = json.loads(json_str)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
logger.exception("Failed to parse AI response")
|
||||
return {}
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
logger.exception("Failed to parse AI response")
|
||||
return {}
|
||||
|
||||
|
||||
def get_ai_document_classification(
|
||||
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.tools import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
@ -18,7 +20,7 @@ class AIClient:
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self):
|
||||
def get_llm(self) -> Ollama | OpenAI:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
@ -39,9 +41,21 @@ class AIClient:
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.complete(prompt)
|
||||
logger.debug("LLM query result: %s", result)
|
||||
return result
|
||||
|
||||
user_msg = ChatMessage(role="user", content=prompt)
|
||||
tool = get_function_tool(DocumentClassifierSchema)
|
||||
result = self.llm.chat_with_tools(
|
||||
tools=[tool],
|
||||
user_msg=user_msg,
|
||||
chat_history=[],
|
||||
)
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
result,
|
||||
error_on_no_tool_calls=True,
|
||||
)
|
||||
logger.debug("LLM query result: %s", tool_calls)
|
||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||
return parsed.model_dump()
|
||||
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
logger.debug(
|
||||
|
10
src/paperless_ai/tools.py
Normal file
10
src/paperless_ai/tools.py
Normal file
@ -0,0 +1,10 @@
|
||||
from llama_index.core.bridge.pydantic import BaseModel
|
||||
|
||||
|
||||
class DocumentClassifierSchema(BaseModel):
|
||||
title: str
|
||||
tags: list[str]
|
||||
correspondents: list[str]
|
||||
document_types: list[str]
|
||||
storage_paths: list[str]
|
||||
dates: list[str]
|
Loading…
x
Reference in New Issue
Block a user