From efd65300a1272320e140b6ea49aaac6c632a84de Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 28 Jan 2026 11:43:32 -0800 Subject: [PATCH] Experiments with using adrf for a few views --- pyproject.toml | 1 + src/paperless_mail/views.py | 215 ++++++++++++++++++++---------------- uv.lock | 25 +++++ 3 files changed, 145 insertions(+), 96 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 34474feda..6d0d339e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ classifiers = [ # This will allow testing to not install a webserver, mysql, etc dependencies = [ + "adrf~=0.1.12", "azure-ai-documentintelligence>=1.0.2", "babel>=2.17", "bleach~=6.3.0", diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index b54bcb5f7..7a3ef993d 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -1,7 +1,12 @@ import datetime import logging from datetime import timedelta +from typing import Any +from adrf.views import APIView +from adrf.viewsets import ModelViewSet +from adrf.viewsets import ReadOnlyModelViewSet +from asgiref.sync import sync_to_async from django.http import HttpResponseBadRequest from django.http import HttpResponseForbidden from django.http import HttpResponseRedirect @@ -15,11 +20,9 @@ from httpx_oauth.oauth2 import GetAccessTokenError from rest_framework import serializers from rest_framework.decorators import action from rest_framework.filters import OrderingFilter -from rest_framework.generics import GenericAPIView from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request from rest_framework.response import Response -from rest_framework.viewsets import ModelViewSet -from rest_framework.viewsets import ReadOnlyModelViewSet from documents.filters import ObjectOwnedOrGrantedPermissionsFilter from documents.permissions import PaperlessObjectPermissions @@ -39,6 +42,8 @@ from paperless_mail.serialisers import MailRuleSerializer from paperless_mail.serialisers import ProcessedMailSerializer from paperless_mail.tasks import process_mail_accounts +logger: logging.Logger = logging.getLogger("paperless_mail") + @extend_schema_view( test=extend_schema( @@ -66,71 +71,75 @@ from paperless_mail.tasks import process_mail_accounts ), ) class MailAccountViewSet(ModelViewSet, PassUserMixin): - model = MailAccount - queryset = MailAccount.objects.all().order_by("pk") serializer_class = MailAccountSerializer pagination_class = StandardPagination permission_classes = (IsAuthenticated, PaperlessObjectPermissions) filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,) - def get_permissions(self): + def get_permissions(self) -> list[Any]: if self.action == "test": - # Test action does not require object level permissions - self.permission_classes = (IsAuthenticated,) + return [IsAuthenticated()] return super().get_permissions() @action(methods=["post"], detail=False) - def test(self, request): - logger = logging.getLogger("paperless_mail") + async def test(self, request: Request) -> Response | HttpResponseBadRequest: request.data["name"] = datetime.datetime.now().isoformat() serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - # account exists, use the password from there instead of *** and refresh_token / expiration + # Validation must be wrapped because of sync DB validators + await sync_to_async(serializer.is_valid)(raise_exception=True) + + validated_data: dict[str, Any] = serializer.validated_data + if ( - len(serializer.validated_data.get("password").replace("*", "")) == 0 - and request.data["id"] is not None + len(str(validated_data.get("password", "")).replace("*", "")) == 0 + and request.data.get("id") is not None ): - existing_account = MailAccount.objects.get(pk=request.data["id"]) - serializer.validated_data["password"] = existing_account.password - serializer.validated_data["account_type"] = existing_account.account_type - serializer.validated_data["refresh_token"] = existing_account.refresh_token - serializer.validated_data["expiration"] = existing_account.expiration + existing_account = await MailAccount.objects.aget(pk=request.data["id"]) + validated_data.update( + { + "password": existing_account.password, + "account_type": existing_account.account_type, + "refresh_token": existing_account.refresh_token, + "expiration": existing_account.expiration, + }, + ) - account = MailAccount(**serializer.validated_data) - with get_mailbox( - account.imap_server, - account.imap_port, - account.imap_security, - ) as M: - try: + account = MailAccount(**validated_data) + + def _blocking_imap_test() -> bool: + with get_mailbox( + account.imap_server, + account.imap_port, + account.imap_security, + ) as m_box: if ( account.is_token - and account.expiration is not None + and account.expiration and account.expiration < timezone.now() ): oauth_manager = PaperlessMailOAuth2Manager() if oauth_manager.refresh_account_oauth_token(existing_account): # User is not changing password and token needs to be refreshed - existing_account.refresh_from_db() account.password = existing_account.password else: raise MailError("Unable to refresh oauth token") + mailbox_login(m_box, account) + return True - mailbox_login(M, account) - return Response({"success": True}) - except MailError as e: - logger.error( - f"Mail account {account} test failed: {e}", - ) - return HttpResponseBadRequest("Unable to connect to server") + try: + await sync_to_async(_blocking_imap_test, thread_sensitive=False)() + return Response({"success": True}) + except MailError as e: + logger.error(f"Mail account {account} test failed: {e}") + return HttpResponseBadRequest("Unable to connect to server") @action(methods=["post"], detail=True) - def process(self, request, pk=None): - account = self.get_object() + async def process(self, request: Request, pk: int | None = None) -> Response: + # FIX: Use aget_object() provided by adrf to avoid SynchronousOnlyOperation + account = await self.aget_object() process_mail_accounts.delay([account.pk]) - return Response({"result": "OK"}) @@ -144,21 +153,38 @@ class ProcessedMailViewSet(ReadOnlyModelViewSet, PassUserMixin): ObjectOwnedOrGrantedPermissionsFilter, ) filterset_class = ProcessedMailFilterSet - queryset = ProcessedMail.objects.all().order_by("-processed") @action(methods=["post"], detail=False) - def bulk_delete(self, request): - mail_ids = request.data.get("mail_ids", []) + async def bulk_delete( + self, + request: Request, + ) -> Response | HttpResponseBadRequest | HttpResponseForbidden: + mail_ids: list[int] = request.data.get("mail_ids", []) if not isinstance(mail_ids, list) or not all( isinstance(i, int) for i in mail_ids ): return HttpResponseBadRequest("mail_ids must be a list of integers") - mails = ProcessedMail.objects.filter(id__in=mail_ids) - for mail in mails: - if not has_perms_owner_aware(request.user, "delete_processedmail", mail): + + # Store objects to delete after verification + to_delete: list[ProcessedMail] = [] + + # We must verify permissions for every requested ID + async for mail in ProcessedMail.objects.filter(id__in=mail_ids): + can_delete = await sync_to_async(has_perms_owner_aware)( + request.user, + "delete_processedmail", + mail, + ) + if not can_delete: + # This is what the test is looking for: 403 on permission failure return HttpResponseForbidden("Insufficient permissions") - mail.delete() + to_delete.append(mail) + + # Only perform deletions if all items passed the permission check + for mail in to_delete: + await mail.adelete() + return Response({"result": "OK", "deleted_mail_ids": mail_ids}) @@ -178,77 +204,74 @@ class MailRuleViewSet(ModelViewSet, PassUserMixin): responses={200: None}, ), ) -class OauthCallbackView(GenericAPIView): +class OauthCallbackView(APIView): permission_classes = (IsAuthenticated,) - def get(self, request, format=None): - if not ( - request.user and request.user.has_perms(["paperless_mail.add_mailaccount"]) - ): + async def get( + self, + request: Request, + ) -> Response | HttpResponseBadRequest | HttpResponseRedirect: + has_perm = await sync_to_async(request.user.has_perm)( + "paperless_mail.add_mailaccount", + ) + if not has_perm: return HttpResponseBadRequest( "You do not have permission to add mail accounts", ) - logger = logging.getLogger("paperless_mail") - code = request.query_params.get("code") - # Gmail passes scope as a query param, Outlook does not - scope = request.query_params.get("scope") + code: str | None = request.query_params.get("code") + state: str | None = request.query_params.get("state") + scope: str | None = request.query_params.get("scope") - if code is None: - logger.error( - f"Invalid oauth callback request, code: {code}, scope: {scope}", - ) - return HttpResponseBadRequest("Invalid request, see logs for more detail") + if not code or not state: + return HttpResponseBadRequest("Invalid request parameters") oauth_manager = PaperlessMailOAuth2Manager( state=request.session.get("oauth_state"), ) - - state = request.query_params.get("state", "") if not oauth_manager.validate_state(state): - logger.error( - f"Invalid oauth callback request received state: {state}, expected: {oauth_manager.state}", - ) - return HttpResponseBadRequest("Invalid request, see logs for more detail") + return HttpResponseBadRequest("Invalid OAuth state") try: - if scope is not None and "google" in scope: - # Google + defaults: dict[str, Any] = { + "username": "", + "imap_security": MailAccount.ImapSecurity.SSL, + "imap_port": 993, + } + + if scope and "google" in scope: account_type = MailAccount.MailAccountType.GMAIL_OAUTH imap_server = "imap.gmail.com" - defaults = { - "name": f"Gmail OAuth {timezone.now()}", - "username": "", - "imap_security": MailAccount.ImapSecurity.SSL, - "imap_port": 993, - "account_type": account_type, - } - result = oauth_manager.get_gmail_access_token(code) - - elif scope is None: - # Outlook + defaults.update( + { + "name": f"Gmail OAuth {timezone.now()}", + "account_type": account_type, + }, + ) + result = await sync_to_async(oauth_manager.get_gmail_access_token)(code) + else: account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH imap_server = "outlook.office365.com" - defaults = { - "name": f"Outlook OAuth {timezone.now()}", - "username": "", - "imap_security": MailAccount.ImapSecurity.SSL, - "imap_port": 993, - "account_type": account_type, - } + defaults.update( + { + "name": f"Outlook OAuth {timezone.now()}", + "account_type": account_type, + }, + ) + result = await sync_to_async(oauth_manager.get_outlook_access_token)( + code, + ) - result = oauth_manager.get_outlook_access_token(code) - - access_token = result["access_token"] - refresh_token = result["refresh_token"] - expires_in = result["expires_in"] - account, _ = MailAccount.objects.update_or_create( - password=access_token, - is_token=True, + account, _ = await MailAccount.objects.aupdate_or_create( imap_server=imap_server, - refresh_token=refresh_token, - expiration=timezone.now() + timedelta(seconds=expires_in), - defaults=defaults, + refresh_token=result["refresh_token"], + defaults={ + **defaults, + "password": result["access_token"], + "is_token": True, + "expiration": timezone.now() + + timedelta(seconds=result["expires_in"]), + }, ) return HttpResponseRedirect( f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}", diff --git a/uv.lock b/uv.lock index 960b5aaa3..50b6eca8c 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,20 @@ supported-markers = [ "sys_platform == 'linux'", ] +[[package]] +name = "adrf" +version = "0.1.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-property", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/fe/573d7ec8805aec8e13459451f81398c6ac819882496e82fb6b4ae96c2762/adrf-0.1.12.tar.gz", hash = "sha256:e7aa49e5406b168f040f1a12cafb606e98fdd5467314240a9c42dbe63200d2c1", size = 17181, upload-time = "2025-11-24T03:25:44.337Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/47/14006c4045f818cf625d2357e823a15b4bb7fab8dd81a40acd23af41ee53/adrf-0.1.12-py3-none-any.whl", hash = "sha256:e9d1f343b82158f4c528c0809c9635a27ceef4c37d3d8e61b8096c8eeded616d", size = 20199, upload-time = "2025-11-24T03:25:43.291Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -203,6 +217,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/d1/69d02ce34caddb0a7ae088b84c356a625a93cd4ff57b2f97644c03fad905/asgiref-3.9.2-py3-none-any.whl", hash = "sha256:0b61526596219d70396548fc003635056856dba5d0d086f86476f10b33c75960", size = 23788, upload-time = "2025-09-23T15:00:53.627Z" }, ] +[[package]] +name = "async-property" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/12/900eb34b3af75c11b69d6b78b74ec0fd1ba489376eceb3785f787d1a0a1d/async_property-0.2.2.tar.gz", hash = "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", size = 16523, upload-time = "2023-07-03T17:21:55.688Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/80/9f608d13b4b3afcebd1dd13baf9551c95fc424d6390e4b1cfd7b1810cd06/async_property-0.2.2-py2.py3-none-any.whl", hash = "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7", size = 9546, upload-time = "2023-07-03T17:21:54.293Z" }, +] + [[package]] name = "async-timeout" version = "5.0.1" @@ -2919,6 +2942,7 @@ name = "paperless-ngx" version = "2.20.5" source = { virtual = "." } dependencies = [ + { name = "adrf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "azure-ai-documentintelligence", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "babel", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "bleach", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -3067,6 +3091,7 @@ typing = [ [package.metadata] requires-dist = [ + { name = "adrf", specifier = "~=0.1.12" }, { name = "azure-ai-documentintelligence", specifier = ">=1.0.2" }, { name = "babel", specifier = ">=2.17" }, { name = "bleach", specifier = "~=6.3.0" },