diff --git a/src/documents/views.py b/src/documents/views.py index d88105ea7..065233a39 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1717,10 +1717,12 @@ class UiSettingsView(GenericAPIView): manager = PaperlessMailOAuth2Manager() if settings.GMAIL_OAUTH_ENABLED: ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url() + request.session["oauth_state"] = manager.state if settings.OUTLOOK_OAUTH_ENABLED: ui_settings["outlook_oauth_url"] = ( manager.get_outlook_authorization_url() ) + request.session["oauth_state"] = manager.state ui_settings["email_enabled"] = settings.EMAIL_ENABLED diff --git a/src/paperless_mail/oauth.py b/src/paperless_mail/oauth.py index 2bf2245bb..f2050451b 100644 --- a/src/paperless_mail/oauth.py +++ b/src/paperless_mail/oauth.py @@ -1,5 +1,6 @@ import asyncio import logging +import secrets from datetime import timedelta from django.conf import settings @@ -13,9 +14,10 @@ from paperless_mail.models import MailAccount class PaperlessMailOAuth2Manager: - def __init__(self): + def __init__(self, state: str | None = None): self._gmail_client = None self._outlook_client = None + self.state = state if state is not None else secrets.token_urlsafe(32) @property def gmail_client(self) -> GoogleOAuth2: @@ -49,6 +51,7 @@ class PaperlessMailOAuth2Manager: redirect_uri=self.oauth_callback_url, scope=["https://mail.google.com/"], extras_params={"prompt": "consent", "access_type": "offline"}, + state=self.state, ), ) @@ -60,6 +63,7 @@ class PaperlessMailOAuth2Manager: "offline_access", "https://outlook.office.com/IMAP.AccessAsUser.All", ], + state=self.state, ), ) @@ -109,3 +113,6 @@ class PaperlessMailOAuth2Manager: except RefreshTokenError as e: logger.error(f"Failed to refresh oauth token for account {account}: {e}") return False + + def validate_state(self, state: str) -> bool: + return settings.DEBUG or (len(state) > 0 and state == self.state) diff --git a/src/paperless_mail/tests/test_mail_oauth.py b/src/paperless_mail/tests/test_mail_oauth.py index 9eb68d3e5..f8f28df65 100644 --- a/src/paperless_mail/tests/test_mail_oauth.py +++ b/src/paperless_mail/tests/test_mail_oauth.py @@ -118,9 +118,17 @@ class TestMailOAuth( "expires_in": 3600, } + session = self.client.session + session.update( + { + "oauth_state": "test_state", + }, + ) + session.save() + # Test Google OAuth callback response = self.client.get( - "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", + "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/&state=test_state", ) self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertIn("oauth_success=1", response.url) @@ -130,7 +138,9 @@ class TestMailOAuth( ) # Test Outlook OAuth callback - response = self.client.get("/api/oauth/callback/?code=test_code") + response = self.client.get( + "/api/oauth/callback/?code=test_code&state=test_state", + ) self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertIn("oauth_success=1", response.url) self.assertTrue( @@ -150,10 +160,18 @@ class TestMailOAuth( """ mock_get_access_token.side_effect = GetAccessTokenError("test_error") + session = self.client.session + session.update( + { + "oauth_state": "test_state", + }, + ) + session.save() + with self.assertLogs("paperless_mail", level="ERROR") as cm: # Test Google OAuth callback response = self.client.get( - "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", + "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/&state=test_state", ) self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertIn("oauth_success=0", response.url) @@ -162,7 +180,9 @@ class TestMailOAuth( ) # Test Outlook OAuth callback - response = self.client.get("/api/oauth/callback/?code=test_code") + response = self.client.get( + "/api/oauth/callback/?code=test_code&state=test_state", + ) self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertIn("oauth_success=0", response.url) self.assertFalse( @@ -224,6 +244,27 @@ class TestMailOAuth( MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), ) + def test_oauth_callback_view_invalid_state(self): + """ + GIVEN: + - Mocked settings for Gmail and Outlook OAuth client IDs and secrets + WHEN: + - OAuth callback is called with an invalid state + THEN: + - 400 bad request returned, no mail accounts are created + """ + + response = self.client.get( + "/api/oauth/callback/?code=test_code&state=invalid_state", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertFalse( + MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), + ) + self.assertFalse( + MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), + ) + @mock.patch("paperless_mail.mail.get_mailbox") @mock.patch( "httpx_oauth.oauth2.BaseOAuth2.refresh_token", diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index 170d5c6c1..1b596452f 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -128,7 +128,16 @@ class OauthCallbackView(GenericAPIView): ) return HttpResponseBadRequest("Invalid request, see logs for more detail") - oauth_manager = PaperlessMailOAuth2Manager() + oauth_manager = PaperlessMailOAuth2Manager( + state=request.session.get("oauth_state"), + ) + + state = request.query_params.get("state", "") + if not oauth_manager.validate_state(state): + logger.error( + f"Invalid oauth callback request received state: {state}, expected: {oauth_manager.state}", + ) + return HttpResponseBadRequest("Invalid request, see logs for more detail") try: if scope is not None and "google" in scope: