mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Compare commits
	
		
			6 Commits
		
	
	
		
			b94912a392
			...
			da2ac19193
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | da2ac19193 | ||
|   | 3583470856 | ||
|   | 5bfbe856a6 | ||
|   | 20bae4bd41 | ||
|   | fa496dfc8d | ||
|   | 924471b59c | 
| @@ -246,7 +246,7 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm | ||||
|  | ||||
|   customFields: CustomField[] = [] | ||||
|  | ||||
|   public readonly today: string = new Date().toISOString().split('T')[0] | ||||
|   public readonly today: string = new Date().toLocaleDateString('en-CA') | ||||
|  | ||||
|   constructor() { | ||||
|     super() | ||||
|   | ||||
| @@ -165,7 +165,7 @@ export class DatesDropdownComponent implements OnInit, OnDestroy { | ||||
|   @Input() | ||||
|   placement: string = 'bottom-start' | ||||
|  | ||||
|   public readonly today: string = new Date().toISOString().split('T')[0] | ||||
|   public readonly today: string = new Date().toLocaleDateString('en-CA') | ||||
|  | ||||
|   get isActive(): boolean { | ||||
|     return ( | ||||
|   | ||||
| @@ -59,7 +59,7 @@ export class DateComponent | ||||
|   @Output() | ||||
|   filterDocuments = new EventEmitter<NgbDateStruct[]>() | ||||
|  | ||||
|   public readonly today: string = new Date().toISOString().split('T')[0] | ||||
|   public readonly today: string = new Date().toLocaleDateString('en-CA') | ||||
|  | ||||
|   getSuggestions() { | ||||
|     return this.suggestions == null | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| import json | ||||
| import logging | ||||
|  | ||||
| from django.contrib.auth.models import User | ||||
| from llama_index.core.base.llms.types import CompletionResponse | ||||
|  | ||||
| from documents.models import Document | ||||
| from documents.permissions import get_objects_for_user_owner_aware | ||||
| @@ -18,58 +16,34 @@ def build_prompt_without_rag(document: Document) -> str: | ||||
|     filename = document.filename or "" | ||||
|     content = truncate_content(document.content[:4000] or "") | ||||
|  | ||||
|     prompt = f""" | ||||
|     You are an assistant that extracts structured information from documents. | ||||
|     Only respond with the JSON object as described below. | ||||
|     Never ask for further information, additional content or ask questions. Never include any other text. | ||||
|     Suggested tags and document types must be strictly based on the content of the document. | ||||
|     Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax. | ||||
|     Each field must be a list of plain strings. | ||||
|     return f""" | ||||
|     You are a document classification assistant. | ||||
|  | ||||
|     The JSON object must contain the following fields: | ||||
|     - title: A short, descriptive title | ||||
|     - tags: A list of simple tags like ["insurance", "medical", "receipts"] | ||||
|     - correspondents: A list of names or organizations mentioned in the document | ||||
|     - document_types: The type/category of the document (e.g. "invoice", "medical record") | ||||
|     - storage_paths: Suggested folder paths (e.g. "Medical/Insurance") | ||||
|     - dates: List up to 3 relevant dates in YYYY-MM-DD format | ||||
|     Analyze the following document and extract the following information: | ||||
|     - A short descriptive title | ||||
|     - Tags that reflect the content | ||||
|     - Names of people or organizations mentioned | ||||
|     - The type or category of the document | ||||
|     - Suggested folder paths for storing the document | ||||
|     - Up to 3 relevant dates in YYYY-MM-DD format | ||||
|  | ||||
|     The format of the JSON object is as follows: | ||||
|     {{ | ||||
|         "title": "xxxxx", | ||||
|         "tags": ["xxxx", "xxxx"], | ||||
|         "correspondents": ["xxxx", "xxxx"], | ||||
|         "document_types": ["xxxx", "xxxx"], | ||||
|         "storage_paths": ["xxxx", "xxxx"], | ||||
|         "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"], | ||||
|     }} | ||||
|     --------- | ||||
|  | ||||
|     FILENAME: | ||||
|     Filename: | ||||
|     {filename} | ||||
|  | ||||
|     CONTENT: | ||||
|     Content: | ||||
|     {content} | ||||
|     """ | ||||
|  | ||||
|     return prompt | ||||
|     """.strip() | ||||
|  | ||||
|  | ||||
| def build_prompt_with_rag(document: Document, user: User | None = None) -> str: | ||||
|     base_prompt = build_prompt_without_rag(document) | ||||
|     context = truncate_content(get_context_for_document(document, user)) | ||||
|     prompt = build_prompt_without_rag(document) | ||||
|  | ||||
|     prompt += f""" | ||||
|     return f"""{base_prompt} | ||||
|  | ||||
|     CONTEXT FROM SIMILAR DOCUMENTS: | ||||
|     Additional context from similar documents: | ||||
|     {context} | ||||
|  | ||||
|     --------- | ||||
|  | ||||
|     DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT. | ||||
|     """ | ||||
|  | ||||
|     return prompt | ||||
|     """.strip() | ||||
|  | ||||
|  | ||||
| def get_context_for_document( | ||||
| @@ -100,36 +74,15 @@ def get_context_for_document( | ||||
|     return "\n\n".join(context_blocks) | ||||
|  | ||||
|  | ||||
| def parse_ai_response(response: CompletionResponse) -> dict: | ||||
|     try: | ||||
|         raw = json.loads(response.text) | ||||
|         return { | ||||
|             "title": raw.get("title"), | ||||
|             "tags": raw.get("tags", []), | ||||
|             "correspondents": raw.get("correspondents", []), | ||||
|             "document_types": raw.get("document_types", []), | ||||
|             "storage_paths": raw.get("storage_paths", []), | ||||
|             "dates": raw.get("dates", []), | ||||
|         } | ||||
|     except json.JSONDecodeError: | ||||
|         logger.warning("Invalid JSON in AI response, attempting modified parsing...") | ||||
|         try: | ||||
|             # search for a valid json string like { ... } in the response | ||||
|             start = response.text.index("{") | ||||
|             end = response.text.rindex("}") + 1 | ||||
|             json_str = response.text[start:end] | ||||
|             raw = json.loads(json_str) | ||||
|             return { | ||||
|                 "title": raw.get("title"), | ||||
|                 "tags": raw.get("tags", []), | ||||
|                 "correspondents": raw.get("correspondents", []), | ||||
|                 "document_types": raw.get("document_types", []), | ||||
|                 "storage_paths": raw.get("storage_paths", []), | ||||
|                 "dates": raw.get("dates", []), | ||||
|             } | ||||
|         except (ValueError, json.JSONDecodeError): | ||||
|             logger.exception("Failed to parse AI response") | ||||
|             return {} | ||||
| def parse_ai_response(raw: dict) -> dict: | ||||
|     return { | ||||
|         "title": raw.get("title", ""), | ||||
|         "tags": raw.get("tags", []), | ||||
|         "correspondents": raw.get("correspondents", []), | ||||
|         "document_types": raw.get("document_types", []), | ||||
|         "storage_paths": raw.get("storage_paths", []), | ||||
|         "dates": raw.get("dates", []), | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_ai_document_classification( | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| import logging | ||||
|  | ||||
| from llama_index.core.llms import ChatMessage | ||||
| from llama_index.core.program.function_program import get_function_tool | ||||
| from llama_index.llms.ollama import Ollama | ||||
| from llama_index.llms.openai import OpenAI | ||||
|  | ||||
| from paperless.config import AIConfig | ||||
| from paperless_ai.tools import DocumentClassifierSchema | ||||
|  | ||||
| logger = logging.getLogger("paperless_ai.client") | ||||
|  | ||||
| @@ -18,7 +20,7 @@ class AIClient: | ||||
|         self.settings = AIConfig() | ||||
|         self.llm = self.get_llm() | ||||
|  | ||||
|     def get_llm(self): | ||||
|     def get_llm(self) -> Ollama | OpenAI: | ||||
|         if self.settings.llm_backend == "ollama": | ||||
|             return Ollama( | ||||
|                 model=self.settings.llm_model or "llama3", | ||||
| @@ -39,9 +41,21 @@ class AIClient: | ||||
|             self.settings.llm_backend, | ||||
|             self.settings.llm_model, | ||||
|         ) | ||||
|         result = self.llm.complete(prompt) | ||||
|         logger.debug("LLM query result: %s", result) | ||||
|         return result | ||||
|  | ||||
|         user_msg = ChatMessage(role="user", content=prompt) | ||||
|         tool = get_function_tool(DocumentClassifierSchema) | ||||
|         result = self.llm.chat_with_tools( | ||||
|             tools=[tool], | ||||
|             user_msg=user_msg, | ||||
|             chat_history=[], | ||||
|         ) | ||||
|         tool_calls = self.llm.get_tool_calls_from_response( | ||||
|             result, | ||||
|             error_on_no_tool_calls=True, | ||||
|         ) | ||||
|         logger.debug("LLM query result: %s", tool_calls) | ||||
|         parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs) | ||||
|         return parsed.model_dump() | ||||
|  | ||||
|     def run_chat(self, messages: list[ChatMessage]) -> str: | ||||
|         logger.debug( | ||||
|   | ||||
| @@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag | ||||
| from paperless_ai.ai_classifier import build_prompt_without_rag | ||||
| from paperless_ai.ai_classifier import get_ai_document_classification | ||||
| from paperless_ai.ai_classifier import get_context_for_document | ||||
| from paperless_ai.ai_classifier import parse_ai_response | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @@ -75,50 +74,14 @@ def mock_similar_documents(): | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): | ||||
|     mock_run_llm_query.return_value.text = json.dumps( | ||||
|         { | ||||
|             "title": "Test Title", | ||||
|             "tags": ["test", "document"], | ||||
|             "correspondents": ["John Doe"], | ||||
|             "document_types": ["report"], | ||||
|             "storage_paths": ["Reports"], | ||||
|             "dates": ["2023-01-01"], | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|     result = get_ai_document_classification(mock_document) | ||||
|  | ||||
|     assert result["title"] == "Test Title" | ||||
|     assert result["tags"] == ["test", "document"] | ||||
|     assert result["correspondents"] == ["John Doe"] | ||||
|     assert result["document_types"] == ["report"] | ||||
|     assert result["storage_paths"] == ["Reports"] | ||||
|     assert result["dates"] == ["2023-01-01"] | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @override_settings( | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_get_ai_document_classification_fallback_parse_success( | ||||
|     mock_run_llm_query, | ||||
|     mock_document, | ||||
| ): | ||||
|     mock_run_llm_query.return_value.text = """ | ||||
|     There is some text before the JSON. | ||||
|     ```json | ||||
|     { | ||||
|     mock_run_llm_query.return_value = { | ||||
|         "title": "Test Title", | ||||
|         "tags": ["test", "document"], | ||||
|         "correspondents": ["John Doe"], | ||||
|         "document_types": ["report"], | ||||
|         "storage_paths": ["Reports"], | ||||
|         "dates": ["2023-01-01"] | ||||
|         "dates": ["2023-01-01"], | ||||
|     } | ||||
|     ``` | ||||
|     """ | ||||
|  | ||||
|     result = get_ai_document_classification(mock_document) | ||||
|  | ||||
| @@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success( | ||||
|     assert result["dates"] == ["2023-01-01"] | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @override_settings( | ||||
|     LLM_BACKEND="ollama", | ||||
|     LLM_MODEL="some_model", | ||||
| ) | ||||
| def test_get_ai_document_classification_parse_failure( | ||||
|     mock_run_llm_query, | ||||
|     mock_document, | ||||
| ): | ||||
|     mock_run_llm_query.return_value.text = "Invalid JSON response" | ||||
|  | ||||
|     result = get_ai_document_classification(mock_document) | ||||
|     assert result == {} | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): | ||||
| @@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen | ||||
|         get_ai_document_classification(mock_document) | ||||
|  | ||||
|  | ||||
| def test_parse_llm_classification_response_invalid_json(): | ||||
|     mock_response = MagicMock() | ||||
|     mock_response.text = "Invalid JSON response" | ||||
|  | ||||
|     result = parse_ai_response(mock_response) | ||||
|  | ||||
|     assert result == {} | ||||
|  | ||||
|  | ||||
| @pytest.mark.django_db | ||||
| @patch("paperless_ai.client.AIClient.run_llm_query") | ||||
| @patch("paperless_ai.ai_classifier.build_prompt_with_rag") | ||||
| @@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document): | ||||
|         return_value="Context from similar documents", | ||||
|     ): | ||||
|         prompt = build_prompt_without_rag(mock_document) | ||||
|         assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt | ||||
|         assert "Additional context from similar documents:" not in prompt | ||||
|  | ||||
|         prompt = build_prompt_with_rag(mock_document) | ||||
|         assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt | ||||
|         assert "Additional context from similar documents:" in prompt | ||||
|  | ||||
|  | ||||
| @patch("paperless_ai.ai_classifier.query_similar_documents") | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from unittest.mock import patch | ||||
|  | ||||
| import pytest | ||||
| from llama_index.core.llms import ChatMessage | ||||
| from llama_index.core.llms.llm import ToolSelection | ||||
|  | ||||
| from paperless_ai.client import AIClient | ||||
|  | ||||
| @@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm): | ||||
|     mock_ai_config.llm_url = "http://test-url" | ||||
|  | ||||
|     mock_llm_instance = mock_ollama_llm.return_value | ||||
|     mock_llm_instance.complete.return_value = "test_result" | ||||
|  | ||||
|     tool_selection = ToolSelection( | ||||
|         tool_id="call_test", | ||||
|         tool_name="DocumentClassifierSchema", | ||||
|         tool_kwargs={ | ||||
|             "title": "Test Title", | ||||
|             "tags": ["test", "document"], | ||||
|             "correspondents": ["John Doe"], | ||||
|             "document_types": ["report"], | ||||
|             "storage_paths": ["Reports"], | ||||
|             "dates": ["2023-01-01"], | ||||
|         }, | ||||
|     ) | ||||
|  | ||||
|     mock_llm_instance.chat_with_tools.return_value = MagicMock() | ||||
|     mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection] | ||||
|  | ||||
|     client = AIClient() | ||||
|     result = client.run_llm_query("test_prompt") | ||||
|  | ||||
|     mock_llm_instance.complete.assert_called_once_with("test_prompt") | ||||
|     assert result == "test_result" | ||||
|     assert result["title"] == "Test Title" | ||||
|  | ||||
|  | ||||
| def test_run_chat(mock_ai_config, mock_ollama_llm): | ||||
|   | ||||
							
								
								
									
										10
									
								
								src/paperless_ai/tools.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								src/paperless_ai/tools.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| from llama_index.core.bridge.pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| class DocumentClassifierSchema(BaseModel): | ||||
|     title: str | ||||
|     tags: list[str] | ||||
|     correspondents: list[str] | ||||
|     document_types: list[str] | ||||
|     storage_paths: list[str] | ||||
|     dates: list[str] | ||||
		Reference in New Issue
	
	Block a user