mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-16 17:25:11 -05:00
Compare commits
6 Commits
b94912a392
...
da2ac19193
Author | SHA1 | Date | |
---|---|---|---|
![]() |
da2ac19193 | ||
![]() |
3583470856 | ||
![]() |
5bfbe856a6 | ||
![]() |
20bae4bd41 | ||
![]() |
fa496dfc8d | ||
![]() |
924471b59c |
@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
|
||||
|
||||
customFields: CustomField[] = []
|
||||
|
||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
||||
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
|
@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy {
|
||||
@Input()
|
||||
placement: string = 'bottom-start'
|
||||
|
||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
||||
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||
|
||||
get isActive(): boolean {
|
||||
return (
|
||||
|
@ -59,7 +59,7 @@ export class DateComponent
|
||||
@Output()
|
||||
filterDocuments = new EventEmitter<NgbDateStruct[]>()
|
||||
|
||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
||||
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||
|
||||
getSuggestions() {
|
||||
return this.suggestions == null
|
||||
|
@ -1,8 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from llama_index.core.base.llms.types import CompletionResponse
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
@ -18,58 +16,34 @@ def build_prompt_without_rag(document: Document) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
|
||||
prompt = f"""
|
||||
You are an assistant that extracts structured information from documents.
|
||||
Only respond with the JSON object as described below.
|
||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
||||
Suggested tags and document types must be strictly based on the content of the document.
|
||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
||||
Each field must be a list of plain strings.
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
The JSON object must contain the following fields:
|
||||
- title: A short, descriptive title
|
||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
||||
- correspondents: A list of names or organizations mentioned in the document
|
||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||
Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
- The type or category of the document
|
||||
- Suggested folder paths for storing the document
|
||||
- Up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
"tags": ["xxxx", "xxxx"],
|
||||
"correspondents": ["xxxx", "xxxx"],
|
||||
"document_types": ["xxxx", "xxxx"],
|
||||
"storage_paths": ["xxxx", "xxxx"],
|
||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||
}}
|
||||
---------
|
||||
|
||||
FILENAME:
|
||||
Filename:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
Content:
|
||||
{content}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
prompt = build_prompt_without_rag(document)
|
||||
|
||||
prompt += f"""
|
||||
return f"""{base_prompt}
|
||||
|
||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
||||
Additional context from similar documents:
|
||||
{context}
|
||||
|
||||
---------
|
||||
|
||||
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_context_for_document(
|
||||
@ -100,36 +74,15 @@ def get_context_for_document(
|
||||
return "\n\n".join(context_blocks)
|
||||
|
||||
|
||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
||||
try:
|
||||
raw = json.loads(response.text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
||||
try:
|
||||
# search for a valid json string like { ... } in the response
|
||||
start = response.text.index("{")
|
||||
end = response.text.rindex("}") + 1
|
||||
json_str = response.text[start:end]
|
||||
raw = json.loads(json_str)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
logger.exception("Failed to parse AI response")
|
||||
return {}
|
||||
def parse_ai_response(raw: dict) -> dict:
|
||||
return {
|
||||
"title": raw.get("title", ""),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
|
||||
|
||||
def get_ai_document_classification(
|
||||
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.tools import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
@ -18,7 +20,7 @@ class AIClient:
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self):
|
||||
def get_llm(self) -> Ollama | OpenAI:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
@ -39,9 +41,21 @@ class AIClient:
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.complete(prompt)
|
||||
logger.debug("LLM query result: %s", result)
|
||||
return result
|
||||
|
||||
user_msg = ChatMessage(role="user", content=prompt)
|
||||
tool = get_function_tool(DocumentClassifierSchema)
|
||||
result = self.llm.chat_with_tools(
|
||||
tools=[tool],
|
||||
user_msg=user_msg,
|
||||
chat_history=[],
|
||||
)
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
result,
|
||||
error_on_no_tool_calls=True,
|
||||
)
|
||||
logger.debug("LLM query result: %s", tool_calls)
|
||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||
return parsed.model_dump()
|
||||
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
logger.debug(
|
||||
|
@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
from paperless_ai.ai_classifier import parse_ai_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -75,50 +74,14 @@ def mock_similar_documents():
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value.text = json.dumps(
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["report"]
|
||||
assert result["storage_paths"] == ["Reports"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_fallback_parse_success(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = """
|
||||
There is some text before the JSON.
|
||||
```json
|
||||
{
|
||||
mock_run_llm_query.return_value = {
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"]
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_parse_failure(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.return_value.text = "Invalid JSON response"
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
|
||||
get_ai_document_classification(mock_document)
|
||||
|
||||
|
||||
def test_parse_llm_classification_response_invalid_json():
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Invalid JSON response"
|
||||
|
||||
result = parse_ai_response(mock_response)
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
||||
assert "Additional context from similar documents:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
||||
assert "Additional context from similar documents:" in prompt
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
|
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.llms.llm import ToolSelection
|
||||
|
||||
from paperless_ai.client import AIClient
|
||||
|
||||
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.complete.return_value = "test_result"
|
||||
|
||||
tool_selection = ToolSelection(
|
||||
tool_id="call_test",
|
||||
tool_name="DocumentClassifierSchema",
|
||||
tool_kwargs={
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
mock_llm_instance.chat_with_tools.return_value = MagicMock()
|
||||
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("test_prompt")
|
||||
|
||||
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
||||
assert result == "test_result"
|
||||
assert result["title"] == "Test Title"
|
||||
|
||||
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
|
10
src/paperless_ai/tools.py
Normal file
10
src/paperless_ai/tools.py
Normal file
@ -0,0 +1,10 @@
|
||||
from llama_index.core.bridge.pydantic import BaseModel
|
||||
|
||||
|
||||
class DocumentClassifierSchema(BaseModel):
|
||||
title: str
|
||||
tags: list[str]
|
||||
correspondents: list[str]
|
||||
document_types: list[str]
|
||||
storage_paths: list[str]
|
||||
dates: list[str]
|
Loading…
x
Reference in New Issue
Block a user