From 20bae4bd41676b67e996f0a685b20baa0b88e3e7 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:27:29 -0700 Subject: [PATCH] Move to structured output --- src/paperless_ai/ai_classifier.py | 84 ++++++++----------------------- src/paperless_ai/client.py | 22 ++++++-- src/paperless_ai/tools.py | 10 ++++ 3 files changed, 49 insertions(+), 67 deletions(-) create mode 100644 src/paperless_ai/tools.py diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index 55c7c7704..3b251da2b 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -2,7 +2,6 @@ import json import logging from django.contrib.auth.models import User -from llama_index.core.base.llms.types import CompletionResponse from documents.models import Document from documents.permissions import get_objects_for_user_owner_aware @@ -18,58 +17,34 @@ def build_prompt_without_rag(document: Document) -> str: filename = document.filename or "" content = truncate_content(document.content[:4000] or "") - prompt = f""" - You are an assistant that extracts structured information from documents. - Only respond with the JSON object as described below. - Never ask for further information, additional content or ask questions. Never include any other text. - Suggested tags and document types must be strictly based on the content of the document. - Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax. - Each field must be a list of plain strings. + return f""" + You are a document classification assistant. - The JSON object must contain the following fields: - - title: A short, descriptive title - - tags: A list of simple tags like ["insurance", "medical", "receipts"] - - correspondents: A list of names or organizations mentioned in the document - - document_types: The type/category of the document (e.g. "invoice", "medical record") - - storage_paths: Suggested folder paths (e.g. "Medical/Insurance") - - dates: List up to 3 relevant dates in YYYY-MM-DD format + Analyze the following document and extract the following information: + - A short descriptive title + - Tags that reflect the content + - Names of people or organizations mentioned + - The type or category of the document + - Suggested folder paths for storing the document + - Up to 3 relevant dates in YYYY-MM-DD format - The format of the JSON object is as follows: - {{ - "title": "xxxxx", - "tags": ["xxxx", "xxxx"], - "correspondents": ["xxxx", "xxxx"], - "document_types": ["xxxx", "xxxx"], - "storage_paths": ["xxxx", "xxxx"], - "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"], - }} - --------- - - FILENAME: + Filename: {filename} - CONTENT: + Content: {content} - """ - - return prompt + """.strip() def build_prompt_with_rag(document: Document, user: User | None = None) -> str: + base_prompt = build_prompt_without_rag(document) context = truncate_content(get_context_for_document(document, user)) - prompt = build_prompt_without_rag(document) - prompt += f""" + return f"""{base_prompt} - CONTEXT FROM SIMILAR DOCUMENTS: + Additional context from similar documents: {context} - - --------- - - DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT. - """ - - return prompt + """.strip() def get_context_for_document( @@ -100,36 +75,19 @@ def get_context_for_document( return "\n\n".join(context_blocks) -def parse_ai_response(response: CompletionResponse) -> dict: +def parse_ai_response(raw: dict) -> dict: try: - raw = json.loads(response.text) return { - "title": raw.get("title"), + "title": raw.get("title", ""), "tags": raw.get("tags", []), "correspondents": raw.get("correspondents", []), "document_types": raw.get("document_types", []), "storage_paths": raw.get("storage_paths", []), "dates": raw.get("dates", []), } - except json.JSONDecodeError: - logger.warning("Invalid JSON in AI response, attempting modified parsing...") - try: - # search for a valid json string like { ... } in the response - start = response.text.index("{") - end = response.text.rindex("}") + 1 - json_str = response.text[start:end] - raw = json.loads(json_str) - return { - "title": raw.get("title"), - "tags": raw.get("tags", []), - "correspondents": raw.get("correspondents", []), - "document_types": raw.get("document_types", []), - "storage_paths": raw.get("storage_paths", []), - "dates": raw.get("dates", []), - } - except (ValueError, json.JSONDecodeError): - logger.exception("Failed to parse AI response") - return {} + except (ValueError, json.JSONDecodeError): + logger.exception("Failed to parse AI response") + return {} def get_ai_document_classification( diff --git a/src/paperless_ai/client.py b/src/paperless_ai/client.py index 67023cfb5..651ca7022 100644 --- a/src/paperless_ai/client.py +++ b/src/paperless_ai/client.py @@ -1,10 +1,12 @@ import logging from llama_index.core.llms import ChatMessage +from llama_index.core.program.function_program import get_function_tool from llama_index.llms.ollama import Ollama from llama_index.llms.openai import OpenAI from paperless.config import AIConfig +from paperless_ai.tools import DocumentClassifierSchema logger = logging.getLogger("paperless_ai.client") @@ -18,7 +20,7 @@ class AIClient: self.settings = AIConfig() self.llm = self.get_llm() - def get_llm(self): + def get_llm(self) -> Ollama | OpenAI: if self.settings.llm_backend == "ollama": return Ollama( model=self.settings.llm_model or "llama3", @@ -39,9 +41,21 @@ class AIClient: self.settings.llm_backend, self.settings.llm_model, ) - result = self.llm.complete(prompt) - logger.debug("LLM query result: %s", result) - return result + + user_msg = ChatMessage(role="user", content=prompt) + tool = get_function_tool(DocumentClassifierSchema) + result = self.llm.chat_with_tools( + tools=[tool], + user_msg=user_msg, + chat_history=[], + ) + tool_calls = self.llm.get_tool_calls_from_response( + result, + error_on_no_tool_calls=True, + ) + logger.debug("LLM query result: %s", tool_calls) + parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs) + return parsed.model_dump() def run_chat(self, messages: list[ChatMessage]) -> str: logger.debug( diff --git a/src/paperless_ai/tools.py b/src/paperless_ai/tools.py new file mode 100644 index 000000000..2924f2c8c --- /dev/null +++ b/src/paperless_ai/tools.py @@ -0,0 +1,10 @@ +from llama_index.core.bridge.pydantic import BaseModel + + +class DocumentClassifierSchema(BaseModel): + title: str + tags: list[str] + correspondents: list[str] + document_types: list[str] + storage_paths: list[str] + dates: list[str]