Compare commits

..

6 Commits

Author SHA1 Message Date
shamoon
da2ac19193
Update ai_classifier.py 2025-07-15 14:42:56 -07:00
shamoon
3583470856
Merge branch 'dev' into feature-ai 2025-07-15 14:36:03 -07:00
shamoon
5bfbe856a6
Fix tests for change to structured output 2025-07-15 14:34:54 -07:00
shamoon
20bae4bd41
Move to structured output 2025-07-15 14:27:29 -07:00
shamoon
fa496dfc8d
Fix: also fix frontend date format in other places
See #10369
2025-07-11 00:46:20 -07:00
shamoon
924471b59c
Fix: fix date format for 'today' in DateComponent (#10369) 2025-07-11 00:43:52 -07:00
8 changed files with 78 additions and 148 deletions

View File

@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
customFields: CustomField[] = [] customFields: CustomField[] = []
public readonly today: string = new Date().toISOString().split('T')[0] public readonly today: string = new Date().toLocaleDateString('en-CA')
constructor() { constructor() {
super() super()

View File

@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy {
@Input() @Input()
placement: string = 'bottom-start' placement: string = 'bottom-start'
public readonly today: string = new Date().toISOString().split('T')[0] public readonly today: string = new Date().toLocaleDateString('en-CA')
get isActive(): boolean { get isActive(): boolean {
return ( return (

View File

@ -59,7 +59,7 @@ export class DateComponent
@Output() @Output()
filterDocuments = new EventEmitter<NgbDateStruct[]>() filterDocuments = new EventEmitter<NgbDateStruct[]>()
public readonly today: string = new Date().toISOString().split('T')[0] public readonly today: string = new Date().toLocaleDateString('en-CA')
getSuggestions() { getSuggestions() {
return this.suggestions == null return this.suggestions == null

View File

@ -1,8 +1,6 @@
import json
import logging import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware from documents.permissions import get_objects_for_user_owner_aware
@ -18,58 +16,34 @@ def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or "" filename = document.filename or ""
content = truncate_content(document.content[:4000] or "") content = truncate_content(document.content[:4000] or "")
prompt = f""" return f"""
You are an assistant that extracts structured information from documents. You are a document classification assistant.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
Each field must be a list of plain strings.
The JSON object must contain the following fields: Analyze the following document and extract the following information:
- title: A short, descriptive title - A short descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"] - Tags that reflect the content
- correspondents: A list of names or organizations mentioned in the document - Names of people or organizations mentioned
- document_types: The type/category of the document (e.g. "invoice", "medical record") - The type or category of the document
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance") - Suggested folder paths for storing the document
- dates: List up to 3 relevant dates in YYYY-MM-DD format - Up to 3 relevant dates in YYYY-MM-DD format
The format of the JSON object is as follows: Filename:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---------
FILENAME:
{filename} {filename}
CONTENT: Content:
{content} {content}
""" """.strip()
return prompt
def build_prompt_with_rag(document: Document, user: User | None = None) -> str: def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user)) context = truncate_content(get_context_for_document(document, user))
prompt = build_prompt_without_rag(document)
prompt += f""" return f"""{base_prompt}
CONTEXT FROM SIMILAR DOCUMENTS: Additional context from similar documents:
{context} {context}
""".strip()
---------
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
"""
return prompt
def get_context_for_document( def get_context_for_document(
@ -100,36 +74,15 @@ def get_context_for_document(
return "\n\n".join(context_blocks) return "\n\n".join(context_blocks)
def parse_ai_response(response: CompletionResponse) -> dict: def parse_ai_response(raw: dict) -> dict:
try: return {
raw = json.loads(response.text) "title": raw.get("title", ""),
return { "tags": raw.get("tags", []),
"title": raw.get("title"), "correspondents": raw.get("correspondents", []),
"tags": raw.get("tags", []), "document_types": raw.get("document_types", []),
"correspondents": raw.get("correspondents", []), "storage_paths": raw.get("storage_paths", []),
"document_types": raw.get("document_types", []), "dates": raw.get("dates", []),
"storage_paths": raw.get("storage_paths", []), }
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
try:
# search for a valid json string like { ... } in the response
start = response.text.index("{")
end = response.text.rindex("}") + 1
json_str = response.text[start:end]
raw = json.loads(json_str)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except (ValueError, json.JSONDecodeError):
logger.exception("Failed to parse AI response")
return {}
def get_ai_document_classification( def get_ai_document_classification(

View File

@ -1,10 +1,12 @@
import logging import logging
from llama_index.core.llms import ChatMessage from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client") logger = logging.getLogger("paperless_ai.client")
@ -18,7 +20,7 @@ class AIClient:
self.settings = AIConfig() self.settings = AIConfig()
self.llm = self.get_llm() self.llm = self.get_llm()
def get_llm(self): def get_llm(self) -> Ollama | OpenAI:
if self.settings.llm_backend == "ollama": if self.settings.llm_backend == "ollama":
return Ollama( return Ollama(
model=self.settings.llm_model or "llama3", model=self.settings.llm_model or "llama3",
@ -39,9 +41,21 @@ class AIClient:
self.settings.llm_backend, self.settings.llm_backend,
self.settings.llm_model, self.settings.llm_model,
) )
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result) user_msg = ChatMessage(role="user", content=prompt)
return result tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_calls=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str: def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug( logger.debug(

View File

@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture @pytest.fixture
@ -75,50 +74,14 @@ def mock_similar_documents():
LLM_MODEL="some_model", LLM_MODEL="some_model",
) )
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps( mock_run_llm_query.return_value = {
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
"title": "Test Title", "title": "Test Title",
"tags": ["test", "document"], "tags": ["test", "document"],
"correspondents": ["John Doe"], "correspondents": ["John Doe"],
"document_types": ["report"], "document_types": ["report"],
"storage_paths": ["Reports"], "storage_paths": ["Reports"],
"dates": ["2023-01-01"] "dates": ["2023-01-01"],
} }
```
"""
result = get_ai_document_classification(mock_document) result = get_ai_document_classification(mock_document)
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
assert result["dates"] == ["2023-01-01"] assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
get_ai_document_classification(mock_document) get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag") @patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
return_value="Context from similar documents", return_value="Context from similar documents",
): ):
prompt = build_prompt_without_rag(mock_document) prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document) prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents") @patch("paperless_ai.ai_classifier.query_similar_documents")

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from llama_index.core.llms import ChatMessage from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient from paperless_ai.client import AIClient
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_url = "http://test-url" mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
tool_kwargs={
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
client = AIClient() client = AIClient()
result = client.run_llm_query("test_prompt") result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt") assert result["title"] == "Test Title"
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm): def test_run_chat(mock_ai_config, mock_ollama_llm):

10
src/paperless_ai/tools.py Normal file
View File

@ -0,0 +1,10 @@
from llama_index.core.bridge.pydantic import BaseModel
class DocumentClassifierSchema(BaseModel):
title: str
tags: list[str]
correspondents: list[str]
document_types: list[str]
storage_paths: list[str]
dates: list[str]