mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-01-14 21:54:22 -06:00
Feature: Paperless AI (#10319)
This commit is contained in:
@@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
|
||||
from documents.signals import document_consumption_finished
|
||||
from documents.signals import document_updated
|
||||
from documents.signals.handlers import add_inbox_tags
|
||||
from documents.signals.handlers import add_or_update_document_in_llm_index
|
||||
from documents.signals.handlers import add_to_index
|
||||
from documents.signals.handlers import run_workflows_added
|
||||
from documents.signals.handlers import run_workflows_updated
|
||||
@@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
|
||||
document_consumption_finished.connect(set_storage_path)
|
||||
document_consumption_finished.connect(add_to_index)
|
||||
document_consumption_finished.connect(run_workflows_added)
|
||||
document_consumption_finished.connect(add_or_update_document_in_llm_index)
|
||||
document_updated.connect(run_workflows_updated)
|
||||
|
||||
import documents.schema # noqa: F401
|
||||
|
||||
@@ -41,6 +41,7 @@ class SuggestionCacheData:
|
||||
CLASSIFIER_VERSION_KEY: Final[str] = "classifier_version"
|
||||
CLASSIFIER_HASH_KEY: Final[str] = "classifier_hash"
|
||||
CLASSIFIER_MODIFIED_KEY: Final[str] = "classifier_modified"
|
||||
LLM_CACHE_CLASSIFIER_VERSION: Final[int] = 1000 # Marker distinguishing LLM suggestions
|
||||
|
||||
CACHE_1_MINUTE: Final[int] = 60
|
||||
CACHE_5_MINUTES: Final[int] = 5 * CACHE_1_MINUTE
|
||||
@@ -196,6 +197,54 @@ def refresh_suggestions_cache(
|
||||
cache.touch(doc_key, timeout)
|
||||
|
||||
|
||||
def get_llm_suggestion_cache(
|
||||
document_id: int,
|
||||
backend: str,
|
||||
) -> SuggestionCacheData | None:
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
data: SuggestionCacheData = cache.get(doc_key)
|
||||
|
||||
if data and data.classifier_hash == backend:
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_llm_suggestions_cache(
|
||||
document_id: int,
|
||||
suggestions: dict,
|
||||
*,
|
||||
backend: str,
|
||||
timeout: int = CACHE_50_MINUTES,
|
||||
) -> None:
|
||||
"""
|
||||
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
|
||||
"""
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
cache.set(
|
||||
doc_key,
|
||||
SuggestionCacheData(
|
||||
classifier_version=LLM_CACHE_CLASSIFIER_VERSION,
|
||||
classifier_hash=backend,
|
||||
suggestions=suggestions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def invalidate_llm_suggestions_cache(
|
||||
document_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Invalidate the LLM suggestions cache for a specific document and backend.
|
||||
"""
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
data: SuggestionCacheData = cache.get(doc_key)
|
||||
|
||||
if data:
|
||||
cache.delete(doc_key)
|
||||
|
||||
|
||||
def get_metadata_cache_key(document_id: int) -> str:
|
||||
"""
|
||||
Returns the basic key for a document's metadata
|
||||
|
||||
22
src/documents/management/commands/document_llmindex.py
Normal file
22
src/documents/management/commands/document_llmindex.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from django.core.management import BaseCommand
|
||||
from django.db import transaction
|
||||
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.tasks import llmindex_index
|
||||
|
||||
|
||||
class Command(ProgressBarMixin, BaseCommand):
|
||||
help = "Manages the LLM-based vector index for Paperless."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("command", choices=["rebuild", "update"])
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
with transaction.atomic():
|
||||
llmindex_index(
|
||||
progress_bar_disable=self.no_progress_bar,
|
||||
rebuild=options["command"] == "rebuild",
|
||||
scheduled=False,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
# Generated by Django 5.1.8 on 2025-04-30 02:38
|
||||
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1074_workflowrun_deleted_at_workflowrun_restored_at_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="paperlesstask",
|
||||
name="task_name",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("consume_file", "Consume File"),
|
||||
("train_classifier", "Train Classifier"),
|
||||
("check_sanity", "Check Sanity"),
|
||||
("index_optimize", "Index Optimize"),
|
||||
("llmindex_update", "LLM Index Update"),
|
||||
],
|
||||
help_text="Name of the task that was run",
|
||||
max_length=255,
|
||||
null=True,
|
||||
verbose_name="Task Name",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -598,6 +598,7 @@ class PaperlessTask(ModelWithOwner):
|
||||
TRAIN_CLASSIFIER = ("train_classifier", _("Train Classifier"))
|
||||
CHECK_SANITY = ("check_sanity", _("Check Sanity"))
|
||||
INDEX_OPTIMIZE = ("index_optimize", _("Index Optimize"))
|
||||
LLMINDEX_UPDATE = ("llmindex_update", _("LLM Index Update"))
|
||||
|
||||
task_id = models.CharField(
|
||||
max_length=255,
|
||||
|
||||
@@ -26,6 +26,8 @@ from filelock import FileLock
|
||||
|
||||
from documents import matching
|
||||
from documents.caching import clear_document_caches
|
||||
from documents.caching import invalidate_llm_suggestions_cache
|
||||
from documents.data_models import ConsumableDocument
|
||||
from documents.file_handling import create_source_path_directory
|
||||
from documents.file_handling import delete_empty_directories
|
||||
from documents.file_handling import generate_filename
|
||||
@@ -52,6 +54,7 @@ from documents.workflows.mutations import apply_assignment_to_overrides
|
||||
from documents.workflows.mutations import apply_removal_to_document
|
||||
from documents.workflows.mutations import apply_removal_to_overrides
|
||||
from documents.workflows.utils import get_workflows_for_trigger
|
||||
from paperless.config import AIConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.classifier import DocumentClassifier
|
||||
@@ -638,6 +641,15 @@ def cleanup_custom_field_deletion(sender, instance: CustomField, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
@receiver(models.signals.post_save, sender=Document)
|
||||
def update_llm_suggestions_cache(sender, instance, **kwargs):
|
||||
"""
|
||||
Invalidate the LLM suggestions cache when a document is saved.
|
||||
"""
|
||||
# Invalidate the cache for the document
|
||||
invalidate_llm_suggestions_cache(instance.pk)
|
||||
|
||||
|
||||
@receiver(models.signals.post_delete, sender=User)
|
||||
@receiver(models.signals.post_delete, sender=Group)
|
||||
def cleanup_user_deletion(sender, instance: User | Group, **kwargs):
|
||||
@@ -944,3 +956,26 @@ def close_connection_pool_on_worker_init(**kwargs):
|
||||
for conn in connections.all(initialized_only=True):
|
||||
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
|
||||
conn.close_pool()
|
||||
|
||||
|
||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||
"""
|
||||
Add or update a document in the LLM index when it is created or updated.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
from documents.tasks import update_document_in_llm_index
|
||||
|
||||
update_document_in_llm_index.delay(document)
|
||||
|
||||
|
||||
@receiver(models.signals.post_delete, sender=Document)
|
||||
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
|
||||
"""
|
||||
Delete a document from the LLM index when it is deleted.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
from documents.tasks import remove_document_from_llm_index
|
||||
|
||||
remove_document_from_llm_index.delay(instance)
|
||||
|
||||
@@ -54,6 +54,10 @@ from documents.signals import document_updated
|
||||
from documents.signals.handlers import cleanup_document_deletion
|
||||
from documents.signals.handlers import run_workflows
|
||||
from documents.workflows.utils import get_workflows_for_trigger
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.indexing import llm_index_add_or_update_document
|
||||
from paperless_ai.indexing import llm_index_remove_document
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.models import LogEntry
|
||||
@@ -242,6 +246,13 @@ def bulk_update_documents(document_ids):
|
||||
for doc in documents:
|
||||
index.update_document(writer, doc)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
update_llm_index(
|
||||
progress_bar_disable=True,
|
||||
rebuild=False,
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_content_maybe_archive_file(document_id):
|
||||
@@ -341,6 +352,10 @@ def update_document_content_maybe_archive_file(document_id):
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, document)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
clear_document_caches(document.pk)
|
||||
|
||||
except Exception:
|
||||
@@ -558,3 +573,55 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
|
||||
|
||||
if affected:
|
||||
bulk_update_documents.delay(document_ids=list(affected))
|
||||
|
||||
|
||||
@shared_task
|
||||
def llmindex_index(
|
||||
*,
|
||||
progress_bar_disable=True,
|
||||
rebuild=False,
|
||||
scheduled=True,
|
||||
auto=False,
|
||||
):
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
task = PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
||||
if scheduled
|
||||
else PaperlessTask.TaskType.AUTO
|
||||
if auto
|
||||
else PaperlessTask.TaskType.MANUAL_TASK,
|
||||
task_id=uuid.uuid4(),
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
status=states.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
try:
|
||||
result = update_llm_index(
|
||||
progress_bar_disable=progress_bar_disable,
|
||||
rebuild=rebuild,
|
||||
)
|
||||
task.status = states.SUCCESS
|
||||
task.result = result
|
||||
except Exception as e:
|
||||
logger.error("LLM index error: " + str(e))
|
||||
task.status = states.FAILURE
|
||||
task.result = str(e)
|
||||
|
||||
task.date_done = timezone.now()
|
||||
task.save(update_fields=["status", "result", "date_done"])
|
||||
else:
|
||||
logger.info("LLM index is disabled, skipping update.")
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_in_llm_index(document):
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
|
||||
@shared_task
|
||||
def remove_document_from_llm_index(document):
|
||||
llm_index_remove_document(document)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
@@ -66,6 +67,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
"barcode_max_pages": None,
|
||||
"barcode_enable_tag": None,
|
||||
"barcode_tag_mapping": None,
|
||||
"ai_enabled": False,
|
||||
"llm_embedding_backend": None,
|
||||
"llm_embedding_model": None,
|
||||
"llm_backend": None,
|
||||
"llm_model": None,
|
||||
"llm_api_key": None,
|
||||
"llm_endpoint": None,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -611,3 +619,76 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
self.assertEqual(ApplicationConfiguration.objects.count(), 1)
|
||||
|
||||
def test_update_llm_api_key(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Existing config with llm_api_key specified
|
||||
WHEN:
|
||||
- API to update llm_api_key is called with all *s
|
||||
- API to update llm_api_key is called with empty string
|
||||
THEN:
|
||||
- llm_api_key is unchanged
|
||||
- llm_api_key is set to None
|
||||
"""
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
config.llm_api_key = "1234567890"
|
||||
config.save()
|
||||
|
||||
# Test with all *
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_api_key": "*" * 32,
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
config.refresh_from_db()
|
||||
self.assertEqual(config.llm_api_key, "1234567890")
|
||||
# Test with empty string
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_api_key": "",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
config.refresh_from_db()
|
||||
self.assertEqual(config.llm_api_key, None)
|
||||
|
||||
def test_enable_ai_index_triggers_update(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Existing config with AI disabled
|
||||
WHEN:
|
||||
- Config is updated to enable AI with llm_embedding_backend
|
||||
THEN:
|
||||
- LLM index is triggered to update
|
||||
"""
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
config.ai_enabled = False
|
||||
config.llm_embedding_backend = None
|
||||
config.save()
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.delay") as mock_update,
|
||||
patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = False
|
||||
self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"ai_enabled": True,
|
||||
"llm_embedding_backend": "openai",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@@ -310,3 +310,69 @@ class TestSystemStatus(APITestCase):
|
||||
"ERROR",
|
||||
)
|
||||
self.assertIsNotNone(response.data["tasks"]["sanity_check_error"])
|
||||
|
||||
def test_system_status_ai_disabled(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI feature is disabled
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=False):
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "DISABLED")
|
||||
self.assertIsNone(response.data["tasks"]["llmindex_error"])
|
||||
|
||||
def test_system_status_ai_enabled(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI index feature is enabled, but no tasks are found
|
||||
- The AI index feature is enabled and a task is found
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
|
||||
self.client.force_login(self.user)
|
||||
|
||||
# No tasks found
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING")
|
||||
|
||||
PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK,
|
||||
status=states.SUCCESS,
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "OK")
|
||||
self.assertIsNone(response.data["tasks"]["llmindex_error"])
|
||||
|
||||
def test_system_status_ai_error(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- The AI index feature is enabled and a task is found with an error
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct AI status
|
||||
"""
|
||||
with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"):
|
||||
PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK,
|
||||
status=states.FAILURE,
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
result="AI index update failed",
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["llmindex_status"], "ERROR")
|
||||
self.assertIsNotNone(response.data["tasks"]["llmindex_error"])
|
||||
|
||||
@@ -49,6 +49,7 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase):
|
||||
"backend_setting": "default",
|
||||
},
|
||||
"email_enabled": False,
|
||||
"ai_enabled": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -3,14 +3,17 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from documents import tasks
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import Tag
|
||||
from documents.sanity_checker import SanityCheckFailedException
|
||||
from documents.sanity_checker import SanityCheckMessages
|
||||
@@ -270,3 +273,103 @@ class TestUpdateContent(DirectoriesMixin, TestCase):
|
||||
|
||||
tasks.update_document_content_maybe_archive_file(doc.pk)
|
||||
self.assertNotEqual(Document.objects.get(pk=doc.pk).content, "test")
|
||||
|
||||
|
||||
class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_ai_index_success(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Document exists, AI is enabled, llm index backend is set
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
THEN:
|
||||
- update_llm_index is called, and the task is marked as success
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.return_value = "LLM index updated successfully."
|
||||
tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.SUCCESS)
|
||||
self.assertEqual(task.result, "LLM index updated successfully.")
|
||||
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_ai_index_failure(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Document exists, AI is enabled, llm index backend is set
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
THEN:
|
||||
- update_llm_index raises an exception, and the task is marked as failure
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.side_effect = Exception("LLM index update failed.")
|
||||
tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.FAILURE)
|
||||
self.assertIn("LLM index update failed.", task.result)
|
||||
|
||||
def test_update_document_in_llm_index(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Nothing
|
||||
WHEN:
|
||||
- update_document_in_llm_index task is called
|
||||
THEN:
|
||||
- llm_index_add_or_update_document is called
|
||||
"""
|
||||
doc = Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
with mock.patch(
|
||||
"documents.tasks.llm_index_add_or_update_document",
|
||||
) as llm_index_add_or_update_document:
|
||||
tasks.update_document_in_llm_index(doc)
|
||||
llm_index_add_or_update_document.assert_called_once_with(doc)
|
||||
|
||||
def test_remove_document_from_llm_index(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Nothing
|
||||
WHEN:
|
||||
- remove_document_from_llm_index task is called
|
||||
THEN:
|
||||
- llm_index_remove_document is called
|
||||
"""
|
||||
doc = Document.objects.create(
|
||||
title="test",
|
||||
content="my document",
|
||||
checksum="wow",
|
||||
)
|
||||
with mock.patch(
|
||||
"documents.tasks.llm_index_remove_document",
|
||||
) as llm_index_remove_document:
|
||||
tasks.remove_document_from_llm_index(doc)
|
||||
llm_index_remove_document.assert_called_once_with(doc)
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import tempfile
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
@@ -15,9 +17,15 @@ from django.utils import timezone
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
from documents.caching import set_llm_suggestions_cache
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import ShareLink
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.signals.handlers import update_llm_suggestions_cache
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless.models import ApplicationConfiguration
|
||||
|
||||
@@ -270,3 +278,176 @@ class TestViews(DirectoriesMixin, TestCase):
|
||||
f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
|
||||
f"but {num_queries_large} queries for 50 tags"
|
||||
)
|
||||
|
||||
|
||||
class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_superuser(username="testuser")
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
self.tag1 = Tag.objects.create(name="tag1")
|
||||
self.correspondent1 = Correspondent.objects.create(name="correspondent1")
|
||||
self.document_type1 = DocumentType.objects.create(name="type1")
|
||||
self.path1 = StoragePath.objects.create(name="path1")
|
||||
super().setUp()
|
||||
|
||||
@patch("documents.views.get_llm_suggestion_cache")
|
||||
@patch("documents.views.refresh_suggestions_cache")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="mock_backend",
|
||||
)
|
||||
def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache):
|
||||
mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]})
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
|
||||
mock_refresh_cache.assert_called_once_with(self.document.pk)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="mock_backend",
|
||||
)
|
||||
def test_suggestions_with_ai_enabled(
|
||||
self,
|
||||
mock_get_ai_classification,
|
||||
):
|
||||
mock_get_ai_classification.return_value = {
|
||||
"title": "AI Title",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"correspondents": ["correspondent1"],
|
||||
"document_types": ["type1"],
|
||||
"storage_paths": ["path1"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"title": "AI Title",
|
||||
"tags": [self.tag1.pk],
|
||||
"suggested_tags": ["tag2"],
|
||||
"correspondents": [self.correspondent1.pk],
|
||||
"suggested_correspondents": [],
|
||||
"document_types": [self.document_type1.pk],
|
||||
"suggested_document_types": [],
|
||||
"storage_paths": [self.path1.pk],
|
||||
"suggested_storage_paths": [],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalidate_suggestions_cache(self):
|
||||
self.client.force_login(user=self.user)
|
||||
suggestions = {
|
||||
"title": "AI Title",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"correspondents": ["correspondent1"],
|
||||
"document_types": ["type1"],
|
||||
"storage_paths": ["path1"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
set_llm_suggestions_cache(
|
||||
self.document.pk,
|
||||
suggestions,
|
||||
backend="mock_backend",
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
self.document.pk,
|
||||
backend="mock_backend",
|
||||
).suggestions,
|
||||
suggestions,
|
||||
)
|
||||
# post_save signal triggered
|
||||
update_llm_suggestions_cache(
|
||||
sender=None,
|
||||
instance=self.document,
|
||||
)
|
||||
self.assertIsNone(
|
||||
get_llm_suggestion_cache(
|
||||
self.document.pk,
|
||||
backend="mock_backend",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestAIChatStreamingView(DirectoriesMixin, TestCase):
|
||||
ENDPOINT = "/api/documents/chat/"
|
||||
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username="testuser", password="pass")
|
||||
self.client.force_login(user=self.user)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
@override_settings(AI_ENABLED=False)
|
||||
def test_post_ai_disabled(self):
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question"}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn(b"AI is required for this feature", response.content)
|
||||
|
||||
@patch("documents.views.stream_chat_with_documents")
|
||||
@patch("documents.views.get_objects_for_user_owner_aware")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_no_document_id(self, mock_get_objects, mock_stream_chat):
|
||||
mock_get_objects.return_value = [self.document]
|
||||
mock_stream_chat.return_value = iter([b"data"])
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question"}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "text/event-stream")
|
||||
|
||||
@patch("documents.views.stream_chat_with_documents")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_document_id(self, mock_stream_chat):
|
||||
mock_stream_chat.return_value = iter([b"data"])
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "text/event-stream")
|
||||
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_invalid_document_id(self):
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question", "document_id": 999999}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn(b"Document not found", response.content)
|
||||
|
||||
@patch("documents.views.has_perms_owner_aware")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_document_id_no_permission(self, mock_has_perms):
|
||||
mock_has_perms.return_value = False
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertIn(b"Insufficient permissions", response.content)
|
||||
|
||||
@@ -45,6 +45,7 @@ from django.http import HttpResponseBadRequest
|
||||
from django.http import HttpResponseForbidden
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.http import HttpResponseServerError
|
||||
from django.http import StreamingHttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils import timezone
|
||||
from django.utils.decorators import method_decorator
|
||||
@@ -52,6 +53,7 @@ from django.utils.timezone import make_aware
|
||||
from django.utils.translation import get_language
|
||||
from django.views import View
|
||||
from django.views.decorators.cache import cache_control
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
from django.views.decorators.http import condition
|
||||
from django.views.decorators.http import last_modified
|
||||
from django.views.generic import TemplateView
|
||||
@@ -91,10 +93,12 @@ from documents import index
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
from documents.caching import get_metadata_cache
|
||||
from documents.caching import get_suggestion_cache
|
||||
from documents.caching import refresh_metadata_cache
|
||||
from documents.caching import refresh_suggestions_cache
|
||||
from documents.caching import set_llm_suggestions_cache
|
||||
from documents.caching import set_metadata_cache
|
||||
from documents.caching import set_suggestions_cache
|
||||
from documents.classifier import load_classifier
|
||||
@@ -182,18 +186,27 @@ from documents.signals import document_updated
|
||||
from documents.tasks import consume_file
|
||||
from documents.tasks import empty_trash
|
||||
from documents.tasks import index_optimize
|
||||
from documents.tasks import llmindex_index
|
||||
from documents.tasks import sanity_check
|
||||
from documents.tasks import train_classifier
|
||||
from documents.tasks import update_document_parent_tags
|
||||
from documents.utils import get_boolean
|
||||
from paperless import version
|
||||
from paperless.celery import app as celery_app
|
||||
from paperless.config import AIConfig
|
||||
from paperless.config import GeneralConfig
|
||||
from paperless.db import GnuPG
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless.views import StandardPagination
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||
@@ -934,37 +947,103 @@ class DocumentViewSet(
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
ai_config = AIConfig()
|
||||
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
|
||||
if ai_config.ai_enabled:
|
||||
cached_llm_suggestions = get_llm_suggestion_cache(
|
||||
doc.pk,
|
||||
backend=ai_config.llm_backend,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
|
||||
}
|
||||
if cached_llm_suggestions:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(cached_llm_suggestions.suggestions)
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
llm_suggestions = get_ai_document_classification(doc, request.user)
|
||||
|
||||
matched_tags = match_tags_by_name(
|
||||
llm_suggestions.get("tags", []),
|
||||
request.user,
|
||||
)
|
||||
matched_correspondents = match_correspondents_by_name(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
request.user,
|
||||
)
|
||||
matched_types = match_document_types_by_name(
|
||||
llm_suggestions.get("document_types", []),
|
||||
request.user,
|
||||
)
|
||||
matched_paths = match_storage_paths_by_name(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
request.user,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"title": llm_suggestions.get("title"),
|
||||
"tags": [t.id for t in matched_tags],
|
||||
"suggested_tags": extract_unmatched_names(
|
||||
llm_suggestions.get("tags", []),
|
||||
matched_tags,
|
||||
),
|
||||
"correspondents": [c.id for c in matched_correspondents],
|
||||
"suggested_correspondents": extract_unmatched_names(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
matched_correspondents,
|
||||
),
|
||||
"document_types": [d.id for d in matched_types],
|
||||
"suggested_document_types": extract_unmatched_names(
|
||||
llm_suggestions.get("document_types", []),
|
||||
matched_types,
|
||||
),
|
||||
"storage_paths": [s.id for s in matched_paths],
|
||||
"suggested_storage_paths": extract_unmatched_names(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
matched_paths,
|
||||
),
|
||||
"dates": llm_suggestions.get("dates", []),
|
||||
}
|
||||
|
||||
set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
|
||||
else:
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{
|
||||
i
|
||||
for i in itertools.islice(
|
||||
gen,
|
||||
settings.NUMBER_OF_SUGGESTED_DATES,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [
|
||||
date.strftime("%Y-%m-%d") for date in dates if date is not None
|
||||
],
|
||||
}
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
|
||||
return Response(resp_data)
|
||||
|
||||
@@ -1288,6 +1367,59 @@ class DocumentViewSet(
|
||||
)
|
||||
|
||||
|
||||
class ChatStreamingSerializer(serializers.Serializer):
|
||||
q = serializers.CharField(required=True)
|
||||
document_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
|
||||
|
||||
@method_decorator(
|
||||
[
|
||||
ensure_csrf_cookie,
|
||||
cache_control(no_cache=True),
|
||||
],
|
||||
name="dispatch",
|
||||
)
|
||||
class ChatStreamingView(GenericAPIView):
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = ChatStreamingSerializer
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
request.compress_exempt = True
|
||||
ai_config = AIConfig()
|
||||
if not ai_config.ai_enabled:
|
||||
return HttpResponseBadRequest("AI is required for this feature")
|
||||
|
||||
try:
|
||||
question = request.data["q"]
|
||||
except KeyError:
|
||||
return HttpResponseBadRequest("Invalid request")
|
||||
|
||||
doc_id = request.data.get("document_id")
|
||||
|
||||
if doc_id:
|
||||
try:
|
||||
document = Document.objects.get(id=doc_id)
|
||||
except Document.DoesNotExist:
|
||||
return HttpResponseBadRequest("Document not found")
|
||||
|
||||
if not has_perms_owner_aware(request.user, "view_document", document):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
|
||||
documents = [document]
|
||||
else:
|
||||
documents = get_objects_for_user_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
stream_chat_with_documents(query_str=question, documents=documents),
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
description="Document views including search",
|
||||
@@ -2446,6 +2578,10 @@ class UiSettingsView(GenericAPIView):
|
||||
|
||||
ui_settings["email_enabled"] = settings.EMAIL_ENABLED
|
||||
|
||||
ai_config = AIConfig()
|
||||
|
||||
ui_settings["ai_enabled"] = ai_config.ai_enabled
|
||||
|
||||
user_resp = {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
@@ -2587,6 +2723,10 @@ class TasksViewSet(ReadOnlyModelViewSet):
|
||||
sanity_check,
|
||||
{"scheduled": False, "raise_on_error": False},
|
||||
),
|
||||
PaperlessTask.TaskName.LLMINDEX_UPDATE: (
|
||||
llmindex_index,
|
||||
{"scheduled": False, "rebuild": False},
|
||||
),
|
||||
}
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -3106,6 +3246,31 @@ class SystemStatusView(PassUserMixin):
|
||||
last_sanity_check.date_done if last_sanity_check else None
|
||||
)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if not ai_config.llm_index_enabled:
|
||||
llmindex_status = "DISABLED"
|
||||
llmindex_error = None
|
||||
llmindex_last_modified = None
|
||||
else:
|
||||
last_llmindex_update = (
|
||||
PaperlessTask.objects.filter(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
.order_by("-date_done")
|
||||
.first()
|
||||
)
|
||||
llmindex_status = "OK"
|
||||
llmindex_error = None
|
||||
if last_llmindex_update is None:
|
||||
llmindex_status = "WARNING"
|
||||
llmindex_error = "No LLM index update tasks found"
|
||||
elif last_llmindex_update and last_llmindex_update.status == states.FAILURE:
|
||||
llmindex_status = "ERROR"
|
||||
llmindex_error = last_llmindex_update.result
|
||||
llmindex_last_modified = (
|
||||
last_llmindex_update.date_done if last_llmindex_update else None
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"pngx_version": current_version,
|
||||
@@ -3143,6 +3308,9 @@ class SystemStatusView(PassUserMixin):
|
||||
"sanity_check_status": sanity_check_status,
|
||||
"sanity_check_last_run": sanity_check_last_run,
|
||||
"sanity_check_error": sanity_check_error,
|
||||
"llmindex_status": llmindex_status,
|
||||
"llmindex_last_modified": llmindex_last_modified,
|
||||
"llmindex_error": llmindex_error,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user