mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-16 17:25:11 -05:00
Move to structured output
This commit is contained in:
parent
b94912a392
commit
20bae4bd41
@ -2,7 +2,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from llama_index.core.base.llms.types import CompletionResponse
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
@ -18,58 +17,34 @@ def build_prompt_without_rag(document: Document) -> str:
|
|||||||
filename = document.filename or ""
|
filename = document.filename or ""
|
||||||
content = truncate_content(document.content[:4000] or "")
|
content = truncate_content(document.content[:4000] or "")
|
||||||
|
|
||||||
prompt = f"""
|
return f"""
|
||||||
You are an assistant that extracts structured information from documents.
|
You are a document classification assistant.
|
||||||
Only respond with the JSON object as described below.
|
|
||||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
|
||||||
Suggested tags and document types must be strictly based on the content of the document.
|
|
||||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
|
||||||
Each field must be a list of plain strings.
|
|
||||||
|
|
||||||
The JSON object must contain the following fields:
|
Analyze the following document and extract the following information:
|
||||||
- title: A short, descriptive title
|
- A short descriptive title
|
||||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
- Tags that reflect the content
|
||||||
- correspondents: A list of names or organizations mentioned in the document
|
- Names of people or organizations mentioned
|
||||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
- The type or category of the document
|
||||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
- Suggested folder paths for storing the document
|
||||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
- Up to 3 relevant dates in YYYY-MM-DD format
|
||||||
|
|
||||||
The format of the JSON object is as follows:
|
Filename:
|
||||||
{{
|
|
||||||
"title": "xxxxx",
|
|
||||||
"tags": ["xxxx", "xxxx"],
|
|
||||||
"correspondents": ["xxxx", "xxxx"],
|
|
||||||
"document_types": ["xxxx", "xxxx"],
|
|
||||||
"storage_paths": ["xxxx", "xxxx"],
|
|
||||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
|
||||||
}}
|
|
||||||
---------
|
|
||||||
|
|
||||||
FILENAME:
|
|
||||||
{filename}
|
{filename}
|
||||||
|
|
||||||
CONTENT:
|
Content:
|
||||||
{content}
|
{content}
|
||||||
"""
|
""".strip()
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||||
|
base_prompt = build_prompt_without_rag(document)
|
||||||
context = truncate_content(get_context_for_document(document, user))
|
context = truncate_content(get_context_for_document(document, user))
|
||||||
prompt = build_prompt_without_rag(document)
|
|
||||||
|
|
||||||
prompt += f"""
|
return f"""{base_prompt}
|
||||||
|
|
||||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
Additional context from similar documents:
|
||||||
{context}
|
{context}
|
||||||
|
""".strip()
|
||||||
---------
|
|
||||||
|
|
||||||
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_for_document(
|
def get_context_for_document(
|
||||||
@ -100,36 +75,19 @@ def get_context_for_document(
|
|||||||
return "\n\n".join(context_blocks)
|
return "\n\n".join(context_blocks)
|
||||||
|
|
||||||
|
|
||||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
def parse_ai_response(raw: dict) -> dict:
|
||||||
try:
|
try:
|
||||||
raw = json.loads(response.text)
|
|
||||||
return {
|
return {
|
||||||
"title": raw.get("title"),
|
"title": raw.get("title", ""),
|
||||||
"tags": raw.get("tags", []),
|
"tags": raw.get("tags", []),
|
||||||
"correspondents": raw.get("correspondents", []),
|
"correspondents": raw.get("correspondents", []),
|
||||||
"document_types": raw.get("document_types", []),
|
"document_types": raw.get("document_types", []),
|
||||||
"storage_paths": raw.get("storage_paths", []),
|
"storage_paths": raw.get("storage_paths", []),
|
||||||
"dates": raw.get("dates", []),
|
"dates": raw.get("dates", []),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except (ValueError, json.JSONDecodeError):
|
||||||
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
logger.exception("Failed to parse AI response")
|
||||||
try:
|
return {}
|
||||||
# search for a valid json string like { ... } in the response
|
|
||||||
start = response.text.index("{")
|
|
||||||
end = response.text.rindex("}") + 1
|
|
||||||
json_str = response.text[start:end]
|
|
||||||
raw = json.loads(json_str)
|
|
||||||
return {
|
|
||||||
"title": raw.get("title"),
|
|
||||||
"tags": raw.get("tags", []),
|
|
||||||
"correspondents": raw.get("correspondents", []),
|
|
||||||
"document_types": raw.get("document_types", []),
|
|
||||||
"storage_paths": raw.get("storage_paths", []),
|
|
||||||
"dates": raw.get("dates", []),
|
|
||||||
}
|
|
||||||
except (ValueError, json.JSONDecodeError):
|
|
||||||
logger.exception("Failed to parse AI response")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_ai_document_classification(
|
def get_ai_document_classification(
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from llama_index.core.llms import ChatMessage
|
from llama_index.core.llms import ChatMessage
|
||||||
|
from llama_index.core.program.function_program import get_function_tool
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
from llama_index.llms.openai import OpenAI
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
from paperless_ai.tools import DocumentClassifierSchema
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.client")
|
logger = logging.getLogger("paperless_ai.client")
|
||||||
|
|
||||||
@ -18,7 +20,7 @@ class AIClient:
|
|||||||
self.settings = AIConfig()
|
self.settings = AIConfig()
|
||||||
self.llm = self.get_llm()
|
self.llm = self.get_llm()
|
||||||
|
|
||||||
def get_llm(self):
|
def get_llm(self) -> Ollama | OpenAI:
|
||||||
if self.settings.llm_backend == "ollama":
|
if self.settings.llm_backend == "ollama":
|
||||||
return Ollama(
|
return Ollama(
|
||||||
model=self.settings.llm_model or "llama3",
|
model=self.settings.llm_model or "llama3",
|
||||||
@ -39,9 +41,21 @@ class AIClient:
|
|||||||
self.settings.llm_backend,
|
self.settings.llm_backend,
|
||||||
self.settings.llm_model,
|
self.settings.llm_model,
|
||||||
)
|
)
|
||||||
result = self.llm.complete(prompt)
|
|
||||||
logger.debug("LLM query result: %s", result)
|
user_msg = ChatMessage(role="user", content=prompt)
|
||||||
return result
|
tool = get_function_tool(DocumentClassifierSchema)
|
||||||
|
result = self.llm.chat_with_tools(
|
||||||
|
tools=[tool],
|
||||||
|
user_msg=user_msg,
|
||||||
|
chat_history=[],
|
||||||
|
)
|
||||||
|
tool_calls = self.llm.get_tool_calls_from_response(
|
||||||
|
result,
|
||||||
|
error_on_no_tool_calls=True,
|
||||||
|
)
|
||||||
|
logger.debug("LLM query result: %s", tool_calls)
|
||||||
|
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||||
|
return parsed.model_dump()
|
||||||
|
|
||||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
10
src/paperless_ai/tools.py
Normal file
10
src/paperless_ai/tools.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from llama_index.core.bridge.pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentClassifierSchema(BaseModel):
|
||||||
|
title: str
|
||||||
|
tags: list[str]
|
||||||
|
correspondents: list[str]
|
||||||
|
document_types: list[str]
|
||||||
|
storage_paths: list[str]
|
||||||
|
dates: list[str]
|
Loading…
x
Reference in New Issue
Block a user