Compare commits

...

1 Commits

Author SHA1 Message Date
Trenton H
efd65300a1 Experiments with using adrf for a few views 2026-01-28 11:43:32 -08:00
3 changed files with 145 additions and 96 deletions

View File

@@ -16,6 +16,7 @@ classifiers = [
# This will allow testing to not install a webserver, mysql, etc # This will allow testing to not install a webserver, mysql, etc
dependencies = [ dependencies = [
"adrf~=0.1.12",
"azure-ai-documentintelligence>=1.0.2", "azure-ai-documentintelligence>=1.0.2",
"babel>=2.17", "babel>=2.17",
"bleach~=6.3.0", "bleach~=6.3.0",

View File

@@ -1,7 +1,12 @@
import datetime import datetime
import logging import logging
from datetime import timedelta from datetime import timedelta
from typing import Any
from adrf.views import APIView
from adrf.viewsets import ModelViewSet
from adrf.viewsets import ReadOnlyModelViewSet
from asgiref.sync import sync_to_async
from django.http import HttpResponseBadRequest from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden from django.http import HttpResponseForbidden
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
@@ -15,11 +20,9 @@ from httpx_oauth.oauth2 import GetAccessTokenError
from rest_framework import serializers from rest_framework import serializers
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework.viewsets import ReadOnlyModelViewSet
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
from documents.permissions import PaperlessObjectPermissions from documents.permissions import PaperlessObjectPermissions
@@ -39,6 +42,8 @@ from paperless_mail.serialisers import MailRuleSerializer
from paperless_mail.serialisers import ProcessedMailSerializer from paperless_mail.serialisers import ProcessedMailSerializer
from paperless_mail.tasks import process_mail_accounts from paperless_mail.tasks import process_mail_accounts
logger: logging.Logger = logging.getLogger("paperless_mail")
@extend_schema_view( @extend_schema_view(
test=extend_schema( test=extend_schema(
@@ -66,71 +71,75 @@ from paperless_mail.tasks import process_mail_accounts
), ),
) )
class MailAccountViewSet(ModelViewSet, PassUserMixin): class MailAccountViewSet(ModelViewSet, PassUserMixin):
model = MailAccount
queryset = MailAccount.objects.all().order_by("pk") queryset = MailAccount.objects.all().order_by("pk")
serializer_class = MailAccountSerializer serializer_class = MailAccountSerializer
pagination_class = StandardPagination pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions) permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,) filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,)
def get_permissions(self): def get_permissions(self) -> list[Any]:
if self.action == "test": if self.action == "test":
# Test action does not require object level permissions return [IsAuthenticated()]
self.permission_classes = (IsAuthenticated,)
return super().get_permissions() return super().get_permissions()
@action(methods=["post"], detail=False) @action(methods=["post"], detail=False)
def test(self, request): async def test(self, request: Request) -> Response | HttpResponseBadRequest:
logger = logging.getLogger("paperless_mail")
request.data["name"] = datetime.datetime.now().isoformat() request.data["name"] = datetime.datetime.now().isoformat()
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
# account exists, use the password from there instead of *** and refresh_token / expiration # Validation must be wrapped because of sync DB validators
await sync_to_async(serializer.is_valid)(raise_exception=True)
validated_data: dict[str, Any] = serializer.validated_data
if ( if (
len(serializer.validated_data.get("password").replace("*", "")) == 0 len(str(validated_data.get("password", "")).replace("*", "")) == 0
and request.data["id"] is not None and request.data.get("id") is not None
): ):
existing_account = MailAccount.objects.get(pk=request.data["id"]) existing_account = await MailAccount.objects.aget(pk=request.data["id"])
serializer.validated_data["password"] = existing_account.password validated_data.update(
serializer.validated_data["account_type"] = existing_account.account_type {
serializer.validated_data["refresh_token"] = existing_account.refresh_token "password": existing_account.password,
serializer.validated_data["expiration"] = existing_account.expiration "account_type": existing_account.account_type,
"refresh_token": existing_account.refresh_token,
"expiration": existing_account.expiration,
},
)
account = MailAccount(**serializer.validated_data) account = MailAccount(**validated_data)
with get_mailbox(
account.imap_server, def _blocking_imap_test() -> bool:
account.imap_port, with get_mailbox(
account.imap_security, account.imap_server,
) as M: account.imap_port,
try: account.imap_security,
) as m_box:
if ( if (
account.is_token account.is_token
and account.expiration is not None and account.expiration
and account.expiration < timezone.now() and account.expiration < timezone.now()
): ):
oauth_manager = PaperlessMailOAuth2Manager() oauth_manager = PaperlessMailOAuth2Manager()
if oauth_manager.refresh_account_oauth_token(existing_account): if oauth_manager.refresh_account_oauth_token(existing_account):
# User is not changing password and token needs to be refreshed # User is not changing password and token needs to be refreshed
existing_account.refresh_from_db()
account.password = existing_account.password account.password = existing_account.password
else: else:
raise MailError("Unable to refresh oauth token") raise MailError("Unable to refresh oauth token")
mailbox_login(m_box, account)
return True
mailbox_login(M, account) try:
return Response({"success": True}) await sync_to_async(_blocking_imap_test, thread_sensitive=False)()
except MailError as e: return Response({"success": True})
logger.error( except MailError as e:
f"Mail account {account} test failed: {e}", logger.error(f"Mail account {account} test failed: {e}")
) return HttpResponseBadRequest("Unable to connect to server")
return HttpResponseBadRequest("Unable to connect to server")
@action(methods=["post"], detail=True) @action(methods=["post"], detail=True)
def process(self, request, pk=None): async def process(self, request: Request, pk: int | None = None) -> Response:
account = self.get_object() # FIX: Use aget_object() provided by adrf to avoid SynchronousOnlyOperation
account = await self.aget_object()
process_mail_accounts.delay([account.pk]) process_mail_accounts.delay([account.pk])
return Response({"result": "OK"}) return Response({"result": "OK"})
@@ -144,21 +153,38 @@ class ProcessedMailViewSet(ReadOnlyModelViewSet, PassUserMixin):
ObjectOwnedOrGrantedPermissionsFilter, ObjectOwnedOrGrantedPermissionsFilter,
) )
filterset_class = ProcessedMailFilterSet filterset_class = ProcessedMailFilterSet
queryset = ProcessedMail.objects.all().order_by("-processed") queryset = ProcessedMail.objects.all().order_by("-processed")
@action(methods=["post"], detail=False) @action(methods=["post"], detail=False)
def bulk_delete(self, request): async def bulk_delete(
mail_ids = request.data.get("mail_ids", []) self,
request: Request,
) -> Response | HttpResponseBadRequest | HttpResponseForbidden:
mail_ids: list[int] = request.data.get("mail_ids", [])
if not isinstance(mail_ids, list) or not all( if not isinstance(mail_ids, list) or not all(
isinstance(i, int) for i in mail_ids isinstance(i, int) for i in mail_ids
): ):
return HttpResponseBadRequest("mail_ids must be a list of integers") return HttpResponseBadRequest("mail_ids must be a list of integers")
mails = ProcessedMail.objects.filter(id__in=mail_ids)
for mail in mails: # Store objects to delete after verification
if not has_perms_owner_aware(request.user, "delete_processedmail", mail): to_delete: list[ProcessedMail] = []
# We must verify permissions for every requested ID
async for mail in ProcessedMail.objects.filter(id__in=mail_ids):
can_delete = await sync_to_async(has_perms_owner_aware)(
request.user,
"delete_processedmail",
mail,
)
if not can_delete:
# This is what the test is looking for: 403 on permission failure
return HttpResponseForbidden("Insufficient permissions") return HttpResponseForbidden("Insufficient permissions")
mail.delete() to_delete.append(mail)
# Only perform deletions if all items passed the permission check
for mail in to_delete:
await mail.adelete()
return Response({"result": "OK", "deleted_mail_ids": mail_ids}) return Response({"result": "OK", "deleted_mail_ids": mail_ids})
@@ -178,77 +204,74 @@ class MailRuleViewSet(ModelViewSet, PassUserMixin):
responses={200: None}, responses={200: None},
), ),
) )
class OauthCallbackView(GenericAPIView): class OauthCallbackView(APIView):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
def get(self, request, format=None): async def get(
if not ( self,
request.user and request.user.has_perms(["paperless_mail.add_mailaccount"]) request: Request,
): ) -> Response | HttpResponseBadRequest | HttpResponseRedirect:
has_perm = await sync_to_async(request.user.has_perm)(
"paperless_mail.add_mailaccount",
)
if not has_perm:
return HttpResponseBadRequest( return HttpResponseBadRequest(
"You do not have permission to add mail accounts", "You do not have permission to add mail accounts",
) )
logger = logging.getLogger("paperless_mail") code: str | None = request.query_params.get("code")
code = request.query_params.get("code") state: str | None = request.query_params.get("state")
# Gmail passes scope as a query param, Outlook does not scope: str | None = request.query_params.get("scope")
scope = request.query_params.get("scope")
if code is None: if not code or not state:
logger.error( return HttpResponseBadRequest("Invalid request parameters")
f"Invalid oauth callback request, code: {code}, scope: {scope}",
)
return HttpResponseBadRequest("Invalid request, see logs for more detail")
oauth_manager = PaperlessMailOAuth2Manager( oauth_manager = PaperlessMailOAuth2Manager(
state=request.session.get("oauth_state"), state=request.session.get("oauth_state"),
) )
state = request.query_params.get("state", "")
if not oauth_manager.validate_state(state): if not oauth_manager.validate_state(state):
logger.error( return HttpResponseBadRequest("Invalid OAuth state")
f"Invalid oauth callback request received state: {state}, expected: {oauth_manager.state}",
)
return HttpResponseBadRequest("Invalid request, see logs for more detail")
try: try:
if scope is not None and "google" in scope: defaults: dict[str, Any] = {
# Google "username": "",
"imap_security": MailAccount.ImapSecurity.SSL,
"imap_port": 993,
}
if scope and "google" in scope:
account_type = MailAccount.MailAccountType.GMAIL_OAUTH account_type = MailAccount.MailAccountType.GMAIL_OAUTH
imap_server = "imap.gmail.com" imap_server = "imap.gmail.com"
defaults = { defaults.update(
"name": f"Gmail OAuth {timezone.now()}", {
"username": "", "name": f"Gmail OAuth {timezone.now()}",
"imap_security": MailAccount.ImapSecurity.SSL, "account_type": account_type,
"imap_port": 993, },
"account_type": account_type, )
} result = await sync_to_async(oauth_manager.get_gmail_access_token)(code)
result = oauth_manager.get_gmail_access_token(code) else:
elif scope is None:
# Outlook
account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH
imap_server = "outlook.office365.com" imap_server = "outlook.office365.com"
defaults = { defaults.update(
"name": f"Outlook OAuth {timezone.now()}", {
"username": "", "name": f"Outlook OAuth {timezone.now()}",
"imap_security": MailAccount.ImapSecurity.SSL, "account_type": account_type,
"imap_port": 993, },
"account_type": account_type, )
} result = await sync_to_async(oauth_manager.get_outlook_access_token)(
code,
)
result = oauth_manager.get_outlook_access_token(code) account, _ = await MailAccount.objects.aupdate_or_create(
access_token = result["access_token"]
refresh_token = result["refresh_token"]
expires_in = result["expires_in"]
account, _ = MailAccount.objects.update_or_create(
password=access_token,
is_token=True,
imap_server=imap_server, imap_server=imap_server,
refresh_token=refresh_token, refresh_token=result["refresh_token"],
expiration=timezone.now() + timedelta(seconds=expires_in), defaults={
defaults=defaults, **defaults,
"password": result["access_token"],
"is_token": True,
"expiration": timezone.now()
+ timedelta(seconds=result["expires_in"]),
},
) )
return HttpResponseRedirect( return HttpResponseRedirect(
f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}", f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}",

25
uv.lock generated
View File

@@ -16,6 +16,20 @@ supported-markers = [
"sys_platform == 'linux'", "sys_platform == 'linux'",
] ]
[[package]]
name = "adrf"
version = "0.1.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "async-property", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/10/fe/573d7ec8805aec8e13459451f81398c6ac819882496e82fb6b4ae96c2762/adrf-0.1.12.tar.gz", hash = "sha256:e7aa49e5406b168f040f1a12cafb606e98fdd5467314240a9c42dbe63200d2c1", size = 17181, upload-time = "2025-11-24T03:25:44.337Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/47/14006c4045f818cf625d2357e823a15b4bb7fab8dd81a40acd23af41ee53/adrf-0.1.12-py3-none-any.whl", hash = "sha256:e9d1f343b82158f4c528c0809c9635a27ceef4c37d3d8e61b8096c8eeded616d", size = 20199, upload-time = "2025-11-24T03:25:43.291Z" },
]
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
version = "2.6.1" version = "2.6.1"
@@ -203,6 +217,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/d1/69d02ce34caddb0a7ae088b84c356a625a93cd4ff57b2f97644c03fad905/asgiref-3.9.2-py3-none-any.whl", hash = "sha256:0b61526596219d70396548fc003635056856dba5d0d086f86476f10b33c75960", size = 23788, upload-time = "2025-09-23T15:00:53.627Z" }, { url = "https://files.pythonhosted.org/packages/c7/d1/69d02ce34caddb0a7ae088b84c356a625a93cd4ff57b2f97644c03fad905/asgiref-3.9.2-py3-none-any.whl", hash = "sha256:0b61526596219d70396548fc003635056856dba5d0d086f86476f10b33c75960", size = 23788, upload-time = "2025-09-23T15:00:53.627Z" },
] ]
[[package]]
name = "async-property"
version = "0.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a7/12/900eb34b3af75c11b69d6b78b74ec0fd1ba489376eceb3785f787d1a0a1d/async_property-0.2.2.tar.gz", hash = "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", size = 16523, upload-time = "2023-07-03T17:21:55.688Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/80/9f608d13b4b3afcebd1dd13baf9551c95fc424d6390e4b1cfd7b1810cd06/async_property-0.2.2-py2.py3-none-any.whl", hash = "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7", size = 9546, upload-time = "2023-07-03T17:21:54.293Z" },
]
[[package]] [[package]]
name = "async-timeout" name = "async-timeout"
version = "5.0.1" version = "5.0.1"
@@ -2919,6 +2942,7 @@ name = "paperless-ngx"
version = "2.20.5" version = "2.20.5"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "adrf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "azure-ai-documentintelligence", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "azure-ai-documentintelligence", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "babel", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "babel", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "bleach", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "bleach", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -3067,6 +3091,7 @@ typing = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "adrf", specifier = "~=0.1.12" },
{ name = "azure-ai-documentintelligence", specifier = ">=1.0.2" }, { name = "azure-ai-documentintelligence", specifier = ">=1.0.2" },
{ name = "babel", specifier = ">=2.17" }, { name = "babel", specifier = ">=2.17" },
{ name = "bleach", specifier = "~=6.3.0" }, { name = "bleach", specifier = "~=6.3.0" },