mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-16 17:25:11 -05:00
Compare commits
6 Commits
b94912a392
...
da2ac19193
Author | SHA1 | Date | |
---|---|---|---|
![]() |
da2ac19193 | ||
![]() |
3583470856 | ||
![]() |
5bfbe856a6 | ||
![]() |
20bae4bd41 | ||
![]() |
fa496dfc8d | ||
![]() |
924471b59c |
@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
|
|||||||
|
|
||||||
customFields: CustomField[] = []
|
customFields: CustomField[] = []
|
||||||
|
|
||||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super()
|
super()
|
||||||
|
@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy {
|
|||||||
@Input()
|
@Input()
|
||||||
placement: string = 'bottom-start'
|
placement: string = 'bottom-start'
|
||||||
|
|
||||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||||
|
|
||||||
get isActive(): boolean {
|
get isActive(): boolean {
|
||||||
return (
|
return (
|
||||||
|
@ -59,7 +59,7 @@ export class DateComponent
|
|||||||
@Output()
|
@Output()
|
||||||
filterDocuments = new EventEmitter<NgbDateStruct[]>()
|
filterDocuments = new EventEmitter<NgbDateStruct[]>()
|
||||||
|
|
||||||
public readonly today: string = new Date().toISOString().split('T')[0]
|
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||||
|
|
||||||
getSuggestions() {
|
getSuggestions() {
|
||||||
return this.suggestions == null
|
return this.suggestions == null
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from llama_index.core.base.llms.types import CompletionResponse
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
@ -18,58 +16,34 @@ def build_prompt_without_rag(document: Document) -> str:
|
|||||||
filename = document.filename or ""
|
filename = document.filename or ""
|
||||||
content = truncate_content(document.content[:4000] or "")
|
content = truncate_content(document.content[:4000] or "")
|
||||||
|
|
||||||
prompt = f"""
|
return f"""
|
||||||
You are an assistant that extracts structured information from documents.
|
You are a document classification assistant.
|
||||||
Only respond with the JSON object as described below.
|
|
||||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
|
||||||
Suggested tags and document types must be strictly based on the content of the document.
|
|
||||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
|
||||||
Each field must be a list of plain strings.
|
|
||||||
|
|
||||||
The JSON object must contain the following fields:
|
Analyze the following document and extract the following information:
|
||||||
- title: A short, descriptive title
|
- A short descriptive title
|
||||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
- Tags that reflect the content
|
||||||
- correspondents: A list of names or organizations mentioned in the document
|
- Names of people or organizations mentioned
|
||||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
- The type or category of the document
|
||||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
- Suggested folder paths for storing the document
|
||||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
- Up to 3 relevant dates in YYYY-MM-DD format
|
||||||
|
|
||||||
The format of the JSON object is as follows:
|
Filename:
|
||||||
{{
|
|
||||||
"title": "xxxxx",
|
|
||||||
"tags": ["xxxx", "xxxx"],
|
|
||||||
"correspondents": ["xxxx", "xxxx"],
|
|
||||||
"document_types": ["xxxx", "xxxx"],
|
|
||||||
"storage_paths": ["xxxx", "xxxx"],
|
|
||||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
|
||||||
}}
|
|
||||||
---------
|
|
||||||
|
|
||||||
FILENAME:
|
|
||||||
{filename}
|
{filename}
|
||||||
|
|
||||||
CONTENT:
|
Content:
|
||||||
{content}
|
{content}
|
||||||
"""
|
""".strip()
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||||
|
base_prompt = build_prompt_without_rag(document)
|
||||||
context = truncate_content(get_context_for_document(document, user))
|
context = truncate_content(get_context_for_document(document, user))
|
||||||
prompt = build_prompt_without_rag(document)
|
|
||||||
|
|
||||||
prompt += f"""
|
return f"""{base_prompt}
|
||||||
|
|
||||||
CONTEXT FROM SIMILAR DOCUMENTS:
|
Additional context from similar documents:
|
||||||
{context}
|
{context}
|
||||||
|
""".strip()
|
||||||
---------
|
|
||||||
|
|
||||||
DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_context_for_document(
|
def get_context_for_document(
|
||||||
@ -100,36 +74,15 @@ def get_context_for_document(
|
|||||||
return "\n\n".join(context_blocks)
|
return "\n\n".join(context_blocks)
|
||||||
|
|
||||||
|
|
||||||
def parse_ai_response(response: CompletionResponse) -> dict:
|
def parse_ai_response(raw: dict) -> dict:
|
||||||
try:
|
return {
|
||||||
raw = json.loads(response.text)
|
"title": raw.get("title", ""),
|
||||||
return {
|
"tags": raw.get("tags", []),
|
||||||
"title": raw.get("title"),
|
"correspondents": raw.get("correspondents", []),
|
||||||
"tags": raw.get("tags", []),
|
"document_types": raw.get("document_types", []),
|
||||||
"correspondents": raw.get("correspondents", []),
|
"storage_paths": raw.get("storage_paths", []),
|
||||||
"document_types": raw.get("document_types", []),
|
"dates": raw.get("dates", []),
|
||||||
"storage_paths": raw.get("storage_paths", []),
|
}
|
||||||
"dates": raw.get("dates", []),
|
|
||||||
}
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("Invalid JSON in AI response, attempting modified parsing...")
|
|
||||||
try:
|
|
||||||
# search for a valid json string like { ... } in the response
|
|
||||||
start = response.text.index("{")
|
|
||||||
end = response.text.rindex("}") + 1
|
|
||||||
json_str = response.text[start:end]
|
|
||||||
raw = json.loads(json_str)
|
|
||||||
return {
|
|
||||||
"title": raw.get("title"),
|
|
||||||
"tags": raw.get("tags", []),
|
|
||||||
"correspondents": raw.get("correspondents", []),
|
|
||||||
"document_types": raw.get("document_types", []),
|
|
||||||
"storage_paths": raw.get("storage_paths", []),
|
|
||||||
"dates": raw.get("dates", []),
|
|
||||||
}
|
|
||||||
except (ValueError, json.JSONDecodeError):
|
|
||||||
logger.exception("Failed to parse AI response")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_ai_document_classification(
|
def get_ai_document_classification(
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from llama_index.core.llms import ChatMessage
|
from llama_index.core.llms import ChatMessage
|
||||||
|
from llama_index.core.program.function_program import get_function_tool
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
from llama_index.llms.openai import OpenAI
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
from paperless_ai.tools import DocumentClassifierSchema
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.client")
|
logger = logging.getLogger("paperless_ai.client")
|
||||||
|
|
||||||
@ -18,7 +20,7 @@ class AIClient:
|
|||||||
self.settings = AIConfig()
|
self.settings = AIConfig()
|
||||||
self.llm = self.get_llm()
|
self.llm = self.get_llm()
|
||||||
|
|
||||||
def get_llm(self):
|
def get_llm(self) -> Ollama | OpenAI:
|
||||||
if self.settings.llm_backend == "ollama":
|
if self.settings.llm_backend == "ollama":
|
||||||
return Ollama(
|
return Ollama(
|
||||||
model=self.settings.llm_model or "llama3",
|
model=self.settings.llm_model or "llama3",
|
||||||
@ -39,9 +41,21 @@ class AIClient:
|
|||||||
self.settings.llm_backend,
|
self.settings.llm_backend,
|
||||||
self.settings.llm_model,
|
self.settings.llm_model,
|
||||||
)
|
)
|
||||||
result = self.llm.complete(prompt)
|
|
||||||
logger.debug("LLM query result: %s", result)
|
user_msg = ChatMessage(role="user", content=prompt)
|
||||||
return result
|
tool = get_function_tool(DocumentClassifierSchema)
|
||||||
|
result = self.llm.chat_with_tools(
|
||||||
|
tools=[tool],
|
||||||
|
user_msg=user_msg,
|
||||||
|
chat_history=[],
|
||||||
|
)
|
||||||
|
tool_calls = self.llm.get_tool_calls_from_response(
|
||||||
|
result,
|
||||||
|
error_on_no_tool_calls=True,
|
||||||
|
)
|
||||||
|
logger.debug("LLM query result: %s", tool_calls)
|
||||||
|
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||||
|
return parsed.model_dump()
|
||||||
|
|
||||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
|
|||||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless_ai.ai_classifier import get_context_for_document
|
from paperless_ai.ai_classifier import get_context_for_document
|
||||||
from paperless_ai.ai_classifier import parse_ai_response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -75,50 +74,14 @@ def mock_similar_documents():
|
|||||||
LLM_MODEL="some_model",
|
LLM_MODEL="some_model",
|
||||||
)
|
)
|
||||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||||
mock_run_llm_query.return_value.text = json.dumps(
|
mock_run_llm_query.return_value = {
|
||||||
{
|
|
||||||
"title": "Test Title",
|
|
||||||
"tags": ["test", "document"],
|
|
||||||
"correspondents": ["John Doe"],
|
|
||||||
"document_types": ["report"],
|
|
||||||
"storage_paths": ["Reports"],
|
|
||||||
"dates": ["2023-01-01"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
result = get_ai_document_classification(mock_document)
|
|
||||||
|
|
||||||
assert result["title"] == "Test Title"
|
|
||||||
assert result["tags"] == ["test", "document"]
|
|
||||||
assert result["correspondents"] == ["John Doe"]
|
|
||||||
assert result["document_types"] == ["report"]
|
|
||||||
assert result["storage_paths"] == ["Reports"]
|
|
||||||
assert result["dates"] == ["2023-01-01"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
|
||||||
@override_settings(
|
|
||||||
LLM_BACKEND="ollama",
|
|
||||||
LLM_MODEL="some_model",
|
|
||||||
)
|
|
||||||
def test_get_ai_document_classification_fallback_parse_success(
|
|
||||||
mock_run_llm_query,
|
|
||||||
mock_document,
|
|
||||||
):
|
|
||||||
mock_run_llm_query.return_value.text = """
|
|
||||||
There is some text before the JSON.
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"title": "Test Title",
|
"title": "Test Title",
|
||||||
"tags": ["test", "document"],
|
"tags": ["test", "document"],
|
||||||
"correspondents": ["John Doe"],
|
"correspondents": ["John Doe"],
|
||||||
"document_types": ["report"],
|
"document_types": ["report"],
|
||||||
"storage_paths": ["Reports"],
|
"storage_paths": ["Reports"],
|
||||||
"dates": ["2023-01-01"]
|
"dates": ["2023-01-01"],
|
||||||
}
|
}
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = get_ai_document_classification(mock_document)
|
result = get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
|
|||||||
assert result["dates"] == ["2023-01-01"]
|
assert result["dates"] == ["2023-01-01"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
|
||||||
@override_settings(
|
|
||||||
LLM_BACKEND="ollama",
|
|
||||||
LLM_MODEL="some_model",
|
|
||||||
)
|
|
||||||
def test_get_ai_document_classification_parse_failure(
|
|
||||||
mock_run_llm_query,
|
|
||||||
mock_document,
|
|
||||||
):
|
|
||||||
mock_run_llm_query.return_value.text = "Invalid JSON response"
|
|
||||||
|
|
||||||
result = get_ai_document_classification(mock_document)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||||
@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
|
|||||||
get_ai_document_classification(mock_document)
|
get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_llm_classification_response_invalid_json():
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.text = "Invalid JSON response"
|
|
||||||
|
|
||||||
result = parse_ai_response(mock_response)
|
|
||||||
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||||
@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
|
|||||||
return_value="Context from similar documents",
|
return_value="Context from similar documents",
|
||||||
):
|
):
|
||||||
prompt = build_prompt_without_rag(mock_document)
|
prompt = build_prompt_without_rag(mock_document)
|
||||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
assert "Additional context from similar documents:" not in prompt
|
||||||
|
|
||||||
prompt = build_prompt_with_rag(mock_document)
|
prompt = build_prompt_with_rag(mock_document)
|
||||||
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|
assert "Additional context from similar documents:" in prompt
|
||||||
|
|
||||||
|
|
||||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||||
|
@ -3,6 +3,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_index.core.llms import ChatMessage
|
from llama_index.core.llms import ChatMessage
|
||||||
|
from llama_index.core.llms.llm import ToolSelection
|
||||||
|
|
||||||
from paperless_ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
|
|
||||||
@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
|||||||
mock_ai_config.llm_url = "http://test-url"
|
mock_ai_config.llm_url = "http://test-url"
|
||||||
|
|
||||||
mock_llm_instance = mock_ollama_llm.return_value
|
mock_llm_instance = mock_ollama_llm.return_value
|
||||||
mock_llm_instance.complete.return_value = "test_result"
|
|
||||||
|
tool_selection = ToolSelection(
|
||||||
|
tool_id="call_test",
|
||||||
|
tool_name="DocumentClassifierSchema",
|
||||||
|
tool_kwargs={
|
||||||
|
"title": "Test Title",
|
||||||
|
"tags": ["test", "document"],
|
||||||
|
"correspondents": ["John Doe"],
|
||||||
|
"document_types": ["report"],
|
||||||
|
"storage_paths": ["Reports"],
|
||||||
|
"dates": ["2023-01-01"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llm_instance.chat_with_tools.return_value = MagicMock()
|
||||||
|
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
|
||||||
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query("test_prompt")
|
result = client.run_llm_query("test_prompt")
|
||||||
|
|
||||||
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
assert result["title"] == "Test Title"
|
||||||
assert result == "test_result"
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||||
|
10
src/paperless_ai/tools.py
Normal file
10
src/paperless_ai/tools.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from llama_index.core.bridge.pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentClassifierSchema(BaseModel):
|
||||||
|
title: str
|
||||||
|
tags: list[str]
|
||||||
|
correspondents: list[str]
|
||||||
|
document_types: list[str]
|
||||||
|
storage_paths: list[str]
|
||||||
|
dates: list[str]
|
Loading…
x
Reference in New Issue
Block a user