diff --git a/src/documents/tests/test_views.py b/src/documents/tests/test_views.py index 4c987e3af..ef5a71e01 100644 --- a/src/documents/tests/test_views.py +++ b/src/documents/tests/test_views.py @@ -1,6 +1,8 @@ import tempfile from datetime import timedelta from pathlib import Path +from unittest.mock import MagicMock +from unittest.mock import patch from django.conf import settings from django.contrib.auth.models import Permission @@ -10,8 +12,15 @@ from django.test import override_settings from django.utils import timezone from rest_framework import status +from documents.caching import get_llm_suggestion_cache +from documents.caching import set_llm_suggestions_cache +from documents.models import Correspondent from documents.models import Document +from documents.models import DocumentType from documents.models import ShareLink +from documents.models import StoragePath +from documents.models import Tag +from documents.signals.handlers import update_llm_suggestions_cache from documents.tests.utils import DirectoriesMixin from paperless.models import ApplicationConfiguration @@ -154,3 +163,104 @@ class TestViews(DirectoriesMixin, TestCase): response.render() self.assertEqual(response.request["PATH_INFO"], "/accounts/login/") self.assertContains(response, b"Share link has expired") + + +class TestAISuggestions(DirectoriesMixin, TestCase): + def setUp(self): + self.user = User.objects.create_superuser(username="testuser") + self.document = Document.objects.create( + title="Test Document", + filename="test.pdf", + mime_type="application/pdf", + ) + self.tag1 = Tag.objects.create(name="tag1") + self.correspondent1 = Correspondent.objects.create(name="correspondent1") + self.document_type1 = DocumentType.objects.create(name="type1") + self.path1 = StoragePath.objects.create(name="path1") + super().setUp() + + @patch("documents.views.get_llm_suggestion_cache") + @patch("documents.views.refresh_suggestions_cache") + @override_settings( + AI_ENABLED=True, + LLM_BACKEND="mock_backend", + ) + def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache): + mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]}) + + self.client.force_login(user=self.user) + response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]}) + mock_refresh_cache.assert_called_once_with(self.document.pk) + + @patch("documents.views.get_ai_document_classification") + @override_settings( + AI_ENABLED=True, + LLM_BACKEND="mock_backend", + ) + def test_suggestions_with_ai_enabled( + self, + mock_get_ai_classification, + ): + mock_get_ai_classification.return_value = { + "title": "AI Title", + "tags": ["tag1", "tag2"], + "correspondents": ["correspondent1"], + "document_types": ["type1"], + "storage_paths": ["path1"], + "dates": ["2023-01-01"], + } + + self.client.force_login(user=self.user) + response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.json(), + { + "title": "AI Title", + "tags": [self.tag1.pk], + "suggested_tags": ["tag2"], + "correspondents": [self.correspondent1.pk], + "suggested_correspondents": [], + "document_types": [self.document_type1.pk], + "suggested_document_types": [], + "storage_paths": [self.path1.pk], + "suggested_storage_paths": [], + "dates": ["2023-01-01"], + }, + ) + + def test_invalidate_suggestions_cache(self): + self.client.force_login(user=self.user) + suggestions = { + "title": "AI Title", + "tags": ["tag1", "tag2"], + "correspondents": ["correspondent1"], + "document_types": ["type1"], + "storage_paths": ["path1"], + "dates": ["2023-01-01"], + } + set_llm_suggestions_cache( + self.document.pk, + suggestions, + backend="mock_backend", + ) + self.assertEqual( + get_llm_suggestion_cache( + self.document.pk, + backend="mock_backend", + ).suggestions, + suggestions, + ) + # post_save signal triggered + update_llm_suggestions_cache( + sender=None, + instance=self.document, + ) + self.assertIsNone( + get_llm_suggestion_cache( + self.document.pk, + backend="mock_backend", + ), + ) diff --git a/src/documents/views.py b/src/documents/views.py index 5ea41025a..8ddfaadf0 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -739,51 +739,57 @@ class DocumentViewSet( return HttpResponseForbidden("Insufficient permissions") if settings.AI_ENABLED: - cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND) + cached_llm_suggestions = get_llm_suggestion_cache( + doc.pk, + backend=settings.LLM_BACKEND, + ) - if cached: + if cached_llm_suggestions: refresh_suggestions_cache(doc.pk) - return Response(cached.suggestions) + return Response(cached_llm_suggestions.suggestions) - llm_resp = get_ai_document_classification(doc) + llm_suggestions = get_ai_document_classification(doc) - matched_tags = match_tags_by_name(llm_resp.get("tags", []), request.user) + matched_tags = match_tags_by_name( + llm_suggestions.get("tags", []), + request.user, + ) matched_correspondents = match_correspondents_by_name( - llm_resp.get("correspondents", []), + llm_suggestions.get("correspondents", []), request.user, ) matched_types = match_document_types_by_name( - llm_resp.get("document_types", []), + llm_suggestions.get("document_types", []), request.user, ) matched_paths = match_storage_paths_by_name( - llm_resp.get("storage_paths", []), + llm_suggestions.get("storage_paths", []), request.user, ) resp_data = { - "title": llm_resp.get("title"), + "title": llm_suggestions.get("title"), "tags": [t.id for t in matched_tags], "suggested_tags": extract_unmatched_names( - llm_resp.get("tags", []), + llm_suggestions.get("tags", []), matched_tags, ), "correspondents": [c.id for c in matched_correspondents], "suggested_correspondents": extract_unmatched_names( - llm_resp.get("correspondents", []), + llm_suggestions.get("correspondents", []), matched_correspondents, ), "document_types": [d.id for d in matched_types], "suggested_document_types": extract_unmatched_names( - llm_resp.get("document_types", []), + llm_suggestions.get("document_types", []), matched_types, ), "storage_paths": [s.id for s in matched_paths], "suggested_storage_paths": extract_unmatched_names( - llm_resp.get("storage_paths", []), + llm_suggestions.get("storage_paths", []), matched_paths, ), - "dates": llm_resp.get("dates", []), + "dates": llm_suggestions.get("dates", []), } set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)