Move to structured output

This commit is contained in:
shamoon 2025-07-15 14:27:29 -07:00
parent b94912a392
commit 20bae4bd41
No known key found for this signature in database
3 changed files with 49 additions and 67 deletions

View File

@ -2,7 +2,6 @@ import json
import logging import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware from documents.permissions import get_objects_for_user_owner_aware
@ -18,58 +17,34 @@ def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or "" filename = document.filename or ""
content = truncate_content(document.content[:4000] or "") content = truncate_content(document.content[:4000] or "")
prompt = f""" return f"""
You are an assistant that extracts structured information from documents. You are a document classification assistant.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
Each field must be a list of plain strings.
The JSON object must contain the following fields: Analyze the following document and extract the following information:
- title: A short, descriptive title - A short descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"] - Tags that reflect the content
- correspondents: A list of names or organizations mentioned in the document - Names of people or organizations mentioned
- document_types: The type/category of the document (e.g. "invoice", "medical record") - The type or category of the document
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance") - Suggested folder paths for storing the document
- dates: List up to 3 relevant dates in YYYY-MM-DD format - Up to 3 relevant dates in YYYY-MM-DD format
The format of the JSON object is as follows: Filename:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---------
FILENAME:
{filename} {filename}
CONTENT: Content:
{content} {content}
""" """.strip()
return prompt
def build_prompt_with_rag(document: Document, user: User | None = None) -> str: def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user)) context = truncate_content(get_context_for_document(document, user))
prompt = build_prompt_without_rag(document)
prompt += f""" return f"""{base_prompt}
CONTEXT FROM SIMILAR DOCUMENTS: Additional context from similar documents:
{context} {context}
""".strip()
---------
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
"""
return prompt
def get_context_for_document( def get_context_for_document(
@ -100,36 +75,19 @@ def get_context_for_document(
return "\n\n".join(context_blocks) return "\n\n".join(context_blocks)
def parse_ai_response(response: CompletionResponse) -> dict: def parse_ai_response(raw: dict) -> dict:
try: try:
raw = json.loads(response.text)
return { return {
"title": raw.get("title"), "title": raw.get("title", ""),
"tags": raw.get("tags", []), "tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []), "correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []), "document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []), "storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []), "dates": raw.get("dates", []),
} }
except json.JSONDecodeError: except (ValueError, json.JSONDecodeError):
logger.warning("Invalid JSON in AI response, attempting modified parsing...") logger.exception("Failed to parse AI response")
try: return {}
# search for a valid json string like { ... } in the response
start = response.text.index("{")
end = response.text.rindex("}") + 1
json_str = response.text[start:end]
raw = json.loads(json_str)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except (ValueError, json.JSONDecodeError):
logger.exception("Failed to parse AI response")
return {}
def get_ai_document_classification( def get_ai_document_classification(

View File

@ -1,10 +1,12 @@
import logging import logging
from llama_index.core.llms import ChatMessage from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client") logger = logging.getLogger("paperless_ai.client")
@ -18,7 +20,7 @@ class AIClient:
self.settings = AIConfig() self.settings = AIConfig()
self.llm = self.get_llm() self.llm = self.get_llm()
def get_llm(self): def get_llm(self) -> Ollama | OpenAI:
if self.settings.llm_backend == "ollama": if self.settings.llm_backend == "ollama":
return Ollama( return Ollama(
model=self.settings.llm_model or "llama3", model=self.settings.llm_model or "llama3",
@ -39,9 +41,21 @@ class AIClient:
self.settings.llm_backend, self.settings.llm_backend,
self.settings.llm_model, self.settings.llm_model,
) )
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result) user_msg = ChatMessage(role="user", content=prompt)
return result tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_calls=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str: def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug( logger.debug(

10
src/paperless_ai/tools.py Normal file
View File

@ -0,0 +1,10 @@
from llama_index.core.bridge.pydantic import BaseModel
class DocumentClassifierSchema(BaseModel):
title: str
tags: list[str]
correspondents: list[str]
document_types: list[str]
storage_paths: list[str]
dates: list[str]