Feature: Paperless AI (#10319)

This commit is contained in:
shamoon
2026-01-13 08:24:42 -08:00
committed by GitHub
parent 4347ba1f9c
commit e940764fe0
78 changed files with 5429 additions and 106 deletions

View File

@@ -1,6 +1,7 @@
import json
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
from django.contrib.auth.models import User
from django.core.files.uploadedfile import SimpleUploadedFile
@@ -66,6 +67,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"barcode_max_pages": None,
"barcode_enable_tag": None,
"barcode_tag_mapping": None,
"ai_enabled": False,
"llm_embedding_backend": None,
"llm_embedding_model": None,
"llm_backend": None,
"llm_model": None,
"llm_api_key": None,
"llm_endpoint": None,
},
)
@@ -611,3 +619,76 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(ApplicationConfiguration.objects.count(), 1)
def test_update_llm_api_key(self):
"""
GIVEN:
- Existing config with llm_api_key specified
WHEN:
- API to update llm_api_key is called with all *s
- API to update llm_api_key is called with empty string
THEN:
- llm_api_key is unchanged
- llm_api_key is set to None
"""
config = ApplicationConfiguration.objects.first()
config.llm_api_key = "1234567890"
config.save()
# Test with all *
response = self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"llm_api_key": "*" * 32,
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config.refresh_from_db()
self.assertEqual(config.llm_api_key, "1234567890")
# Test with empty string
response = self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"llm_api_key": "",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config.refresh_from_db()
self.assertEqual(config.llm_api_key, None)
def test_enable_ai_index_triggers_update(self):
"""
GIVEN:
- Existing config with AI disabled
WHEN:
- Config is updated to enable AI with llm_embedding_backend
THEN:
- LLM index is triggered to update
"""
config = ApplicationConfiguration.objects.first()
config.ai_enabled = False
config.llm_embedding_backend = None
config.save()
with (
patch("documents.tasks.llmindex_index.delay") as mock_update,
patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"ai_enabled": True,
"llm_embedding_backend": "openai",
},
),
content_type="application/json",
)
mock_update.assert_called_once()

View File

@@ -310,3 +310,69 @@ class TestSystemStatus(APITestCase):
"ERROR",
)
self.assertIsNotNone(response.data["tasks"]["sanity_check_error"])
def test_system_status_ai_disabled(self):
"""
GIVEN:
- The AI feature is disabled
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=False):
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "DISABLED")
self.assertIsNone(response.data["tasks"]["llmindex_error"])
def test_system_status_ai_enabled(self):
"""
GIVEN:
- The AI index feature is enabled, but no tasks are found
- The AI index feature is enabled and a task is found
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
self.client.force_login(self.user)
# No tasks found
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING")
PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK,
status=states.SUCCESS,
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "OK")
self.assertIsNone(response.data["tasks"]["llmindex_error"])
def test_system_status_ai_error(self):
"""
GIVEN:
- The AI index feature is enabled and a task is found with an error
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK,
status=states.FAILURE,
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
result="AI index update failed",
)
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "ERROR")
self.assertIsNotNone(response.data["tasks"]["llmindex_error"])

View File

@@ -49,6 +49,7 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase):
"backend_setting": "default",
},
"email_enabled": False,
"ai_enabled": False,
},
)

View File

@@ -3,14 +3,17 @@ from datetime import timedelta
from pathlib import Path
from unittest import mock
from celery import states
from django.conf import settings
from django.test import TestCase
from django.test import override_settings
from django.utils import timezone
from documents import tasks
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import PaperlessTask
from documents.models import Tag
from documents.sanity_checker import SanityCheckFailedException
from documents.sanity_checker import SanityCheckMessages
@@ -270,3 +273,103 @@ class TestUpdateContent(DirectoriesMixin, TestCase):
tasks.update_document_content_maybe_archive_file(doc.pk)
self.assertNotEqual(Document.objects.get(pk=doc.pk).content, "test")
class TestAIIndex(DirectoriesMixin, TestCase):
@override_settings(
AI_ENABLED=True,
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_ai_index_success(self):
"""
GIVEN:
- Document exists, AI is enabled, llm index backend is set
WHEN:
- llmindex_index task is called
THEN:
- update_llm_index is called, and the task is marked as success
"""
Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
# lazy-loaded so mock the actual function
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
update_llm_index.return_value = "LLM index updated successfully."
tasks.llmindex_index()
update_llm_index.assert_called_once()
task = PaperlessTask.objects.get(
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
self.assertEqual(task.status, states.SUCCESS)
self.assertEqual(task.result, "LLM index updated successfully.")
@override_settings(
AI_ENABLED=True,
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_ai_index_failure(self):
"""
GIVEN:
- Document exists, AI is enabled, llm index backend is set
WHEN:
- llmindex_index task is called
THEN:
- update_llm_index raises an exception, and the task is marked as failure
"""
Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
# lazy-loaded so mock the actual function
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
update_llm_index.side_effect = Exception("LLM index update failed.")
tasks.llmindex_index()
update_llm_index.assert_called_once()
task = PaperlessTask.objects.get(
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
self.assertEqual(task.status, states.FAILURE)
self.assertIn("LLM index update failed.", task.result)
def test_update_document_in_llm_index(self):
"""
GIVEN:
- Nothing
WHEN:
- update_document_in_llm_index task is called
THEN:
- llm_index_add_or_update_document is called
"""
doc = Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
with mock.patch(
"documents.tasks.llm_index_add_or_update_document",
) as llm_index_add_or_update_document:
tasks.update_document_in_llm_index(doc)
llm_index_add_or_update_document.assert_called_once_with(doc)
def test_remove_document_from_llm_index(self):
"""
GIVEN:
- Nothing
WHEN:
- remove_document_from_llm_index task is called
THEN:
- llm_index_remove_document is called
"""
doc = Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
with mock.patch(
"documents.tasks.llm_index_remove_document",
) as llm_index_remove_document:
tasks.remove_document_from_llm_index(doc)
llm_index_remove_document.assert_called_once_with(doc)

View File

@@ -2,6 +2,8 @@ import json
import tempfile
from datetime import timedelta
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
from django.conf import settings
from django.contrib.auth.models import Group
@@ -15,9 +17,15 @@ from django.utils import timezone
from guardian.shortcuts import assign_perm
from rest_framework import status
from documents.caching import get_llm_suggestion_cache
from documents.caching import set_llm_suggestions_cache
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from paperless.models import ApplicationConfiguration
@@ -270,3 +278,176 @@ class TestViews(DirectoriesMixin, TestCase):
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
f"but {num_queries_large} queries for 50 tags"
)
class TestAISuggestions(DirectoriesMixin, TestCase):
def setUp(self):
self.user = User.objects.create_superuser(username="testuser")
self.document = Document.objects.create(
title="Test Document",
filename="test.pdf",
mime_type="application/pdf",
)
self.tag1 = Tag.objects.create(name="tag1")
self.correspondent1 = Correspondent.objects.create(name="correspondent1")
self.document_type1 = DocumentType.objects.create(name="type1")
self.path1 = StoragePath.objects.create(name="path1")
super().setUp()
@patch("documents.views.get_llm_suggestion_cache")
@patch("documents.views.refresh_suggestions_cache")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache):
mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]})
self.client.force_login(user=self.user)
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
mock_refresh_cache.assert_called_once_with(self.document.pk)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_suggestions_with_ai_enabled(
self,
mock_get_ai_classification,
):
mock_get_ai_classification.return_value = {
"title": "AI Title",
"tags": ["tag1", "tag2"],
"correspondents": ["correspondent1"],
"document_types": ["type1"],
"storage_paths": ["path1"],
"dates": ["2023-01-01"],
}
self.client.force_login(user=self.user)
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json(),
{
"title": "AI Title",
"tags": [self.tag1.pk],
"suggested_tags": ["tag2"],
"correspondents": [self.correspondent1.pk],
"suggested_correspondents": [],
"document_types": [self.document_type1.pk],
"suggested_document_types": [],
"storage_paths": [self.path1.pk],
"suggested_storage_paths": [],
"dates": ["2023-01-01"],
},
)
def test_invalidate_suggestions_cache(self):
self.client.force_login(user=self.user)
suggestions = {
"title": "AI Title",
"tags": ["tag1", "tag2"],
"correspondents": ["correspondent1"],
"document_types": ["type1"],
"storage_paths": ["path1"],
"dates": ["2023-01-01"],
}
set_llm_suggestions_cache(
self.document.pk,
suggestions,
backend="mock_backend",
)
self.assertEqual(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend",
).suggestions,
suggestions,
)
# post_save signal triggered
update_llm_suggestions_cache(
sender=None,
instance=self.document,
)
self.assertIsNone(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend",
),
)
class TestAIChatStreamingView(DirectoriesMixin, TestCase):
ENDPOINT = "/api/documents/chat/"
def setUp(self):
self.user = User.objects.create_user(username="testuser", password="pass")
self.client.force_login(user=self.user)
self.document = Document.objects.create(
title="Test Document",
filename="test.pdf",
mime_type="application/pdf",
)
super().setUp()
@override_settings(AI_ENABLED=False)
def test_post_ai_disabled(self):
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
content_type="application/json",
)
self.assertEqual(response.status_code, 400)
self.assertIn(b"AI is required for this feature", response.content)
@patch("documents.views.stream_chat_with_documents")
@patch("documents.views.get_objects_for_user_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_no_document_id(self, mock_get_objects, mock_stream_chat):
mock_get_objects.return_value = [self.document]
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/event-stream")
@patch("documents.views.stream_chat_with_documents")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id(self, mock_stream_chat):
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
self.ENDPOINT,
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/event-stream")
@override_settings(AI_ENABLED=True)
def test_post_with_invalid_document_id(self):
response = self.client.post(
self.ENDPOINT,
data='{"q": "question", "document_id": 999999}',
content_type="application/json",
)
self.assertEqual(response.status_code, 400)
self.assertIn(b"Document not found", response.content)
@patch("documents.views.has_perms_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id_no_permission(self, mock_has_perms):
mock_has_perms.return_value = False
response = self.client.post(
self.ENDPOINT,
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
content_type="application/json",
)
self.assertEqual(response.status_code, 403)
self.assertIn(b"Insufficient permissions", response.content)