Compare commits

..

No commits in common. "da2ac19193fb5006cdd407477999d0608c533bcb" and "b94912a392fd261069c526ee3dacac441f9f3375" have entirely different histories.

8 changed files with 154 additions and 84 deletions

View File

@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
customFields: CustomField[] = []
public readonly today: string = new Date().toLocaleDateString('en-CA')
public readonly today: string = new Date().toISOString().split('T')[0]
constructor() {
super()

View File

@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy {
@Input()
placement: string = 'bottom-start'
public readonly today: string = new Date().toLocaleDateString('en-CA')
public readonly today: string = new Date().toISOString().split('T')[0]
get isActive(): boolean {
return (

View File

@ -59,7 +59,7 @@ export class DateComponent
@Output()
filterDocuments = new EventEmitter<NgbDateStruct[]>()
public readonly today: string = new Date().toLocaleDateString('en-CA')
public readonly today: string = new Date().toISOString().split('T')[0]
getSuggestions() {
return this.suggestions == null

View File

@ -1,6 +1,8 @@
import json
import logging
from django.contrib.auth.models import User
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
@ -16,34 +18,58 @@ def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
return f"""
You are a document classification assistant.
prompt = f"""
You are an assistant that extracts structured information from documents.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
Each field must be a list of plain strings.
Analyze the following document and extract the following information:
- A short descriptive title
- Tags that reflect the content
- Names of people or organizations mentioned
- The type or category of the document
- Suggested folder paths for storing the document
- Up to 3 relevant dates in YYYY-MM-DD format
The JSON object must contain the following fields:
- title: A short, descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
- correspondents: A list of names or organizations mentioned in the document
- document_types: The type/category of the document (e.g. "invoice", "medical record")
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format
Filename:
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---------
FILENAME:
{filename}
Content:
CONTENT:
{content}
""".strip()
"""
return prompt
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
prompt = build_prompt_without_rag(document)
return f"""{base_prompt}
prompt += f"""
Additional context from similar documents:
CONTEXT FROM SIMILAR DOCUMENTS:
{context}
""".strip()
---------
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
"""
return prompt
def get_context_for_document(
@ -74,15 +100,36 @@ def get_context_for_document(
return "\n\n".join(context_blocks)
def parse_ai_response(raw: dict) -> dict:
return {
"title": raw.get("title", ""),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
def parse_ai_response(response: CompletionResponse) -> dict:
try:
raw = json.loads(response.text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
try:
# search for a valid json string like { ... } in the response
start = response.text.index("{")
end = response.text.rindex("}") + 1
json_str = response.text[start:end]
raw = json.loads(json_str)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except (ValueError, json.JSONDecodeError):
logger.exception("Failed to parse AI response")
return {}
def get_ai_document_classification(

View File

@ -1,12 +1,10 @@
import logging
from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client")
@ -20,7 +18,7 @@ class AIClient:
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self) -> Ollama | OpenAI:
def get_llm(self):
if self.settings.llm_backend == "ollama":
return Ollama(
model=self.settings.llm_model or "llama3",
@ -41,21 +39,9 @@ class AIClient:
self.settings.llm_backend,
self.settings.llm_model,
)
user_msg = ChatMessage(role="user", content=prompt)
tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_calls=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result)
return result
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(

View File

@ -10,6 +10,7 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
@ -74,14 +75,16 @@ def mock_similar_documents():
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
}
mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
@ -93,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_fallback_parse_success(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = """
There is some text before the JSON.
```json
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"]
}
```
"""
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_parse_failure(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.return_value.text = "Invalid JSON response"
result = get_ai_document_classification(mock_document)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@ -103,6 +156,15 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@ -156,10 +218,10 @@ def test_prompt_with_without_rag(mock_document):
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "Additional context from similar documents:" not in prompt
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "Additional context from similar documents:" in prompt
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents")

View File

@ -3,7 +3,6 @@ from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient
@ -70,27 +69,13 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
tool_kwargs={
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
mock_llm_instance.complete.return_value = "test_result"
client = AIClient()
result = client.run_llm_query("test_prompt")
assert result["title"] == "Test Title"
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm):

View File

@ -1,10 +0,0 @@
from llama_index.core.bridge.pydantic import BaseModel
class DocumentClassifierSchema(BaseModel):
title: str
tags: list[str]
correspondents: list[str]
document_types: list[str]
storage_paths: list[str]
dates: list[str]