From 2dbca9e9cefb8d4d2083496e77a0429f7c221196 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:12:09 -0800 Subject: [PATCH] Support syncing groups on login --- docs/configuration.md | 12 ++++ src/paperless/apps.py | 6 ++ src/paperless/settings.py | 1 + src/paperless/signals.py | 18 ++++++ src/paperless/tests/test_signals.py | 91 +++++++++++++++++++++++++++++ 5 files changed, 128 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 86a22640d..3ba0825c1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -586,6 +586,18 @@ system. See the corresponding Defaults to False +#### [`PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS=`](#PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS) {#PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS} + +: Sync groups from the third party authentication system (e.g. OIDC) to Paperless-ngx. When enabled, users will be added or removed from groups based on their group membership in the third party authentication system. Groups must already exist in Paperless-ngx and have the same name as in the third party authentication system. Groups are updated upon logging in via the third party authentication system, see the corresponding [django-allauth documentation](https://docs.allauth.org/en/dev/socialaccount/signals.html). + +In order to pass groups from the authentication system you will need to update your [PAPERLESS_SOCIALACCOUNT_PROVIDERS](#PAPERLESS_SOCIALACCOUNT_PROVIDERS) setting by adding a top-level "SCOPES" setting which includes "groups", e.g.: + +```json +{"openid_connect":{"SCOPE": ["openid","profile","email","groups"]... +``` + + Defaults to False + #### [`PAPERLESS_ACCOUNT_DEFAULT_GROUPS=`](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS) {#PAPERLESS_ACCOUNT_DEFAULT_GROUPS} : A list of group names that users will be added to when they sign up for a new account. Groups listed here must already exist. diff --git a/src/paperless/apps.py b/src/paperless/apps.py index b4147a2e3..819d8d5ff 100644 --- a/src/paperless/apps.py +++ b/src/paperless/apps.py @@ -2,6 +2,7 @@ from django.apps import AppConfig from django.utils.translation import gettext_lazy as _ from paperless.signals import handle_failed_login +from paperless.signals import handle_social_account_updated class PaperlessConfig(AppConfig): @@ -13,4 +14,9 @@ class PaperlessConfig(AppConfig): from django.contrib.auth.signals import user_login_failed user_login_failed.connect(handle_failed_login) + + from allauth.socialaccount.signals import social_account_updated + + social_account_updated.connect(handle_social_account_updated) + AppConfig.ready(self) diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 401ba04b9..0c8c71ab9 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -492,6 +492,7 @@ SOCIALACCOUNT_PROVIDERS = json.loads( os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"), ) SOCIAL_ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS") +SOCIAL_ACCOUNT_SYNC_GROUPS = __get_boolean("PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS") MFA_TOTP_ISSUER = "Paperless-ngx" diff --git a/src/paperless/signals.py b/src/paperless/signals.py index fa0298685..a173ccc2e 100644 --- a/src/paperless/signals.py +++ b/src/paperless/signals.py @@ -30,3 +30,21 @@ def handle_failed_login(sender, credentials, request, **kwargs): log_output += f" from private IP `{client_ip}`." logger.info(log_output) + + +def handle_social_account_updated(sender, request, sociallogin, **kwargs): + """ + Handle the social account update signal. + """ + from django.contrib.auth.models import Group + + social_account_groups = sociallogin.account.extra_data.get( + "groups", + [], + ) # None if not found + if settings.SOCIAL_ACCOUNT_SYNC_GROUPS and social_account_groups is not None: + groups = Group.objects.filter(name__in=social_account_groups) + logger.debug( + f"Syncing groups for user `{sociallogin.user}`: {social_account_groups}", + ) + sociallogin.user.groups.set(groups, clear=True) diff --git a/src/paperless/tests/test_signals.py b/src/paperless/tests/test_signals.py index dc425d667..0948ca575 100644 --- a/src/paperless/tests/test_signals.py +++ b/src/paperless/tests/test_signals.py @@ -1,7 +1,13 @@ +from unittest.mock import Mock + +from django.contrib.auth.models import Group +from django.contrib.auth.models import User from django.http import HttpRequest from django.test import TestCase +from django.test import override_settings from paperless.signals import handle_failed_login +from paperless.signals import handle_social_account_updated class TestFailedLoginLogging(TestCase): @@ -99,3 +105,88 @@ class TestFailedLoginLogging(TestCase): "INFO:paperless.auth:Login failed for user `john lennon` from private IP `10.0.0.1`.", ], ) + + +class TestSyncSocialLoginGroups(TestCase): + @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True) + def test_sync_enabled(self): + """ + GIVEN: + - Enabled group syncing, a user, and a social login + WHEN: + - The social login is updated via signal after login + THEN: + - The user's groups are updated to match the social login's groups + """ + group = Group.objects.create(name="group1") + user = User.objects.create_user(username="testuser") + sociallogin = Mock( + user=user, + account=Mock( + extra_data={ + "groups": ["group1"], + }, + ), + ) + handle_social_account_updated( + sender=None, + request=HttpRequest(), + sociallogin=sociallogin, + ) + self.assertEqual(list(user.groups.all()), [group]) + + @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=False) + def test_sync_disabled(self): + """ + GIVEN: + - Disabled group syncing, a user, and a social login + WHEN: + - The social login is updated via signal after login + THEN: + - The user's groups are not updated + """ + Group.objects.create(name="group1") + user = User.objects.create_user(username="testuser") + sociallogin = Mock( + user=user, + account=Mock( + extra_data={ + "groups": ["group1"], + }, + ), + ) + handle_social_account_updated( + sender=None, + request=HttpRequest(), + sociallogin=sociallogin, + ) + self.assertEqual(list(user.groups.all()), []) + + @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True) + def test_no_groups(self): + """ + GIVEN: + - Enabled group syncing, a user, and a social login with no groups + WHEN: + - The social login is updated via signal after login + THEN: + - The user's groups are cleared to match the social login's groups + """ + group = Group.objects.create(name="group1") + user = User.objects.create_user(username="testuser") + user.groups.add(group) + user.save() + sociallogin = Mock( + user=user, + account=Mock( + extra_data={ + "groups": [], + }, + ), + ) + handle_social_account_updated( + sender=None, + request=HttpRequest(), + sociallogin=sociallogin, + ) + self.assertEqual(list(user.groups.all()), [])