mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Fix: use state param with oauth (#8636)
This commit is contained in:
parent
0ab21b6fc5
commit
a899ff16e3
@ -1717,10 +1717,12 @@ class UiSettingsView(GenericAPIView):
|
|||||||
manager = PaperlessMailOAuth2Manager()
|
manager = PaperlessMailOAuth2Manager()
|
||||||
if settings.GMAIL_OAUTH_ENABLED:
|
if settings.GMAIL_OAUTH_ENABLED:
|
||||||
ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url()
|
ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url()
|
||||||
|
request.session["oauth_state"] = manager.state
|
||||||
if settings.OUTLOOK_OAUTH_ENABLED:
|
if settings.OUTLOOK_OAUTH_ENABLED:
|
||||||
ui_settings["outlook_oauth_url"] = (
|
ui_settings["outlook_oauth_url"] = (
|
||||||
manager.get_outlook_authorization_url()
|
manager.get_outlook_authorization_url()
|
||||||
)
|
)
|
||||||
|
request.session["oauth_state"] = manager.state
|
||||||
|
|
||||||
ui_settings["email_enabled"] = settings.EMAIL_ENABLED
|
ui_settings["email_enabled"] = settings.EMAIL_ENABLED
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import secrets
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
@ -13,9 +14,10 @@ from paperless_mail.models import MailAccount
|
|||||||
|
|
||||||
|
|
||||||
class PaperlessMailOAuth2Manager:
|
class PaperlessMailOAuth2Manager:
|
||||||
def __init__(self):
|
def __init__(self, state: str | None = None):
|
||||||
self._gmail_client = None
|
self._gmail_client = None
|
||||||
self._outlook_client = None
|
self._outlook_client = None
|
||||||
|
self.state = state if state is not None else secrets.token_urlsafe(32)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gmail_client(self) -> GoogleOAuth2:
|
def gmail_client(self) -> GoogleOAuth2:
|
||||||
@ -49,6 +51,7 @@ class PaperlessMailOAuth2Manager:
|
|||||||
redirect_uri=self.oauth_callback_url,
|
redirect_uri=self.oauth_callback_url,
|
||||||
scope=["https://mail.google.com/"],
|
scope=["https://mail.google.com/"],
|
||||||
extras_params={"prompt": "consent", "access_type": "offline"},
|
extras_params={"prompt": "consent", "access_type": "offline"},
|
||||||
|
state=self.state,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,6 +63,7 @@ class PaperlessMailOAuth2Manager:
|
|||||||
"offline_access",
|
"offline_access",
|
||||||
"https://outlook.office.com/IMAP.AccessAsUser.All",
|
"https://outlook.office.com/IMAP.AccessAsUser.All",
|
||||||
],
|
],
|
||||||
|
state=self.state,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -109,3 +113,6 @@ class PaperlessMailOAuth2Manager:
|
|||||||
except RefreshTokenError as e:
|
except RefreshTokenError as e:
|
||||||
logger.error(f"Failed to refresh oauth token for account {account}: {e}")
|
logger.error(f"Failed to refresh oauth token for account {account}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def validate_state(self, state: str) -> bool:
|
||||||
|
return settings.DEBUG or (len(state) > 0 and state == self.state)
|
||||||
|
@ -118,9 +118,17 @@ class TestMailOAuth(
|
|||||||
"expires_in": 3600,
|
"expires_in": 3600,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session = self.client.session
|
||||||
|
session.update(
|
||||||
|
{
|
||||||
|
"oauth_state": "test_state",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session.save()
|
||||||
|
|
||||||
# Test Google OAuth callback
|
# Test Google OAuth callback
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/",
|
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/&state=test_state",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=1", response.url)
|
self.assertIn("oauth_success=1", response.url)
|
||||||
@ -130,7 +138,9 @@ class TestMailOAuth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test Outlook OAuth callback
|
# Test Outlook OAuth callback
|
||||||
response = self.client.get("/api/oauth/callback/?code=test_code")
|
response = self.client.get(
|
||||||
|
"/api/oauth/callback/?code=test_code&state=test_state",
|
||||||
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=1", response.url)
|
self.assertIn("oauth_success=1", response.url)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -150,10 +160,18 @@ class TestMailOAuth(
|
|||||||
"""
|
"""
|
||||||
mock_get_access_token.side_effect = GetAccessTokenError("test_error")
|
mock_get_access_token.side_effect = GetAccessTokenError("test_error")
|
||||||
|
|
||||||
|
session = self.client.session
|
||||||
|
session.update(
|
||||||
|
{
|
||||||
|
"oauth_state": "test_state",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session.save()
|
||||||
|
|
||||||
with self.assertLogs("paperless_mail", level="ERROR") as cm:
|
with self.assertLogs("paperless_mail", level="ERROR") as cm:
|
||||||
# Test Google OAuth callback
|
# Test Google OAuth callback
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/",
|
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/&state=test_state",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=0", response.url)
|
self.assertIn("oauth_success=0", response.url)
|
||||||
@ -162,7 +180,9 @@ class TestMailOAuth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test Outlook OAuth callback
|
# Test Outlook OAuth callback
|
||||||
response = self.client.get("/api/oauth/callback/?code=test_code")
|
response = self.client.get(
|
||||||
|
"/api/oauth/callback/?code=test_code&state=test_state",
|
||||||
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=0", response.url)
|
self.assertIn("oauth_success=0", response.url)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
@ -224,6 +244,27 @@ class TestMailOAuth(
|
|||||||
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_oauth_callback_view_invalid_state(self):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
|
||||||
|
WHEN:
|
||||||
|
- OAuth callback is called with an invalid state
|
||||||
|
THEN:
|
||||||
|
- 400 bad request returned, no mail accounts are created
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = self.client.get(
|
||||||
|
"/api/oauth/callback/?code=test_code&state=invalid_state",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertFalse(
|
||||||
|
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch("paperless_mail.mail.get_mailbox")
|
@mock.patch("paperless_mail.mail.get_mailbox")
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"httpx_oauth.oauth2.BaseOAuth2.refresh_token",
|
"httpx_oauth.oauth2.BaseOAuth2.refresh_token",
|
||||||
|
@ -128,7 +128,16 @@ class OauthCallbackView(GenericAPIView):
|
|||||||
)
|
)
|
||||||
return HttpResponseBadRequest("Invalid request, see logs for more detail")
|
return HttpResponseBadRequest("Invalid request, see logs for more detail")
|
||||||
|
|
||||||
oauth_manager = PaperlessMailOAuth2Manager()
|
oauth_manager = PaperlessMailOAuth2Manager(
|
||||||
|
state=request.session.get("oauth_state"),
|
||||||
|
)
|
||||||
|
|
||||||
|
state = request.query_params.get("state", "")
|
||||||
|
if not oauth_manager.validate_state(state):
|
||||||
|
logger.error(
|
||||||
|
f"Invalid oauth callback request received state: {state}, expected: {oauth_manager.state}",
|
||||||
|
)
|
||||||
|
return HttpResponseBadRequest("Invalid request, see logs for more detail")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if scope is not None and "google" in scope:
|
if scope is not None and "google" in scope:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user