Compare commits

..

6 Commits

Author SHA1 Message Date
shamoon
da2ac19193
Update ai_classifier.py 2025-07-15 14:42:56 -07:00
shamoon
3583470856
Merge branch 'dev' into feature-ai 2025-07-15 14:36:03 -07:00
shamoon
5bfbe856a6
Fix tests for change to structured output 2025-07-15 14:34:54 -07:00
shamoon
20bae4bd41
Move to structured output 2025-07-15 14:27:29 -07:00
shamoon
fa496dfc8d
Fix: also fix frontend date format in other places
See #10369
2025-07-11 00:46:20 -07:00
shamoon
924471b59c
Fix: fix date format for 'today' in DateComponent (#10369) 2025-07-11 00:43:52 -07:00
8 changed files with 78 additions and 148 deletions

View File

@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
customFields: CustomField[] = []
public readonly today: string = new Date().toISOString().split('T')[0]
public readonly today: string = new Date().toLocaleDateString('en-CA')
constructor() {
super()

View File

@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy {
@Input()
placement: string = 'bottom-start'
public readonly today: string = new Date().toISOString().split('T')[0]
public readonly today: string = new Date().toLocaleDateString('en-CA')
get isActive(): boolean {
return (

View File

@ -59,7 +59,7 @@ export class DateComponent
@Output()
filterDocuments = new EventEmitter<NgbDateStruct[]>()
public readonly today: string = new Date().toISOString().split('T')[0]
public readonly today: string = new Date().toLocaleDateString('en-CA')
getSuggestions() {
return this.suggestions == null

View File

@ -1,8 +1,6 @@
import json
import logging
from django.contrib.auth.models import User
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
@ -18,58 +16,34 @@ def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
prompt = f"""
You are an assistant that extracts structured information from documents.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
Each field must be a list of plain strings.
return f"""
You are a document classification assistant.
The JSON object must contain the following fields:
- title: A short, descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
- correspondents: A list of names or organizations mentioned in the document
- document_types: The type/category of the document (e.g. "invoice", "medical record")
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format
Analyze the following document and extract the following information:
- A short descriptive title
- Tags that reflect the content
- Names of people or organizations mentioned
- The type or category of the document
- Suggested folder paths for storing the document
- Up to 3 relevant dates in YYYY-MM-DD format
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---------
FILENAME:
Filename:
{filename}
CONTENT:
Content:
{content}
"""
return prompt
""".strip()
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
prompt = build_prompt_without_rag(document)
prompt += f"""
return f"""{base_prompt}
CONTEXT FROM SIMILAR DOCUMENTS:
Additional context from similar documents:
{context}
---------
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
"""
return prompt
""".strip()
def get_context_for_document(
@ -100,36 +74,15 @@ def get_context_for_document(
return "\n\n".join(context_blocks)
def parse_ai_response(response: CompletionResponse) -> dict:
try:
raw = json.loads(response.text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
try:
# search for a valid json string like { ... } in the response
start = response.text.index("{")
end = response.text.rindex("}") + 1
json_str = response.text[start:end]
raw = json.loads(json_str)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except (ValueError, json.JSONDecodeError):
logger.exception("Failed to parse AI response")
return {}
def parse_ai_response(raw: dict) -> dict:
return {
"title": raw.get("title", ""),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
def get_ai_document_classification(

View File

@ -1,10 +1,12 @@
import logging
from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client")
@ -18,7 +20,7 @@ class AIClient:
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self):
def get_llm(self) -> Ollama | OpenAI:
if self.settings.llm_backend == "ollama":
return Ollama(
model=self.settings.llm_model or "llama3",
@ -39,9 +41,21 @@ class AIClient:
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result)
return result
user_msg = ChatMessage(role="user", content=prompt)
tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_calls=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(

View File

@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
@ -75,50 +74,14 @@ def mock_similar_documents():
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
mock_run_llm_query.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
"dates": ["2023-01-01"],
}
```
"""
result = get_ai_document_classification(mock_document)
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents")

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
tool_kwargs={
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
client = AIClient()
result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
assert result["title"] == "Test Title"
def test_run_chat(mock_ai_config, mock_ollama_llm):

10
src/paperless_ai/tools.py Normal file
View File

@ -0,0 +1,10 @@
from llama_index.core.bridge.pydantic import BaseModel
class DocumentClassifierSchema(BaseModel):
title: str
tags: list[str]
correspondents: list[str]
document_types: list[str]
storage_paths: list[str]
dates: list[str]