mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Feature: OAuth2 Gmail and Outlook email support (#7866)
This commit is contained in:
		| @@ -2,6 +2,7 @@ import json | ||||
|  | ||||
| from django.contrib.auth.models import Permission | ||||
| from django.contrib.auth.models import User | ||||
| from django.test import override_settings | ||||
| from rest_framework import status | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| @@ -113,3 +114,22 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase): | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|  | ||||
|     @override_settings( | ||||
|         OAUTH_CALLBACK_BASE_URL="http://localhost:8000", | ||||
|         GMAIL_OAUTH_CLIENT_ID="abc123", | ||||
|         GMAIL_OAUTH_CLIENT_SECRET="def456", | ||||
|         GMAIL_OAUTH_ENABLED=True, | ||||
|         OUTLOOK_OAUTH_CLIENT_ID="ghi789", | ||||
|         OUTLOOK_OAUTH_CLIENT_SECRET="jkl012", | ||||
|         OUTLOOK_OAUTH_ENABLED=True, | ||||
|     ) | ||||
|     def test_settings_includes_oauth_urls_if_enabled(self): | ||||
|         response = self.client.get(self.ENDPOINT, format="json") | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertIsNotNone( | ||||
|             response.data["settings"]["gmail_oauth_url"], | ||||
|         ) | ||||
|         self.assertIsNotNone( | ||||
|             response.data["settings"]["outlook_oauth_url"], | ||||
|         ) | ||||
|   | ||||
| @@ -8,7 +8,7 @@ class TestMigrateWorkflow(TestMigrations): | ||||
|     dependencies = ( | ||||
|         ( | ||||
|             "paperless_mail", | ||||
|             "0026_mailrule_enabled", | ||||
|             "0027_mailaccount_expiration_mailaccount_account_type_and_more", | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|   | ||||
| @@ -162,6 +162,7 @@ from paperless.serialisers import UserSerializer | ||||
| from paperless.views import StandardPagination | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.models import MailRule | ||||
| from paperless_mail.oauth import PaperlessMailOAuth2Manager | ||||
| from paperless_mail.serialisers import MailAccountSerializer | ||||
| from paperless_mail.serialisers import MailRuleSerializer | ||||
|  | ||||
| @@ -1605,6 +1606,15 @@ class UiSettingsView(GenericAPIView): | ||||
|  | ||||
|         ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED | ||||
|  | ||||
|         if settings.GMAIL_OAUTH_ENABLED or settings.OUTLOOK_OAUTH_ENABLED: | ||||
|             manager = PaperlessMailOAuth2Manager() | ||||
|             if settings.GMAIL_OAUTH_ENABLED: | ||||
|                 ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url() | ||||
|             if settings.OUTLOOK_OAUTH_ENABLED: | ||||
|                 ui_settings["outlook_oauth_url"] = ( | ||||
|                     manager.get_outlook_authorization_url() | ||||
|                 ) | ||||
|  | ||||
|         user_resp = { | ||||
|             "id": user.id, | ||||
|             "username": user.username, | ||||
|   | ||||
| @@ -1195,3 +1195,19 @@ EMAIL_ENABLE_GPG_DECRYPTOR: Final[bool] = __get_boolean( | ||||
| # Soft Delete                                                                 # | ||||
| ############################################################################### | ||||
| EMPTY_TRASH_DELAY = max(__get_int("PAPERLESS_EMPTY_TRASH_DELAY", 30), 1) | ||||
|  | ||||
|  | ||||
| ############################################################################### | ||||
| # Oauth Email                                                                 # | ||||
| ############################################################################### | ||||
| OAUTH_CALLBACK_BASE_URL = os.getenv("PAPERLESS_OAUTH_CALLBACK_BASE_URL") | ||||
| GMAIL_OAUTH_CLIENT_ID = os.getenv("PAPERLESS_GMAIL_OAUTH_CLIENT_ID") | ||||
| GMAIL_OAUTH_CLIENT_SECRET = os.getenv("PAPERLESS_GMAIL_OAUTH_CLIENT_SECRET") | ||||
| GMAIL_OAUTH_ENABLED = bool( | ||||
|     OAUTH_CALLBACK_BASE_URL and GMAIL_OAUTH_CLIENT_ID and GMAIL_OAUTH_CLIENT_SECRET, | ||||
| ) | ||||
| OUTLOOK_OAUTH_CLIENT_ID = os.getenv("PAPERLESS_OUTLOOK_OAUTH_CLIENT_ID") | ||||
| OUTLOOK_OAUTH_CLIENT_SECRET = os.getenv("PAPERLESS_OUTLOOK_OAUTH_CLIENT_SECRET") | ||||
| OUTLOOK_OAUTH_ENABLED = bool( | ||||
|     OAUTH_CALLBACK_BASE_URL and OUTLOOK_OAUTH_CLIENT_ID and OUTLOOK_OAUTH_CLIENT_SECRET, | ||||
| ) | ||||
|   | ||||
| @@ -55,6 +55,7 @@ from paperless.views import UserViewSet | ||||
| from paperless_mail.views import MailAccountTestView | ||||
| from paperless_mail.views import MailAccountViewSet | ||||
| from paperless_mail.views import MailRuleViewSet | ||||
| from paperless_mail.views import OauthCallbackView | ||||
|  | ||||
| api_router = DefaultRouter() | ||||
| api_router.register(r"correspondents", CorrespondentViewSet) | ||||
| @@ -171,6 +172,11 @@ urlpatterns = [ | ||||
|                     StoragePathTestView.as_view(), | ||||
|                     name="storage_paths_test", | ||||
|                 ), | ||||
|                 re_path( | ||||
|                     r"^oauth/callback/", | ||||
|                     OauthCallbackView.as_view(), | ||||
|                     name="oauth_callback", | ||||
|                 ), | ||||
|                 *api_router.urls, | ||||
|             ], | ||||
|         ), | ||||
|   | ||||
| @@ -18,6 +18,7 @@ from celery import shared_task | ||||
| from celery.canvas import Signature | ||||
| from django.conf import settings | ||||
| from django.db import DatabaseError | ||||
| from django.utils import timezone | ||||
| from django.utils.timezone import is_naive | ||||
| from django.utils.timezone import make_aware | ||||
| from imap_tools import AND | ||||
| @@ -42,6 +43,7 @@ from documents.tasks import consume_file | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.models import MailRule | ||||
| from paperless_mail.models import ProcessedMail | ||||
| from paperless_mail.oauth import PaperlessMailOAuth2Manager | ||||
| from paperless_mail.preprocessor import MailMessageDecryptor | ||||
| from paperless_mail.preprocessor import MailMessagePreprocessor | ||||
|  | ||||
| @@ -530,6 +532,17 @@ class MailAccountHandler(LoggingMixin): | ||||
|                 account.imap_port, | ||||
|                 account.imap_security, | ||||
|             ) as M: | ||||
|                 if ( | ||||
|                     account.is_token | ||||
|                     and account.expiration is not None | ||||
|                     and account.expiration < timezone.now() | ||||
|                 ): | ||||
|                     manager = PaperlessMailOAuth2Manager() | ||||
|                     if manager.refresh_account_oauth_token(account): | ||||
|                         account.refresh_from_db() | ||||
|                     else: | ||||
|                         return total_processed_files | ||||
|  | ||||
|                 supports_gmail_labels = "X-GM-EXT-1" in M.client.capabilities | ||||
|                 supports_auth_plain = "AUTH=PLAIN" in M.client.capabilities | ||||
|  | ||||
|   | ||||
| @@ -0,0 +1,43 @@ | ||||
| # Generated by Django 5.1.1 on 2024-10-05 17:12 | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db import models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|     dependencies = [ | ||||
|         ("paperless_mail", "0026_mailrule_enabled"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="mailaccount", | ||||
|             name="expiration", | ||||
|             field=models.DateTimeField( | ||||
|                 blank=True, | ||||
|                 help_text="The expiration date of the refresh token. ", | ||||
|                 null=True, | ||||
|                 verbose_name="expiration", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="mailaccount", | ||||
|             name="account_type", | ||||
|             field=models.PositiveIntegerField( | ||||
|                 choices=[(1, "IMAP"), (2, "Gmail OAuth"), (3, "Outlook OAuth")], | ||||
|                 default=1, | ||||
|                 verbose_name="account type", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="mailaccount", | ||||
|             name="refresh_token", | ||||
|             field=models.CharField( | ||||
|                 blank=True, | ||||
|                 help_text="The refresh token to use for token authentication e.g. with oauth2.", | ||||
|                 max_length=2048, | ||||
|                 null=True, | ||||
|                 verbose_name="refresh token", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @@ -15,6 +15,11 @@ class MailAccount(document_models.ModelWithOwner): | ||||
|         SSL = 2, _("Use SSL") | ||||
|         STARTTLS = 3, _("Use STARTTLS") | ||||
|  | ||||
|     class MailAccountType(models.IntegerChoices): | ||||
|         IMAP = 1, _("IMAP") | ||||
|         GMAIL_OAUTH = 2, _("Gmail OAuth") | ||||
|         OUTLOOK_OAUTH = 3, _("Outlook OAuth") | ||||
|  | ||||
|     name = models.CharField(_("name"), max_length=256, unique=True) | ||||
|  | ||||
|     imap_server = models.CharField(_("IMAP server"), max_length=256) | ||||
| @@ -51,6 +56,31 @@ class MailAccount(document_models.ModelWithOwner): | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     account_type = models.PositiveIntegerField( | ||||
|         _("account type"), | ||||
|         choices=MailAccountType.choices, | ||||
|         default=MailAccountType.IMAP, | ||||
|     ) | ||||
|  | ||||
|     refresh_token = models.CharField( | ||||
|         _("refresh token"), | ||||
|         max_length=2048, | ||||
|         blank=True, | ||||
|         null=True, | ||||
|         help_text=_( | ||||
|             "The refresh token to use for token authentication e.g. with oauth2.", | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     expiration = models.DateTimeField( | ||||
|         _("expiration"), | ||||
|         blank=True, | ||||
|         null=True, | ||||
|         help_text=_( | ||||
|             "The expiration date of the refresh token. ", | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.name | ||||
|  | ||||
|   | ||||
							
								
								
									
										111
									
								
								src/paperless_mail/oauth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								src/paperless_mail/oauth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| from datetime import timedelta | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.utils import timezone | ||||
| from httpx_oauth.clients.google import GoogleOAuth2 | ||||
| from httpx_oauth.clients.microsoft import MicrosoftGraphOAuth2 | ||||
| from httpx_oauth.oauth2 import OAuth2Token | ||||
| from httpx_oauth.oauth2 import RefreshTokenError | ||||
|  | ||||
| from paperless_mail.models import MailAccount | ||||
|  | ||||
|  | ||||
| class PaperlessMailOAuth2Manager: | ||||
|     def __init__(self): | ||||
|         self._gmail_client = None | ||||
|         self._outlook_client = None | ||||
|  | ||||
|     @property | ||||
|     def gmail_client(self) -> GoogleOAuth2: | ||||
|         if self._gmail_client is None: | ||||
|             self._gmail_client = GoogleOAuth2( | ||||
|                 settings.GMAIL_OAUTH_CLIENT_ID, | ||||
|                 settings.GMAIL_OAUTH_CLIENT_SECRET, | ||||
|             ) | ||||
|         return self._gmail_client | ||||
|  | ||||
|     @property | ||||
|     def outlook_client(self) -> MicrosoftGraphOAuth2: | ||||
|         if self._outlook_client is None: | ||||
|             self._outlook_client = MicrosoftGraphOAuth2( | ||||
|                 settings.OUTLOOK_OAUTH_CLIENT_ID, | ||||
|                 settings.OUTLOOK_OAUTH_CLIENT_SECRET, | ||||
|             ) | ||||
|         return self._outlook_client | ||||
|  | ||||
|     @property | ||||
|     def oauth_callback_url(self) -> str: | ||||
|         return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/" | ||||
|  | ||||
|     @property | ||||
|     def oauth_redirect_url(self) -> str: | ||||
|         return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail"  # e.g. "http://localhost:4200/mail" or "/mail" | ||||
|  | ||||
|     def get_gmail_authorization_url(self) -> str: | ||||
|         return asyncio.run( | ||||
|             self.gmail_client.get_authorization_url( | ||||
|                 redirect_uri=self.oauth_callback_url, | ||||
|                 scope=["https://mail.google.com/"], | ||||
|                 extras_params={"prompt": "consent", "access_type": "offline"}, | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def get_outlook_authorization_url(self) -> str: | ||||
|         return asyncio.run( | ||||
|             self.outlook_client.get_authorization_url( | ||||
|                 redirect_uri=self.oauth_callback_url, | ||||
|                 scope=[ | ||||
|                     "offline_access", | ||||
|                     "https://outlook.office.com/IMAP.AccessAsUser.All", | ||||
|                 ], | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def get_gmail_access_token(self, code: str) -> OAuth2Token: | ||||
|         return asyncio.run( | ||||
|             self.gmail_client.get_access_token( | ||||
|                 code=code, | ||||
|                 redirect_uri=self.oauth_callback_url, | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def get_outlook_access_token(self, code: str) -> OAuth2Token: | ||||
|         return asyncio.run( | ||||
|             self.outlook_client.get_access_token( | ||||
|                 code=code, | ||||
|                 redirect_uri=self.oauth_callback_url, | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def refresh_account_oauth_token(self, account: MailAccount) -> bool: | ||||
|         """ | ||||
|         Refreshes the oauth token for the given mail account. | ||||
|         """ | ||||
|         logger = logging.getLogger("paperless_mail") | ||||
|         logger.debug(f"Attempting to refresh oauth token for account {account}") | ||||
|         try: | ||||
|             result: OAuth2Token | ||||
|             if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH: | ||||
|                 result = asyncio.run( | ||||
|                     self.gmail_client.refresh_token( | ||||
|                         refresh_token=account.refresh_token, | ||||
|                     ), | ||||
|                 ) | ||||
|             elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH: | ||||
|                 result = asyncio.run( | ||||
|                     self.outlook_client.refresh_token( | ||||
|                         refresh_token=account.refresh_token, | ||||
|                     ), | ||||
|                 ) | ||||
|             account.password = result["access_token"] | ||||
|             account.expiration = timezone.now() + timedelta( | ||||
|                 seconds=result["expires_in"], | ||||
|             ) | ||||
|             account.save() | ||||
|             logger.debug(f"Successfully refreshed oauth token for account {account}") | ||||
|             return True | ||||
|         except RefreshTokenError as e: | ||||
|             logger.error(f"Failed to refresh oauth token for account {account}: {e}") | ||||
|             return False | ||||
| @@ -39,6 +39,8 @@ class MailAccountSerializer(OwnedObjectSerializer): | ||||
|             "user_can_change", | ||||
|             "permissions", | ||||
|             "set_permissions", | ||||
|             "account_type", | ||||
|             "expiration", | ||||
|         ] | ||||
|  | ||||
|     def update(self, instance, validated_data): | ||||
|   | ||||
| @@ -4,9 +4,11 @@ import random | ||||
| import uuid | ||||
| from collections import namedtuple | ||||
| from contextlib import AbstractContextManager | ||||
| from datetime import timedelta | ||||
| from unittest import mock | ||||
|  | ||||
| import pytest | ||||
| from django.contrib.auth.models import User | ||||
| from django.core.management import call_command | ||||
| from django.db import DatabaseError | ||||
| from django.test import TestCase | ||||
| @@ -19,6 +21,8 @@ from imap_tools import MailboxLoginError | ||||
| from imap_tools import MailMessage | ||||
| from imap_tools import MailMessageFlags | ||||
| from imap_tools import errors | ||||
| from rest_framework import status | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from documents.models import Correspondent | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
| @@ -1590,3 +1594,128 @@ class TestTasks(TestCase): | ||||
|  | ||||
|         tasks.process_mail_accounts() | ||||
|         self.assertEqual(m.call_count, 0) | ||||
|  | ||||
|  | ||||
| class TestMailAccountTestView(APITestCase): | ||||
|     def setUp(self): | ||||
|         self.mailMocker = MailMocker() | ||||
|         self.mailMocker.setUp() | ||||
|         self.user = User.objects.create_user( | ||||
|             username="testuser", | ||||
|             password="testpassword", | ||||
|         ) | ||||
|         self.client.force_authenticate(user=self.user) | ||||
|         self.url = "/api/mail_accounts/test/" | ||||
|  | ||||
|     def test_mail_account_test_view_success(self): | ||||
|         data = { | ||||
|             "imap_server": "imap.example.com", | ||||
|             "imap_port": 993, | ||||
|             "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|             "username": "admin", | ||||
|             "password": "secret", | ||||
|             "account_type": MailAccount.MailAccountType.IMAP, | ||||
|             "is_token": False, | ||||
|         } | ||||
|         response = self.client.post(self.url, data, format="json") | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual(response.data, {"success": True}) | ||||
|  | ||||
|     def test_mail_account_test_view_mail_error(self): | ||||
|         data = { | ||||
|             "imap_server": "imap.example.com", | ||||
|             "imap_port": 993, | ||||
|             "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|             "username": "admin", | ||||
|             "password": "wrong", | ||||
|             "account_type": MailAccount.MailAccountType.IMAP, | ||||
|             "is_token": False, | ||||
|         } | ||||
|         response = self.client.post(self.url, data, format="json") | ||||
|         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||||
|         self.assertEqual(response.content.decode(), "Unable to connect to server") | ||||
|  | ||||
|     @mock.patch( | ||||
|         "paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token", | ||||
|     ) | ||||
|     def test_mail_account_test_view_refresh_token( | ||||
|         self, | ||||
|         mock_refresh_account_oauth_token, | ||||
|     ): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mail account with expired token | ||||
|         WHEN: | ||||
|             - Mail account is tested | ||||
|         THEN: | ||||
|             - Should refresh the token | ||||
|         """ | ||||
|         existing_account = MailAccount.objects.create( | ||||
|             imap_server="imap.example.com", | ||||
|             imap_port=993, | ||||
|             imap_security=MailAccount.ImapSecurity.SSL, | ||||
|             username="admin", | ||||
|             password="secret", | ||||
|             account_type=MailAccount.MailAccountType.GMAIL_OAUTH, | ||||
|             refresh_token="oldtoken", | ||||
|             expiration=timezone.now() - timedelta(days=1), | ||||
|             is_token=True, | ||||
|         ) | ||||
|  | ||||
|         mock_refresh_account_oauth_token.return_value = True | ||||
|         data = { | ||||
|             "id": existing_account.id, | ||||
|             "imap_server": "imap.example.com", | ||||
|             "imap_port": 993, | ||||
|             "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|             "username": "admin", | ||||
|             "password": "****", | ||||
|             "is_token": True, | ||||
|         } | ||||
|         self.client.post(self.url, data, format="json") | ||||
|         self.assertEqual(mock_refresh_account_oauth_token.call_count, 1) | ||||
|  | ||||
|     @mock.patch( | ||||
|         "paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token", | ||||
|     ) | ||||
|     def test_mail_account_test_view_refresh_token_fails( | ||||
|         self, | ||||
|         mock_mock_refresh_account_oauth_token, | ||||
|     ): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mail account with expired token | ||||
|         WHEN: | ||||
|             - Mail account is tested | ||||
|             - Token refresh fails | ||||
|         THEN: | ||||
|             - Should log an error | ||||
|         """ | ||||
|         existing_account = MailAccount.objects.create( | ||||
|             imap_server="imap.example.com", | ||||
|             imap_port=993, | ||||
|             imap_security=MailAccount.ImapSecurity.SSL, | ||||
|             username="admin", | ||||
|             password="secret", | ||||
|             account_type=MailAccount.MailAccountType.GMAIL_OAUTH, | ||||
|             refresh_token="oldtoken", | ||||
|             expiration=timezone.now() - timedelta(days=1), | ||||
|             is_token=True, | ||||
|         ) | ||||
|  | ||||
|         mock_mock_refresh_account_oauth_token.return_value = False | ||||
|         data = { | ||||
|             "id": existing_account.id, | ||||
|             "imap_server": "imap.example.com", | ||||
|             "imap_port": 993, | ||||
|             "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|             "username": "admin", | ||||
|             "password": "****", | ||||
|             "is_token": True, | ||||
|         } | ||||
|         with self.assertLogs("paperless_mail", level="ERROR") as cm: | ||||
|             response = self.client.post(self.url, data, format="json") | ||||
|             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||||
|             error_str = cm.output[0] | ||||
|             expected_str = "Unable to refresh oauth token" | ||||
|             self.assertIn(expected_str, error_str) | ||||
|   | ||||
							
								
								
									
										334
									
								
								src/paperless_mail/tests/test_mail_oauth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										334
									
								
								src/paperless_mail/tests/test_mail_oauth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,334 @@ | ||||
| from datetime import timedelta | ||||
| from unittest import mock | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.models import Permission | ||||
| from django.contrib.auth.models import User | ||||
| from django.test import TestCase | ||||
| from django.test import override_settings | ||||
| from django.utils import timezone | ||||
| from httpx_oauth.oauth2 import GetAccessTokenError | ||||
| from httpx_oauth.oauth2 import RefreshTokenError | ||||
| from rest_framework import status | ||||
|  | ||||
| from paperless_mail.mail import MailAccountHandler | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.oauth import PaperlessMailOAuth2Manager | ||||
|  | ||||
|  | ||||
| class TestMailOAuth( | ||||
|     TestCase, | ||||
| ): | ||||
|     def setUp(self) -> None: | ||||
|         self.user = User.objects.create_user("testuser") | ||||
|         self.user.user_permissions.add( | ||||
|             *Permission.objects.filter( | ||||
|                 codename__in=[ | ||||
|                     "add_mailaccount", | ||||
|                 ], | ||||
|             ), | ||||
|         ) | ||||
|         self.user.save() | ||||
|         self.client.force_login(self.user) | ||||
|         self.mail_account_handler = MailAccountHandler() | ||||
|         # Mock settings | ||||
|         settings.OAUTH_CALLBACK_BASE_URL = "http://localhost:8000" | ||||
|         settings.GMAIL_OAUTH_CLIENT_ID = "test_gmail_client_id" | ||||
|         settings.GMAIL_OAUTH_CLIENT_SECRET = "test_gmail_client_secret" | ||||
|         settings.OUTLOOK_OAUTH_CLIENT_ID = "test_outlook_client_id" | ||||
|         settings.OUTLOOK_OAUTH_CLIENT_SECRET = "test_outlook_client_secret" | ||||
|         super().setUp() | ||||
|  | ||||
|     def test_generate_paths(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mocked settings for OAuth callback and base URLs | ||||
|         WHEN: | ||||
|             - get_oauth_callback_url and get_oauth_redirect_url are called | ||||
|         THEN: | ||||
|             - Correct URLs are generated | ||||
|         """ | ||||
|         # Callback URL | ||||
|         oauth_manager = PaperlessMailOAuth2Manager() | ||||
|         with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"): | ||||
|             self.assertEqual( | ||||
|                 oauth_manager.oauth_callback_url, | ||||
|                 "http://paperless.example.com/api/oauth/callback/", | ||||
|             ) | ||||
|         with override_settings( | ||||
|             OAUTH_CALLBACK_BASE_URL=None, | ||||
|             PAPERLESS_URL="http://paperless.example.com", | ||||
|         ): | ||||
|             self.assertEqual( | ||||
|                 oauth_manager.oauth_callback_url, | ||||
|                 "http://paperless.example.com/api/oauth/callback/", | ||||
|             ) | ||||
|         with override_settings( | ||||
|             OAUTH_CALLBACK_BASE_URL=None, | ||||
|             PAPERLESS_URL="http://paperless.example.com", | ||||
|             BASE_URL="/paperless/", | ||||
|         ): | ||||
|             self.assertEqual( | ||||
|                 oauth_manager.oauth_callback_url, | ||||
|                 "http://paperless.example.com/paperless/api/oauth/callback/", | ||||
|             ) | ||||
|  | ||||
|         # Redirect URL | ||||
|         with override_settings(DEBUG=True): | ||||
|             self.assertEqual( | ||||
|                 oauth_manager.oauth_redirect_url, | ||||
|                 "http://localhost:4200/mail", | ||||
|             ) | ||||
|         with override_settings(DEBUG=False): | ||||
|             self.assertEqual( | ||||
|                 oauth_manager.oauth_redirect_url, | ||||
|                 "/mail", | ||||
|             ) | ||||
|  | ||||
|     @mock.patch( | ||||
|         "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_gmail_access_token", | ||||
|     ) | ||||
|     @mock.patch( | ||||
|         "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_outlook_access_token", | ||||
|     ) | ||||
|     def test_oauth_callback_view_success( | ||||
|         self, | ||||
|         mock_get_outlook_access_token, | ||||
|         mock_get_gmail_access_token, | ||||
|     ): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mocked settings for Gmail and Outlook OAuth client IDs and secrets | ||||
|         WHEN: | ||||
|             - OAuth callback is called with a code and scope | ||||
|             - OAuth callback is called with a code and no scope | ||||
|         THEN: | ||||
|             - Gmail mail account is created | ||||
|             - Outlook mail account is created | ||||
|         """ | ||||
|  | ||||
|         mock_get_gmail_access_token.return_value = { | ||||
|             "access_token": "test_access_token", | ||||
|             "refresh_token": "test_refresh_token", | ||||
|             "expires_in": 3600, | ||||
|         } | ||||
|         mock_get_outlook_access_token.return_value = { | ||||
|             "access_token": "test_access_token", | ||||
|             "refresh_token": "test_refresh_token", | ||||
|             "expires_in": 3600, | ||||
|         } | ||||
|  | ||||
|         # Test Google OAuth callback | ||||
|         response = self.client.get( | ||||
|             "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_302_FOUND) | ||||
|         self.assertIn("oauth_success=1", response.url) | ||||
|         mock_get_gmail_access_token.assert_called_once() | ||||
|         self.assertTrue( | ||||
|             MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), | ||||
|         ) | ||||
|  | ||||
|         # Test Outlook OAuth callback | ||||
|         response = self.client.get("/api/oauth/callback/?code=test_code") | ||||
|         self.assertEqual(response.status_code, status.HTTP_302_FOUND) | ||||
|         self.assertIn("oauth_success=1", response.url) | ||||
|         self.assertTrue( | ||||
|             MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), | ||||
|         ) | ||||
|  | ||||
|     @mock.patch("httpx_oauth.oauth2.BaseOAuth2.get_access_token") | ||||
|     def test_oauth_callback_view_fails(self, mock_get_access_token): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mocked settings for Gmail and Outlook OAuth client IDs and secrets | ||||
|         WHEN: | ||||
|             - OAuth callback is called and get access token returns an error | ||||
|         THEN: | ||||
|             - No mail account is created | ||||
|             - Error is logged | ||||
|         """ | ||||
|         mock_get_access_token.side_effect = GetAccessTokenError("test_error") | ||||
|  | ||||
|         with self.assertLogs("paperless_mail", level="ERROR") as cm: | ||||
|             # Test Google OAuth callback | ||||
|             response = self.client.get( | ||||
|                 "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", | ||||
|             ) | ||||
|             self.assertEqual(response.status_code, status.HTTP_302_FOUND) | ||||
|             self.assertIn("oauth_success=0", response.url) | ||||
|             self.assertFalse( | ||||
|                 MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), | ||||
|             ) | ||||
|  | ||||
|             # Test Outlook OAuth callback | ||||
|             response = self.client.get("/api/oauth/callback/?code=test_code") | ||||
|             self.assertEqual(response.status_code, status.HTTP_302_FOUND) | ||||
|             self.assertIn("oauth_success=0", response.url) | ||||
|             self.assertFalse( | ||||
|                 MailAccount.objects.filter( | ||||
|                     imap_server="outlook.office365.com", | ||||
|                 ).exists(), | ||||
|             ) | ||||
|  | ||||
|             self.assertIn("Error getting access token: test_error", cm.output[0]) | ||||
|  | ||||
|     def test_oauth_callback_view_insufficient_permissions(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mocked settings for Gmail and Outlook OAuth client IDs and secrets | ||||
|             - User without add_mailaccount permission | ||||
|         WHEN: | ||||
|             - OAuth callback is called | ||||
|         THEN: | ||||
|             - 400 bad request returned, no mail accounts are created | ||||
|         """ | ||||
|         self.user.user_permissions.remove( | ||||
|             *Permission.objects.filter( | ||||
|                 codename__in=[ | ||||
|                     "add_mailaccount", | ||||
|                 ], | ||||
|             ), | ||||
|         ) | ||||
|         self.user.save() | ||||
|  | ||||
|         response = self.client.get( | ||||
|             "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||||
|         self.assertFalse( | ||||
|             MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), | ||||
|         ) | ||||
|         self.assertFalse( | ||||
|             MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), | ||||
|         ) | ||||
|  | ||||
|     def test_oauth_callback_view_no_code(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mocked settings for Gmail and Outlook OAuth client IDs and secrets | ||||
|         WHEN: | ||||
|             - OAuth callback is called without a code | ||||
|         THEN: | ||||
|             - 400 bad request returned, no mail accounts are created | ||||
|         """ | ||||
|  | ||||
|         response = self.client.get( | ||||
|             "/api/oauth/callback/", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||||
|         self.assertFalse( | ||||
|             MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), | ||||
|         ) | ||||
|         self.assertFalse( | ||||
|             MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), | ||||
|         ) | ||||
|  | ||||
|     @mock.patch("paperless_mail.mail.get_mailbox") | ||||
|     @mock.patch( | ||||
|         "httpx_oauth.oauth2.BaseOAuth2.refresh_token", | ||||
|     ) | ||||
|     def test_refresh_token_on_handle_mail_account( | ||||
|         self, | ||||
|         mock_refresh_token, | ||||
|         mock_get_mailbox, | ||||
|     ): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mail account with refresh token and expiration | ||||
|         WHEN: | ||||
|             - handle_mail_account is called | ||||
|         THEN: | ||||
|             - Refresh token is called | ||||
|         """ | ||||
|  | ||||
|         mock_mailbox = mock.MagicMock() | ||||
|         mock_get_mailbox.return_value.__enter__.return_value = mock_mailbox | ||||
|  | ||||
|         mail_account = MailAccount.objects.create( | ||||
|             name="Test Gmail Mail Account", | ||||
|             username="test_username", | ||||
|             imap_security=MailAccount.ImapSecurity.SSL, | ||||
|             imap_port=993, | ||||
|             account_type=MailAccount.MailAccountType.GMAIL_OAUTH, | ||||
|             is_token=True, | ||||
|             refresh_token="test_refresh_token", | ||||
|             expiration=timezone.now() - timedelta(days=1), | ||||
|         ) | ||||
|  | ||||
|         mock_refresh_token.return_value = { | ||||
|             "access_token": "test_access_token", | ||||
|             "refresh_token": "test_refresh_token", | ||||
|             "expires_in": 3600, | ||||
|         } | ||||
|  | ||||
|         self.mail_account_handler.handle_mail_account(mail_account) | ||||
|         mock_refresh_token.assert_called_once() | ||||
|         mock_refresh_token.reset_mock() | ||||
|  | ||||
|         mock_refresh_token.return_value = { | ||||
|             "access_token": "test_access_token", | ||||
|             "refresh_token": "test_refresh", | ||||
|             "expires_in": 3600, | ||||
|         } | ||||
|         outlook_mail_account = MailAccount.objects.create( | ||||
|             name="Test Outlook Mail Account", | ||||
|             username="test_username", | ||||
|             imap_security=MailAccount.ImapSecurity.SSL, | ||||
|             imap_port=993, | ||||
|             account_type=MailAccount.MailAccountType.OUTLOOK_OAUTH, | ||||
|             is_token=True, | ||||
|             refresh_token="test_refresh_token", | ||||
|             expiration=timezone.now() - timedelta(days=1), | ||||
|         ) | ||||
|  | ||||
|         self.mail_account_handler.handle_mail_account(outlook_mail_account) | ||||
|         mock_refresh_token.assert_called_once() | ||||
|  | ||||
|     @mock.patch("paperless_mail.mail.get_mailbox") | ||||
|     @mock.patch( | ||||
|         "httpx_oauth.oauth2.BaseOAuth2.refresh_token", | ||||
|     ) | ||||
|     def test_refresh_token_on_handle_mail_account_fails( | ||||
|         self, | ||||
|         mock_refresh_token, | ||||
|         mock_get_mailbox, | ||||
|     ): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Mail account with refresh token and expiration | ||||
|         WHEN: | ||||
|             - handle_mail_account is called | ||||
|             - Refresh token is called but fails | ||||
|         THEN: | ||||
|             - Error is logged | ||||
|             - 0 processed mails is returned | ||||
|         """ | ||||
|  | ||||
|         mock_mailbox = mock.MagicMock() | ||||
|         mock_get_mailbox.return_value.__enter__.return_value = mock_mailbox | ||||
|  | ||||
|         mail_account = MailAccount.objects.create( | ||||
|             name="Test Gmail Mail Account", | ||||
|             username="test_username", | ||||
|             imap_security=MailAccount.ImapSecurity.SSL, | ||||
|             imap_port=993, | ||||
|             account_type=MailAccount.MailAccountType.GMAIL_OAUTH, | ||||
|             is_token=True, | ||||
|             refresh_token="test_refresh_token", | ||||
|             expiration=timezone.now() - timedelta(days=1), | ||||
|         ) | ||||
|  | ||||
|         mock_refresh_token.side_effect = RefreshTokenError("test_error") | ||||
|  | ||||
|         with self.assertLogs("paperless_mail", level="ERROR") as cm: | ||||
|             # returns 0 processed mails | ||||
|             self.assertEqual( | ||||
|                 self.mail_account_handler.handle_mail_account(mail_account), | ||||
|                 0, | ||||
|             ) | ||||
|             mock_refresh_token.assert_called_once() | ||||
|             self.assertIn( | ||||
|                 f"Failed to refresh oauth token for account {mail_account}: test_error", | ||||
|                 cm.output[0], | ||||
|             ) | ||||
| @@ -1,7 +1,11 @@ | ||||
| import datetime | ||||
| import logging | ||||
| from datetime import timedelta | ||||
|  | ||||
| from django.http import HttpResponseBadRequest | ||||
| from django.http import HttpResponseRedirect | ||||
| from django.utils import timezone | ||||
| from httpx_oauth.oauth2 import GetAccessTokenError | ||||
| from rest_framework.generics import GenericAPIView | ||||
| from rest_framework.permissions import IsAuthenticated | ||||
| from rest_framework.response import Response | ||||
| @@ -16,6 +20,7 @@ from paperless_mail.mail import get_mailbox | ||||
| from paperless_mail.mail import mailbox_login | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.models import MailRule | ||||
| from paperless_mail.oauth import PaperlessMailOAuth2Manager | ||||
| from paperless_mail.serialisers import MailAccountSerializer | ||||
| from paperless_mail.serialisers import MailRuleSerializer | ||||
|  | ||||
| @@ -50,27 +55,114 @@ class MailAccountTestView(GenericAPIView): | ||||
|         serializer = self.get_serializer(data=request.data) | ||||
|         serializer.is_valid(raise_exception=True) | ||||
|  | ||||
|         # account exists, use the password from there instead of *** | ||||
|         # account exists, use the password from there instead of *** and refresh_token / expiration | ||||
|         if ( | ||||
|             len(serializer.validated_data.get("password").replace("*", "")) == 0 | ||||
|             and request.data["id"] is not None | ||||
|         ): | ||||
|             serializer.validated_data["password"] = MailAccount.objects.get( | ||||
|                 pk=request.data["id"], | ||||
|             ).password | ||||
|             existing_account = MailAccount.objects.get(pk=request.data["id"]) | ||||
|             serializer.validated_data["password"] = existing_account.password | ||||
|             serializer.validated_data["account_type"] = existing_account.account_type | ||||
|             serializer.validated_data["refresh_token"] = existing_account.refresh_token | ||||
|             serializer.validated_data["expiration"] = existing_account.expiration | ||||
|  | ||||
|         account = MailAccount(**serializer.validated_data) | ||||
|  | ||||
|         with get_mailbox( | ||||
|             account.imap_server, | ||||
|             account.imap_port, | ||||
|             account.imap_security, | ||||
|         ) as M: | ||||
|             try: | ||||
|                 if ( | ||||
|                     account.is_token | ||||
|                     and account.expiration is not None | ||||
|                     and account.expiration < timezone.now() | ||||
|                 ): | ||||
|                     oauth_manager = PaperlessMailOAuth2Manager() | ||||
|                     if oauth_manager.refresh_account_oauth_token(existing_account): | ||||
|                         # User is not changing password and token needs to be refreshed | ||||
|                         existing_account.refresh_from_db() | ||||
|                         account.password = existing_account.password | ||||
|                     else: | ||||
|                         raise MailError("Unable to refresh oauth token") | ||||
|  | ||||
|                 mailbox_login(M, account) | ||||
|                 return Response({"success": True}) | ||||
|             except MailError: | ||||
|             except MailError as e: | ||||
|                 logger.error( | ||||
|                     f"Mail account {account} test failed", | ||||
|                     f"Mail account {account} test failed: {e}", | ||||
|                 ) | ||||
|                 return HttpResponseBadRequest("Unable to connect to server") | ||||
|  | ||||
|  | ||||
| class OauthCallbackView(GenericAPIView): | ||||
|     permission_classes = (IsAuthenticated,) | ||||
|  | ||||
|     def get(self, request, format=None): | ||||
|         if not ( | ||||
|             request.user and request.user.has_perms(["paperless_mail.add_mailaccount"]) | ||||
|         ): | ||||
|             return HttpResponseBadRequest( | ||||
|                 "You do not have permission to add mail accounts", | ||||
|             ) | ||||
|  | ||||
|         logger = logging.getLogger("paperless_mail") | ||||
|         code = request.query_params.get("code") | ||||
|         # Gmail passes scope as a query param, Outlook does not | ||||
|         scope = request.query_params.get("scope") | ||||
|  | ||||
|         if code is None: | ||||
|             logger.error( | ||||
|                 f"Invalid oauth callback request, code: {code}, scope: {scope}", | ||||
|             ) | ||||
|             return HttpResponseBadRequest("Invalid request, see logs for more detail") | ||||
|  | ||||
|         oauth_manager = PaperlessMailOAuth2Manager() | ||||
|  | ||||
|         try: | ||||
|             if scope is not None and "google" in scope: | ||||
|                 # Google | ||||
|                 account_type = MailAccount.MailAccountType.GMAIL_OAUTH | ||||
|                 imap_server = "imap.gmail.com" | ||||
|                 defaults = { | ||||
|                     "name": f"Gmail OAuth {timezone.now()}", | ||||
|                     "username": "", | ||||
|                     "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|                     "imap_port": 993, | ||||
|                     "account_type": account_type, | ||||
|                 } | ||||
|                 result = oauth_manager.get_gmail_access_token(code) | ||||
|  | ||||
|             elif scope is None: | ||||
|                 # Outlook | ||||
|                 account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH | ||||
|                 imap_server = "outlook.office365.com" | ||||
|                 defaults = { | ||||
|                     "name": f"Outlook OAuth {timezone.now()}", | ||||
|                     "username": "", | ||||
|                     "imap_security": MailAccount.ImapSecurity.SSL, | ||||
|                     "imap_port": 993, | ||||
|                     "account_type": account_type, | ||||
|                 } | ||||
|  | ||||
|                 result = oauth_manager.get_outlook_access_token(code) | ||||
|  | ||||
|             access_token = result["access_token"] | ||||
|             refresh_token = result["refresh_token"] | ||||
|             expires_in = result["expires_in"] | ||||
|             account, _ = MailAccount.objects.update_or_create( | ||||
|                 password=access_token, | ||||
|                 is_token=True, | ||||
|                 imap_server=imap_server, | ||||
|                 refresh_token=refresh_token, | ||||
|                 expiration=timezone.now() + timedelta(seconds=expires_in), | ||||
|                 defaults=defaults, | ||||
|             ) | ||||
|             return HttpResponseRedirect( | ||||
|                 f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}", | ||||
|             ) | ||||
|         except GetAccessTokenError as e: | ||||
|             logger.error(f"Error getting access token: {e}") | ||||
|             return HttpResponseRedirect( | ||||
|                 f"{oauth_manager.oauth_redirect_url}?oauth_success=0", | ||||
|             ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 shamoon
					shamoon