Feature: Paperless AI (#10319)

This commit is contained in:
shamoon
2026-01-13 08:24:42 -08:00
committed by GitHub
parent 4347ba1f9c
commit e940764fe0
78 changed files with 5429 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
from documents.signals import document_consumption_finished
from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags
from documents.signals.handlers import add_or_update_document_in_llm_index
from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated
@@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
document_consumption_finished.connect(set_storage_path)
document_consumption_finished.connect(add_to_index)
document_consumption_finished.connect(run_workflows_added)
document_consumption_finished.connect(add_or_update_document_in_llm_index)
document_updated.connect(run_workflows_updated)
import documents.schema # noqa: F401

View File

@@ -41,6 +41,7 @@ class SuggestionCacheData:
CLASSIFIER_VERSION_KEY: Final[str] = "classifier_version"
CLASSIFIER_HASH_KEY: Final[str] = "classifier_hash"
CLASSIFIER_MODIFIED_KEY: Final[str] = "classifier_modified"
LLM_CACHE_CLASSIFIER_VERSION: Final[int] = 1000 # Marker distinguishing LLM suggestions
CACHE_1_MINUTE: Final[int] = 60
CACHE_5_MINUTES: Final[int] = 5 * CACHE_1_MINUTE
@@ -196,6 +197,54 @@ def refresh_suggestions_cache(
cache.touch(doc_key, timeout)
def get_llm_suggestion_cache(
document_id: int,
backend: str,
) -> SuggestionCacheData | None:
doc_key = get_suggestion_cache_key(document_id)
data: SuggestionCacheData = cache.get(doc_key)
if data and data.classifier_hash == backend:
return data
return None
def set_llm_suggestions_cache(
document_id: int,
suggestions: dict,
*,
backend: str,
timeout: int = CACHE_50_MINUTES,
) -> None:
"""
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
"""
doc_key = get_suggestion_cache_key(document_id)
cache.set(
doc_key,
SuggestionCacheData(
classifier_version=LLM_CACHE_CLASSIFIER_VERSION,
classifier_hash=backend,
suggestions=suggestions,
),
timeout,
)
def invalidate_llm_suggestions_cache(
document_id: int,
) -> None:
"""
Invalidate the LLM suggestions cache for a specific document and backend.
"""
doc_key = get_suggestion_cache_key(document_id)
data: SuggestionCacheData = cache.get(doc_key)
if data:
cache.delete(doc_key)
def get_metadata_cache_key(document_id: int) -> str:
"""
Returns the basic key for a document's metadata

View File

@@ -0,0 +1,22 @@
from django.core.management import BaseCommand
from django.db import transaction
from documents.management.commands.mixins import ProgressBarMixin
from documents.tasks import llmindex_index
class Command(ProgressBarMixin, BaseCommand):
help = "Manages the LLM-based vector index for Paperless."
def add_arguments(self, parser):
parser.add_argument("command", choices=["rebuild", "update"])
self.add_argument_progress_bar_mixin(parser)
def handle(self, *args, **options):
self.handle_progress_bar_mixin(**options)
with transaction.atomic():
llmindex_index(
progress_bar_disable=self.no_progress_bar,
rebuild=options["command"] == "rebuild",
scheduled=False,
)

View File

@@ -0,0 +1,30 @@
# Generated by Django 5.1.8 on 2025-04-30 02:38
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("documents", "1074_workflowrun_deleted_at_workflowrun_restored_at_and_more"),
]
operations = [
migrations.AlterField(
model_name="paperlesstask",
name="task_name",
field=models.CharField(
choices=[
("consume_file", "Consume File"),
("train_classifier", "Train Classifier"),
("check_sanity", "Check Sanity"),
("index_optimize", "Index Optimize"),
("llmindex_update", "LLM Index Update"),
],
help_text="Name of the task that was run",
max_length=255,
null=True,
verbose_name="Task Name",
),
),
]

View File

@@ -598,6 +598,7 @@ class PaperlessTask(ModelWithOwner):
TRAIN_CLASSIFIER = ("train_classifier", _("Train Classifier"))
CHECK_SANITY = ("check_sanity", _("Check Sanity"))
INDEX_OPTIMIZE = ("index_optimize", _("Index Optimize"))
LLMINDEX_UPDATE = ("llmindex_update", _("LLM Index Update"))
task_id = models.CharField(
max_length=255,

View File

@@ -26,6 +26,8 @@ from filelock import FileLock
from documents import matching
from documents.caching import clear_document_caches
from documents.caching import invalidate_llm_suggestions_cache
from documents.data_models import ConsumableDocument
from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories
from documents.file_handling import generate_filename
@@ -52,6 +54,7 @@ from documents.workflows.mutations import apply_assignment_to_overrides
from documents.workflows.mutations import apply_removal_to_document
from documents.workflows.mutations import apply_removal_to_overrides
from documents.workflows.utils import get_workflows_for_trigger
from paperless.config import AIConfig
if TYPE_CHECKING:
from documents.classifier import DocumentClassifier
@@ -638,6 +641,15 @@ def cleanup_custom_field_deletion(sender, instance: CustomField, **kwargs):
)
@receiver(models.signals.post_save, sender=Document)
def update_llm_suggestions_cache(sender, instance, **kwargs):
"""
Invalidate the LLM suggestions cache when a document is saved.
"""
# Invalidate the cache for the document
invalidate_llm_suggestions_cache(instance.pk)
@receiver(models.signals.post_delete, sender=User)
@receiver(models.signals.post_delete, sender=Group)
def cleanup_user_deletion(sender, instance: User | Group, **kwargs):
@@ -944,3 +956,26 @@ def close_connection_pool_on_worker_init(**kwargs):
for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool()
def add_or_update_document_in_llm_index(sender, document, **kwargs):
"""
Add or update a document in the LLM index when it is created or updated.
"""
ai_config = AIConfig()
if ai_config.llm_index_enabled:
from documents.tasks import update_document_in_llm_index
update_document_in_llm_index.delay(document)
@receiver(models.signals.post_delete, sender=Document)
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
"""
Delete a document from the LLM index when it is deleted.
"""
ai_config = AIConfig()
if ai_config.llm_index_enabled:
from documents.tasks import remove_document_from_llm_index
remove_document_from_llm_index.delay(instance)

View File

@@ -54,6 +54,10 @@ from documents.signals import document_updated
from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows
from documents.workflows.utils import get_workflows_for_trigger
from paperless.config import AIConfig
from paperless_ai.indexing import llm_index_add_or_update_document
from paperless_ai.indexing import llm_index_remove_document
from paperless_ai.indexing import update_llm_index
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
@@ -242,6 +246,13 @@ def bulk_update_documents(document_ids):
for doc in documents:
index.update_document(writer, doc)
ai_config = AIConfig()
if ai_config.llm_index_enabled:
update_llm_index(
progress_bar_disable=True,
rebuild=False,
)
@shared_task
def update_document_content_maybe_archive_file(document_id):
@@ -341,6 +352,10 @@ def update_document_content_maybe_archive_file(document_id):
with index.open_index_writer() as writer:
index.update_document(writer, document)
ai_config = AIConfig()
if ai_config.llm_index_enabled:
llm_index_add_or_update_document(document)
clear_document_caches(document.pk)
except Exception:
@@ -558,3 +573,55 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
if affected:
bulk_update_documents.delay(document_ids=list(affected))
@shared_task
def llmindex_index(
*,
progress_bar_disable=True,
rebuild=False,
scheduled=True,
auto=False,
):
ai_config = AIConfig()
if ai_config.llm_index_enabled:
task = PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK
if scheduled
else PaperlessTask.TaskType.AUTO
if auto
else PaperlessTask.TaskType.MANUAL_TASK,
task_id=uuid.uuid4(),
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
status=states.STARTED,
date_created=timezone.now(),
date_started=timezone.now(),
)
from paperless_ai.indexing import update_llm_index
try:
result = update_llm_index(
progress_bar_disable=progress_bar_disable,
rebuild=rebuild,
)
task.status = states.SUCCESS
task.result = result
except Exception as e:
logger.error("LLM index error: " + str(e))
task.status = states.FAILURE
task.result = str(e)
task.date_done = timezone.now()
task.save(update_fields=["status", "result", "date_done"])
else:
logger.info("LLM index is disabled, skipping update.")
@shared_task
def update_document_in_llm_index(document):
llm_index_add_or_update_document(document)
@shared_task
def remove_document_from_llm_index(document):
llm_index_remove_document(document)

View File

@@ -1,6 +1,7 @@
import json
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
from django.contrib.auth.models import User
from django.core.files.uploadedfile import SimpleUploadedFile
@@ -66,6 +67,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"barcode_max_pages": None,
"barcode_enable_tag": None,
"barcode_tag_mapping": None,
"ai_enabled": False,
"llm_embedding_backend": None,
"llm_embedding_model": None,
"llm_backend": None,
"llm_model": None,
"llm_api_key": None,
"llm_endpoint": None,
},
)
@@ -611,3 +619,76 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(ApplicationConfiguration.objects.count(), 1)
def test_update_llm_api_key(self):
"""
GIVEN:
- Existing config with llm_api_key specified
WHEN:
- API to update llm_api_key is called with all *s
- API to update llm_api_key is called with empty string
THEN:
- llm_api_key is unchanged
- llm_api_key is set to None
"""
config = ApplicationConfiguration.objects.first()
config.llm_api_key = "1234567890"
config.save()
# Test with all *
response = self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"llm_api_key": "*" * 32,
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config.refresh_from_db()
self.assertEqual(config.llm_api_key, "1234567890")
# Test with empty string
response = self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"llm_api_key": "",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config.refresh_from_db()
self.assertEqual(config.llm_api_key, None)
def test_enable_ai_index_triggers_update(self):
"""
GIVEN:
- Existing config with AI disabled
WHEN:
- Config is updated to enable AI with llm_embedding_backend
THEN:
- LLM index is triggered to update
"""
config = ApplicationConfiguration.objects.first()
config.ai_enabled = False
config.llm_embedding_backend = None
config.save()
with (
patch("documents.tasks.llmindex_index.delay") as mock_update,
patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"ai_enabled": True,
"llm_embedding_backend": "openai",
},
),
content_type="application/json",
)
mock_update.assert_called_once()

View File

@@ -310,3 +310,69 @@ class TestSystemStatus(APITestCase):
"ERROR",
)
self.assertIsNotNone(response.data["tasks"]["sanity_check_error"])
def test_system_status_ai_disabled(self):
"""
GIVEN:
- The AI feature is disabled
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=False):
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "DISABLED")
self.assertIsNone(response.data["tasks"]["llmindex_error"])
def test_system_status_ai_enabled(self):
"""
GIVEN:
- The AI index feature is enabled, but no tasks are found
- The AI index feature is enabled and a task is found
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
self.client.force_login(self.user)
# No tasks found
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING")
PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK,
status=states.SUCCESS,
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "OK")
self.assertIsNone(response.data["tasks"]["llmindex_error"])
def test_system_status_ai_error(self):
"""
GIVEN:
- The AI index feature is enabled and a task is found with an error
WHEN:
- The user requests the system status
THEN:
- The response contains the correct AI status
"""
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK,
status=states.FAILURE,
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
result="AI index update failed",
)
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["llmindex_status"], "ERROR")
self.assertIsNotNone(response.data["tasks"]["llmindex_error"])

View File

@@ -49,6 +49,7 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase):
"backend_setting": "default",
},
"email_enabled": False,
"ai_enabled": False,
},
)

View File

@@ -3,14 +3,17 @@ from datetime import timedelta
from pathlib import Path
from unittest import mock
from celery import states
from django.conf import settings
from django.test import TestCase
from django.test import override_settings
from django.utils import timezone
from documents import tasks
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import PaperlessTask
from documents.models import Tag
from documents.sanity_checker import SanityCheckFailedException
from documents.sanity_checker import SanityCheckMessages
@@ -270,3 +273,103 @@ class TestUpdateContent(DirectoriesMixin, TestCase):
tasks.update_document_content_maybe_archive_file(doc.pk)
self.assertNotEqual(Document.objects.get(pk=doc.pk).content, "test")
class TestAIIndex(DirectoriesMixin, TestCase):
@override_settings(
AI_ENABLED=True,
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_ai_index_success(self):
"""
GIVEN:
- Document exists, AI is enabled, llm index backend is set
WHEN:
- llmindex_index task is called
THEN:
- update_llm_index is called, and the task is marked as success
"""
Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
# lazy-loaded so mock the actual function
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
update_llm_index.return_value = "LLM index updated successfully."
tasks.llmindex_index()
update_llm_index.assert_called_once()
task = PaperlessTask.objects.get(
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
self.assertEqual(task.status, states.SUCCESS)
self.assertEqual(task.result, "LLM index updated successfully.")
@override_settings(
AI_ENABLED=True,
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_ai_index_failure(self):
"""
GIVEN:
- Document exists, AI is enabled, llm index backend is set
WHEN:
- llmindex_index task is called
THEN:
- update_llm_index raises an exception, and the task is marked as failure
"""
Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
# lazy-loaded so mock the actual function
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
update_llm_index.side_effect = Exception("LLM index update failed.")
tasks.llmindex_index()
update_llm_index.assert_called_once()
task = PaperlessTask.objects.get(
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
self.assertEqual(task.status, states.FAILURE)
self.assertIn("LLM index update failed.", task.result)
def test_update_document_in_llm_index(self):
"""
GIVEN:
- Nothing
WHEN:
- update_document_in_llm_index task is called
THEN:
- llm_index_add_or_update_document is called
"""
doc = Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
with mock.patch(
"documents.tasks.llm_index_add_or_update_document",
) as llm_index_add_or_update_document:
tasks.update_document_in_llm_index(doc)
llm_index_add_or_update_document.assert_called_once_with(doc)
def test_remove_document_from_llm_index(self):
"""
GIVEN:
- Nothing
WHEN:
- remove_document_from_llm_index task is called
THEN:
- llm_index_remove_document is called
"""
doc = Document.objects.create(
title="test",
content="my document",
checksum="wow",
)
with mock.patch(
"documents.tasks.llm_index_remove_document",
) as llm_index_remove_document:
tasks.remove_document_from_llm_index(doc)
llm_index_remove_document.assert_called_once_with(doc)

View File

@@ -2,6 +2,8 @@ import json
import tempfile
from datetime import timedelta
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
from django.conf import settings
from django.contrib.auth.models import Group
@@ -15,9 +17,15 @@ from django.utils import timezone
from guardian.shortcuts import assign_perm
from rest_framework import status
from documents.caching import get_llm_suggestion_cache
from documents.caching import set_llm_suggestions_cache
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from paperless.models import ApplicationConfiguration
@@ -270,3 +278,176 @@ class TestViews(DirectoriesMixin, TestCase):
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
f"but {num_queries_large} queries for 50 tags"
)
class TestAISuggestions(DirectoriesMixin, TestCase):
def setUp(self):
self.user = User.objects.create_superuser(username="testuser")
self.document = Document.objects.create(
title="Test Document",
filename="test.pdf",
mime_type="application/pdf",
)
self.tag1 = Tag.objects.create(name="tag1")
self.correspondent1 = Correspondent.objects.create(name="correspondent1")
self.document_type1 = DocumentType.objects.create(name="type1")
self.path1 = StoragePath.objects.create(name="path1")
super().setUp()
@patch("documents.views.get_llm_suggestion_cache")
@patch("documents.views.refresh_suggestions_cache")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache):
mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]})
self.client.force_login(user=self.user)
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
mock_refresh_cache.assert_called_once_with(self.document.pk)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_suggestions_with_ai_enabled(
self,
mock_get_ai_classification,
):
mock_get_ai_classification.return_value = {
"title": "AI Title",
"tags": ["tag1", "tag2"],
"correspondents": ["correspondent1"],
"document_types": ["type1"],
"storage_paths": ["path1"],
"dates": ["2023-01-01"],
}
self.client.force_login(user=self.user)
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json(),
{
"title": "AI Title",
"tags": [self.tag1.pk],
"suggested_tags": ["tag2"],
"correspondents": [self.correspondent1.pk],
"suggested_correspondents": [],
"document_types": [self.document_type1.pk],
"suggested_document_types": [],
"storage_paths": [self.path1.pk],
"suggested_storage_paths": [],
"dates": ["2023-01-01"],
},
)
def test_invalidate_suggestions_cache(self):
self.client.force_login(user=self.user)
suggestions = {
"title": "AI Title",
"tags": ["tag1", "tag2"],
"correspondents": ["correspondent1"],
"document_types": ["type1"],
"storage_paths": ["path1"],
"dates": ["2023-01-01"],
}
set_llm_suggestions_cache(
self.document.pk,
suggestions,
backend="mock_backend",
)
self.assertEqual(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend",
).suggestions,
suggestions,
)
# post_save signal triggered
update_llm_suggestions_cache(
sender=None,
instance=self.document,
)
self.assertIsNone(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend",
),
)
class TestAIChatStreamingView(DirectoriesMixin, TestCase):
ENDPOINT = "/api/documents/chat/"
def setUp(self):
self.user = User.objects.create_user(username="testuser", password="pass")
self.client.force_login(user=self.user)
self.document = Document.objects.create(
title="Test Document",
filename="test.pdf",
mime_type="application/pdf",
)
super().setUp()
@override_settings(AI_ENABLED=False)
def test_post_ai_disabled(self):
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
content_type="application/json",
)
self.assertEqual(response.status_code, 400)
self.assertIn(b"AI is required for this feature", response.content)
@patch("documents.views.stream_chat_with_documents")
@patch("documents.views.get_objects_for_user_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_no_document_id(self, mock_get_objects, mock_stream_chat):
mock_get_objects.return_value = [self.document]
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/event-stream")
@patch("documents.views.stream_chat_with_documents")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id(self, mock_stream_chat):
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
self.ENDPOINT,
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/event-stream")
@override_settings(AI_ENABLED=True)
def test_post_with_invalid_document_id(self):
response = self.client.post(
self.ENDPOINT,
data='{"q": "question", "document_id": 999999}',
content_type="application/json",
)
self.assertEqual(response.status_code, 400)
self.assertIn(b"Document not found", response.content)
@patch("documents.views.has_perms_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id_no_permission(self, mock_has_perms):
mock_has_perms.return_value = False
response = self.client.post(
self.ENDPOINT,
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
content_type="application/json",
)
self.assertEqual(response.status_code, 403)
self.assertIn(b"Insufficient permissions", response.content)

View File

@@ -45,6 +45,7 @@ from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden
from django.http import HttpResponseRedirect
from django.http import HttpResponseServerError
from django.http import StreamingHttpResponse
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.decorators import method_decorator
@@ -52,6 +53,7 @@ from django.utils.timezone import make_aware
from django.utils.translation import get_language
from django.views import View
from django.views.decorators.cache import cache_control
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.decorators.http import condition
from django.views.decorators.http import last_modified
from django.views.generic import TemplateView
@@ -91,10 +93,12 @@ from documents import index
from documents.bulk_download import ArchiveOnlyStrategy
from documents.bulk_download import OriginalAndArchiveStrategy
from documents.bulk_download import OriginalsOnlyStrategy
from documents.caching import get_llm_suggestion_cache
from documents.caching import get_metadata_cache
from documents.caching import get_suggestion_cache
from documents.caching import refresh_metadata_cache
from documents.caching import refresh_suggestions_cache
from documents.caching import set_llm_suggestions_cache
from documents.caching import set_metadata_cache
from documents.caching import set_suggestions_cache
from documents.classifier import load_classifier
@@ -182,18 +186,27 @@ from documents.signals import document_updated
from documents.tasks import consume_file
from documents.tasks import empty_trash
from documents.tasks import index_optimize
from documents.tasks import llmindex_index
from documents.tasks import sanity_check
from documents.tasks import train_classifier
from documents.tasks import update_document_parent_tags
from documents.utils import get_boolean
from paperless import version
from paperless.celery import app as celery_app
from paperless.config import AIConfig
from paperless.config import GeneralConfig
from paperless.db import GnuPG
from paperless.models import ApplicationConfiguration
from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager
@@ -934,37 +947,103 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions")
document_suggestions = get_suggestion_cache(doc.pk)
ai_config = AIConfig()
if document_suggestions is not None:
refresh_suggestions_cache(doc.pk)
return Response(document_suggestions.suggestions)
classifier = load_classifier()
dates = []
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
gen = parse_date_generator(doc.filename, doc.content)
dates = sorted(
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
if ai_config.ai_enabled:
cached_llm_suggestions = get_llm_suggestion_cache(
doc.pk,
backend=ai_config.llm_backend,
)
resp_data = {
"correspondents": [
c.id for c in match_correspondents(doc, classifier, request.user)
],
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
"document_types": [
dt.id for dt in match_document_types(doc, classifier, request.user)
],
"storage_paths": [
dt.id for dt in match_storage_paths(doc, classifier, request.user)
],
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
}
if cached_llm_suggestions:
refresh_suggestions_cache(doc.pk)
return Response(cached_llm_suggestions.suggestions)
# Cache the suggestions and the classifier hash for later
set_suggestions_cache(doc.pk, resp_data, classifier)
llm_suggestions = get_ai_document_classification(doc, request.user)
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
request.user,
)
matched_correspondents = match_correspondents_by_name(
llm_suggestions.get("correspondents", []),
request.user,
)
matched_types = match_document_types_by_name(
llm_suggestions.get("document_types", []),
request.user,
)
matched_paths = match_storage_paths_by_name(
llm_suggestions.get("storage_paths", []),
request.user,
)
resp_data = {
"title": llm_suggestions.get("title"),
"tags": [t.id for t in matched_tags],
"suggested_tags": extract_unmatched_names(
llm_suggestions.get("tags", []),
matched_tags,
),
"correspondents": [c.id for c in matched_correspondents],
"suggested_correspondents": extract_unmatched_names(
llm_suggestions.get("correspondents", []),
matched_correspondents,
),
"document_types": [d.id for d in matched_types],
"suggested_document_types": extract_unmatched_names(
llm_suggestions.get("document_types", []),
matched_types,
),
"storage_paths": [s.id for s in matched_paths],
"suggested_storage_paths": extract_unmatched_names(
llm_suggestions.get("storage_paths", []),
matched_paths,
),
"dates": llm_suggestions.get("dates", []),
}
set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
else:
document_suggestions = get_suggestion_cache(doc.pk)
if document_suggestions is not None:
refresh_suggestions_cache(doc.pk)
return Response(document_suggestions.suggestions)
classifier = load_classifier()
dates = []
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
gen = parse_date_generator(doc.filename, doc.content)
dates = sorted(
{
i
for i in itertools.islice(
gen,
settings.NUMBER_OF_SUGGESTED_DATES,
)
},
)
resp_data = {
"correspondents": [
c.id for c in match_correspondents(doc, classifier, request.user)
],
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
"document_types": [
dt.id for dt in match_document_types(doc, classifier, request.user)
],
"storage_paths": [
dt.id for dt in match_storage_paths(doc, classifier, request.user)
],
"dates": [
date.strftime("%Y-%m-%d") for date in dates if date is not None
],
}
# Cache the suggestions and the classifier hash for later
set_suggestions_cache(doc.pk, resp_data, classifier)
return Response(resp_data)
@@ -1288,6 +1367,59 @@ class DocumentViewSet(
)
class ChatStreamingSerializer(serializers.Serializer):
q = serializers.CharField(required=True)
document_id = serializers.IntegerField(required=False, allow_null=True)
@method_decorator(
[
ensure_csrf_cookie,
cache_control(no_cache=True),
],
name="dispatch",
)
class ChatStreamingView(GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = ChatStreamingSerializer
def post(self, request, *args, **kwargs):
request.compress_exempt = True
ai_config = AIConfig()
if not ai_config.ai_enabled:
return HttpResponseBadRequest("AI is required for this feature")
try:
question = request.data["q"]
except KeyError:
return HttpResponseBadRequest("Invalid request")
doc_id = request.data.get("document_id")
if doc_id:
try:
document = Document.objects.get(id=doc_id)
except Document.DoesNotExist:
return HttpResponseBadRequest("Document not found")
if not has_perms_owner_aware(request.user, "view_document", document):
return HttpResponseForbidden("Insufficient permissions")
documents = [document]
else:
documents = get_objects_for_user_owner_aware(
request.user,
"view_document",
Document,
)
response = StreamingHttpResponse(
stream_chat_with_documents(query_str=question, documents=documents),
content_type="text/event-stream",
)
return response
@extend_schema_view(
list=extend_schema(
description="Document views including search",
@@ -2446,6 +2578,10 @@ class UiSettingsView(GenericAPIView):
ui_settings["email_enabled"] = settings.EMAIL_ENABLED
ai_config = AIConfig()
ui_settings["ai_enabled"] = ai_config.ai_enabled
user_resp = {
"id": user.id,
"username": user.username,
@@ -2587,6 +2723,10 @@ class TasksViewSet(ReadOnlyModelViewSet):
sanity_check,
{"scheduled": False, "raise_on_error": False},
),
PaperlessTask.TaskName.LLMINDEX_UPDATE: (
llmindex_index,
{"scheduled": False, "rebuild": False},
),
}
def get_queryset(self):
@@ -3106,6 +3246,31 @@ class SystemStatusView(PassUserMixin):
last_sanity_check.date_done if last_sanity_check else None
)
ai_config = AIConfig()
if not ai_config.llm_index_enabled:
llmindex_status = "DISABLED"
llmindex_error = None
llmindex_last_modified = None
else:
last_llmindex_update = (
PaperlessTask.objects.filter(
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
)
.order_by("-date_done")
.first()
)
llmindex_status = "OK"
llmindex_error = None
if last_llmindex_update is None:
llmindex_status = "WARNING"
llmindex_error = "No LLM index update tasks found"
elif last_llmindex_update and last_llmindex_update.status == states.FAILURE:
llmindex_status = "ERROR"
llmindex_error = last_llmindex_update.result
llmindex_last_modified = (
last_llmindex_update.date_done if last_llmindex_update else None
)
return Response(
{
"pngx_version": current_version,
@@ -3143,6 +3308,9 @@ class SystemStatusView(PassUserMixin):
"sanity_check_status": sanity_check_status,
"sanity_check_last_run": sanity_check_last_run,
"sanity_check_error": sanity_check_error,
"llmindex_status": llmindex_status,
"llmindex_last_modified": llmindex_last_modified,
"llmindex_error": llmindex_error,
},
},
)