diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index 949cfaf69..d5ec88323 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -3,15 +3,13 @@ import logging from documents.models import Document from paperless.ai.client import AIClient +from paperless.ai.rag import get_context_for_document +from paperless.config import AIConfig -logger = logging.getLogger("paperless.ai.ai_classifier") +logger = logging.getLogger("paperless.ai.rag_classifier") -def get_ai_document_classification(document: Document) -> dict: - """ - Returns classification suggestions for a given document using an LLM. - Output schema matches the API's expected DocumentClassificationSuggestions format. - """ +def build_prompt_without_rag(document: Document) -> str: filename = document.filename or "" content = document.content or "" @@ -41,6 +39,7 @@ def get_ai_document_classification(document: Document) -> dict: }} --- + FILENAME: {filename} @@ -48,39 +47,71 @@ def get_ai_document_classification(document: Document) -> dict: {content[:8000]} # Trim to safe size """ - try: - client = AIClient() - result = client.run_llm_query(prompt) - suggestions = parse_ai_classification_response(result) - return suggestions or {} - except Exception: - logger.exception("Error during LLM classification: %s", exc_info=True) - return {} + return prompt -def parse_ai_classification_response(text: str) -> dict: - """ - Parses LLM output and ensures it conforms to expected schema. +def build_prompt_with_rag(document: Document) -> str: + context = get_context_for_document(document) + content = document.content or "" + filename = document.filename or "" + + prompt = f""" + You are a helpful assistant that extracts structured information from documents. + You have access to similar documents as context to help improve suggestions. + + Only output valid JSON in the format below. No additional explanations. + + The JSON object must contain: + - title: A short, descriptive title + - tags: A list of relevant topics + - correspondents: People or organizations involved + - document_types: Type or category of the document + - storage_paths: Suggested folder paths + - dates: Up to 3 relevant dates in YYYY-MM-DD + + Here is an example document: + FILENAME: + {filename} + + CONTENT: + {content[:4000]} + + CONTEXT FROM SIMILAR DOCUMENTS: + {context[:4000]} """ + + return prompt + + +def parse_ai_response(text: str) -> dict: try: raw = json.loads(text) return { "title": raw.get("title"), "tags": raw.get("tags", []), - "correspondents": [raw["correspondents"]] - if isinstance(raw.get("correspondents"), str) - else raw.get("correspondents", []), - "document_types": [raw["document_types"]] - if isinstance(raw.get("document_types"), str) - else raw.get("document_types", []), + "correspondents": raw.get("correspondents", []), + "document_types": raw.get("document_types", []), "storage_paths": raw.get("storage_paths", []), - "dates": [d for d in raw.get("dates", []) if d], + "dates": raw.get("dates", []), } except json.JSONDecodeError: - # fallback: try to extract JSON manually? - logger.exception( - "Failed to parse LLM classification response: %s", - text, - exc_info=True, - ) + logger.exception("Invalid JSON in RAG response") + return {} + + +def get_ai_document_classification(document: Document) -> dict: + ai_config = AIConfig() + + prompt = ( + build_prompt_with_rag(document) + if ai_config.llm_embedding_backend + else build_prompt_without_rag(document) + ) + + try: + client = AIClient() + result = client.run_llm_query(prompt) + return parse_ai_response(result) + except Exception: + logger.exception("Failed AI classification") return {} diff --git a/src/paperless/ai/rag.py b/src/paperless/ai/rag.py new file mode 100644 index 000000000..9b5baf425 --- /dev/null +++ b/src/paperless/ai/rag.py @@ -0,0 +1,12 @@ +from documents.models import Document +from paperless.ai.indexing import query_similar_documents + + +def get_context_for_document(doc: Document, max_docs: int = 5) -> str: + similar_docs = query_similar_documents(doc)[:max_docs] + context_blocks = [] + for similar in similar_docs: + text = similar.content or "" + title = similar.title or similar.filename or "Untitled" + context_blocks.append(f"TITLE: {title}\n{text}") + return "\n\n".join(context_blocks)