Compare commits

...

1 Commits

Author SHA1 Message Date
Trenton H
efd65300a1 Experiments with using adrf for a few views 2026-01-28 11:43:32 -08:00
3 changed files with 145 additions and 96 deletions

View File

@@ -16,6 +16,7 @@ classifiers = [
# This will allow testing to not install a webserver, mysql, etc
dependencies = [
"adrf~=0.1.12",
"azure-ai-documentintelligence>=1.0.2",
"babel>=2.17",
"bleach~=6.3.0",

View File

@@ -1,7 +1,12 @@
import datetime
import logging
from datetime import timedelta
from typing import Any
from adrf.views import APIView
from adrf.viewsets import ModelViewSet
from adrf.viewsets import ReadOnlyModelViewSet
from asgiref.sync import sync_to_async
from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden
from django.http import HttpResponseRedirect
@@ -15,11 +20,9 @@ from httpx_oauth.oauth2 import GetAccessTokenError
from rest_framework import serializers
from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework.viewsets import ReadOnlyModelViewSet
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
from documents.permissions import PaperlessObjectPermissions
@@ -39,6 +42,8 @@ from paperless_mail.serialisers import MailRuleSerializer
from paperless_mail.serialisers import ProcessedMailSerializer
from paperless_mail.tasks import process_mail_accounts
logger: logging.Logger = logging.getLogger("paperless_mail")
@extend_schema_view(
test=extend_schema(
@@ -66,71 +71,75 @@ from paperless_mail.tasks import process_mail_accounts
),
)
class MailAccountViewSet(ModelViewSet, PassUserMixin):
model = MailAccount
queryset = MailAccount.objects.all().order_by("pk")
serializer_class = MailAccountSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,)
def get_permissions(self):
def get_permissions(self) -> list[Any]:
if self.action == "test":
# Test action does not require object level permissions
self.permission_classes = (IsAuthenticated,)
return [IsAuthenticated()]
return super().get_permissions()
@action(methods=["post"], detail=False)
def test(self, request):
logger = logging.getLogger("paperless_mail")
async def test(self, request: Request) -> Response | HttpResponseBadRequest:
request.data["name"] = datetime.datetime.now().isoformat()
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
# account exists, use the password from there instead of *** and refresh_token / expiration
# Validation must be wrapped because of sync DB validators
await sync_to_async(serializer.is_valid)(raise_exception=True)
validated_data: dict[str, Any] = serializer.validated_data
if (
len(serializer.validated_data.get("password").replace("*", "")) == 0
and request.data["id"] is not None
len(str(validated_data.get("password", "")).replace("*", "")) == 0
and request.data.get("id") is not None
):
existing_account = MailAccount.objects.get(pk=request.data["id"])
serializer.validated_data["password"] = existing_account.password
serializer.validated_data["account_type"] = existing_account.account_type
serializer.validated_data["refresh_token"] = existing_account.refresh_token
serializer.validated_data["expiration"] = existing_account.expiration
existing_account = await MailAccount.objects.aget(pk=request.data["id"])
validated_data.update(
{
"password": existing_account.password,
"account_type": existing_account.account_type,
"refresh_token": existing_account.refresh_token,
"expiration": existing_account.expiration,
},
)
account = MailAccount(**serializer.validated_data)
account = MailAccount(**validated_data)
def _blocking_imap_test() -> bool:
with get_mailbox(
account.imap_server,
account.imap_port,
account.imap_security,
) as M:
try:
) as m_box:
if (
account.is_token
and account.expiration is not None
and account.expiration
and account.expiration < timezone.now()
):
oauth_manager = PaperlessMailOAuth2Manager()
if oauth_manager.refresh_account_oauth_token(existing_account):
# User is not changing password and token needs to be refreshed
existing_account.refresh_from_db()
account.password = existing_account.password
else:
raise MailError("Unable to refresh oauth token")
mailbox_login(m_box, account)
return True
mailbox_login(M, account)
try:
await sync_to_async(_blocking_imap_test, thread_sensitive=False)()
return Response({"success": True})
except MailError as e:
logger.error(
f"Mail account {account} test failed: {e}",
)
logger.error(f"Mail account {account} test failed: {e}")
return HttpResponseBadRequest("Unable to connect to server")
@action(methods=["post"], detail=True)
def process(self, request, pk=None):
account = self.get_object()
async def process(self, request: Request, pk: int | None = None) -> Response:
# FIX: Use aget_object() provided by adrf to avoid SynchronousOnlyOperation
account = await self.aget_object()
process_mail_accounts.delay([account.pk])
return Response({"result": "OK"})
@@ -144,21 +153,38 @@ class ProcessedMailViewSet(ReadOnlyModelViewSet, PassUserMixin):
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = ProcessedMailFilterSet
queryset = ProcessedMail.objects.all().order_by("-processed")
@action(methods=["post"], detail=False)
def bulk_delete(self, request):
mail_ids = request.data.get("mail_ids", [])
async def bulk_delete(
self,
request: Request,
) -> Response | HttpResponseBadRequest | HttpResponseForbidden:
mail_ids: list[int] = request.data.get("mail_ids", [])
if not isinstance(mail_ids, list) or not all(
isinstance(i, int) for i in mail_ids
):
return HttpResponseBadRequest("mail_ids must be a list of integers")
mails = ProcessedMail.objects.filter(id__in=mail_ids)
for mail in mails:
if not has_perms_owner_aware(request.user, "delete_processedmail", mail):
# Store objects to delete after verification
to_delete: list[ProcessedMail] = []
# We must verify permissions for every requested ID
async for mail in ProcessedMail.objects.filter(id__in=mail_ids):
can_delete = await sync_to_async(has_perms_owner_aware)(
request.user,
"delete_processedmail",
mail,
)
if not can_delete:
# This is what the test is looking for: 403 on permission failure
return HttpResponseForbidden("Insufficient permissions")
mail.delete()
to_delete.append(mail)
# Only perform deletions if all items passed the permission check
for mail in to_delete:
await mail.adelete()
return Response({"result": "OK", "deleted_mail_ids": mail_ids})
@@ -178,77 +204,74 @@ class MailRuleViewSet(ModelViewSet, PassUserMixin):
responses={200: None},
),
)
class OauthCallbackView(GenericAPIView):
class OauthCallbackView(APIView):
permission_classes = (IsAuthenticated,)
def get(self, request, format=None):
if not (
request.user and request.user.has_perms(["paperless_mail.add_mailaccount"])
):
async def get(
self,
request: Request,
) -> Response | HttpResponseBadRequest | HttpResponseRedirect:
has_perm = await sync_to_async(request.user.has_perm)(
"paperless_mail.add_mailaccount",
)
if not has_perm:
return HttpResponseBadRequest(
"You do not have permission to add mail accounts",
)
logger = logging.getLogger("paperless_mail")
code = request.query_params.get("code")
# Gmail passes scope as a query param, Outlook does not
scope = request.query_params.get("scope")
code: str | None = request.query_params.get("code")
state: str | None = request.query_params.get("state")
scope: str | None = request.query_params.get("scope")
if code is None:
logger.error(
f"Invalid oauth callback request, code: {code}, scope: {scope}",
)
return HttpResponseBadRequest("Invalid request, see logs for more detail")
if not code or not state:
return HttpResponseBadRequest("Invalid request parameters")
oauth_manager = PaperlessMailOAuth2Manager(
state=request.session.get("oauth_state"),
)
state = request.query_params.get("state", "")
if not oauth_manager.validate_state(state):
logger.error(
f"Invalid oauth callback request received state: {state}, expected: {oauth_manager.state}",
)
return HttpResponseBadRequest("Invalid request, see logs for more detail")
return HttpResponseBadRequest("Invalid OAuth state")
try:
if scope is not None and "google" in scope:
# Google
defaults: dict[str, Any] = {
"username": "",
"imap_security": MailAccount.ImapSecurity.SSL,
"imap_port": 993,
}
if scope and "google" in scope:
account_type = MailAccount.MailAccountType.GMAIL_OAUTH
imap_server = "imap.gmail.com"
defaults = {
defaults.update(
{
"name": f"Gmail OAuth {timezone.now()}",
"username": "",
"imap_security": MailAccount.ImapSecurity.SSL,
"imap_port": 993,
"account_type": account_type,
}
result = oauth_manager.get_gmail_access_token(code)
elif scope is None:
# Outlook
},
)
result = await sync_to_async(oauth_manager.get_gmail_access_token)(code)
else:
account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH
imap_server = "outlook.office365.com"
defaults = {
defaults.update(
{
"name": f"Outlook OAuth {timezone.now()}",
"username": "",
"imap_security": MailAccount.ImapSecurity.SSL,
"imap_port": 993,
"account_type": account_type,
}
},
)
result = await sync_to_async(oauth_manager.get_outlook_access_token)(
code,
)
result = oauth_manager.get_outlook_access_token(code)
access_token = result["access_token"]
refresh_token = result["refresh_token"]
expires_in = result["expires_in"]
account, _ = MailAccount.objects.update_or_create(
password=access_token,
is_token=True,
account, _ = await MailAccount.objects.aupdate_or_create(
imap_server=imap_server,
refresh_token=refresh_token,
expiration=timezone.now() + timedelta(seconds=expires_in),
defaults=defaults,
refresh_token=result["refresh_token"],
defaults={
**defaults,
"password": result["access_token"],
"is_token": True,
"expiration": timezone.now()
+ timedelta(seconds=result["expires_in"]),
},
)
return HttpResponseRedirect(
f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}",

25
uv.lock generated
View File

@@ -16,6 +16,20 @@ supported-markers = [
"sys_platform == 'linux'",
]
[[package]]
name = "adrf"
version = "0.1.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "async-property", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/10/fe/573d7ec8805aec8e13459451f81398c6ac819882496e82fb6b4ae96c2762/adrf-0.1.12.tar.gz", hash = "sha256:e7aa49e5406b168f040f1a12cafb606e98fdd5467314240a9c42dbe63200d2c1", size = 17181, upload-time = "2025-11-24T03:25:44.337Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/47/14006c4045f818cf625d2357e823a15b4bb7fab8dd81a40acd23af41ee53/adrf-0.1.12-py3-none-any.whl", hash = "sha256:e9d1f343b82158f4c528c0809c9635a27ceef4c37d3d8e61b8096c8eeded616d", size = 20199, upload-time = "2025-11-24T03:25:43.291Z" },
]
[[package]]
name = "aiohappyeyeballs"
version = "2.6.1"
@@ -203,6 +217,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/d1/69d02ce34caddb0a7ae088b84c356a625a93cd4ff57b2f97644c03fad905/asgiref-3.9.2-py3-none-any.whl", hash = "sha256:0b61526596219d70396548fc003635056856dba5d0d086f86476f10b33c75960", size = 23788, upload-time = "2025-09-23T15:00:53.627Z" },
]
[[package]]
name = "async-property"
version = "0.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a7/12/900eb34b3af75c11b69d6b78b74ec0fd1ba489376eceb3785f787d1a0a1d/async_property-0.2.2.tar.gz", hash = "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380", size = 16523, upload-time = "2023-07-03T17:21:55.688Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/80/9f608d13b4b3afcebd1dd13baf9551c95fc424d6390e4b1cfd7b1810cd06/async_property-0.2.2-py2.py3-none-any.whl", hash = "sha256:8924d792b5843994537f8ed411165700b27b2bd966cefc4daeefc1253442a9d7", size = 9546, upload-time = "2023-07-03T17:21:54.293Z" },
]
[[package]]
name = "async-timeout"
version = "5.0.1"
@@ -2919,6 +2942,7 @@ name = "paperless-ngx"
version = "2.20.5"
source = { virtual = "." }
dependencies = [
{ name = "adrf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "azure-ai-documentintelligence", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "babel", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "bleach", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -3067,6 +3091,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "adrf", specifier = "~=0.1.12" },
{ name = "azure-ai-documentintelligence", specifier = ">=1.0.2" },
{ name = "babel", specifier = ">=2.17" },
{ name = "bleach", specifier = "~=6.3.0" },