diff --git a/docs/configuration.md b/docs/configuration.md index 441d46105..86a22640d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -586,6 +586,19 @@ system. See the corresponding Defaults to False +#### [`PAPERLESS_ACCOUNT_DEFAULT_GROUPS=`](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS) {#PAPERLESS_ACCOUNT_DEFAULT_GROUPS} + +: A list of group names that users will be added to when they sign up for a new account. Groups listed here must already exist. + + Defaults to None + +#### [`PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS=`](#PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS) {#PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS} + +: A list of group names that users who signup via social accounts will be added to upon signup. Groups lsited here must already exist. +If both the [PAPERLESS_ACCOUNT_DEFAULT_GROUPS](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS) setting and this setting are used, the user will be added to both sets of groups. + + Defaults to None + #### [`PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL=`](#PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL) {#PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL} : The protocol used when generating URLs, e.g. login callback URLs. See the corresponding diff --git a/src/paperless/adapter.py b/src/paperless/adapter.py index add2bf45d..eda51a2d4 100644 --- a/src/paperless/adapter.py +++ b/src/paperless/adapter.py @@ -4,6 +4,8 @@ from allauth.account.adapter import DefaultAccountAdapter from allauth.core import context from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from django.conf import settings +from django.contrib.auth.models import Group +from django.contrib.auth.models import User from django.forms import ValidationError from django.urls import reverse @@ -61,6 +63,19 @@ class CustomAccountAdapter(DefaultAccountAdapter): path = path.replace("UID-KEY", quote(key)) return settings.PAPERLESS_URL + path + def save_user(self, request, user, form, commit): + """ + Save the user instance. Default groups are assigned to the user, if + specified in the settings. + """ + user: User = super().save_user(request, user, form, commit) + group_names: list[str] = settings.ACCOUNT_DEFAULT_GROUPS + if len(group_names) > 0: + groups = Group.objects.filter(name__in=group_names) + user.groups.add(*groups) + user.save() + return user + class CustomSocialAccountAdapter(DefaultSocialAccountAdapter): def is_open_for_signup(self, request, sociallogin): @@ -80,10 +95,16 @@ class CustomSocialAccountAdapter(DefaultSocialAccountAdapter): url = reverse("base") return url - def populate_user(self, request, sociallogin, data): + def save_user(self, request, sociallogin, form=None): """ - Populate the user with data from the social account. Stub is kept in case - global default permissions are implemented in the future. + Save the user instance. Default groups are assigned to the user, if + specified in the settings. """ - # TODO: If default global permissions are implemented, should also be here - return super().populate_user(request, sociallogin, data) # pragma: no cover + # save_user also calls account_adapter save_user which would set ACCOUNT_DEFAULT_GROUPS + user: User = super().save_user(request, sociallogin, form) + group_names: list[str] = settings.SOCIAL_ACCOUNT_DEFAULT_GROUPS + if len(group_names) > 0: + groups = Group.objects.filter(name__in=group_names) + user.groups.add(*groups) + user.save() + return user diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 8072f694e..401ba04b9 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -480,6 +480,7 @@ ACCOUNT_DEFAULT_HTTP_PROTOCOL = os.getenv( ACCOUNT_ADAPTER = "paperless.adapter.CustomAccountAdapter" ACCOUNT_ALLOW_SIGNUPS = __get_boolean("PAPERLESS_ACCOUNT_ALLOW_SIGNUPS") +ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_ACCOUNT_DEFAULT_GROUPS") SOCIALACCOUNT_ADAPTER = "paperless.adapter.CustomSocialAccountAdapter" SOCIALACCOUNT_ALLOW_SIGNUPS = __get_boolean( @@ -490,6 +491,7 @@ SOCIALACCOUNT_AUTO_SIGNUP = __get_boolean("PAPERLESS_SOCIAL_AUTO_SIGNUP") SOCIALACCOUNT_PROVIDERS = json.loads( os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"), ) +SOCIAL_ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS") MFA_TOTP_ISSUER = "Paperless-ngx" diff --git a/src/paperless/tests/test_adapter.py b/src/paperless/tests/test_adapter.py index 5659a279a..be4ad3d90 100644 --- a/src/paperless/tests/test_adapter.py +++ b/src/paperless/tests/test_adapter.py @@ -4,6 +4,8 @@ from allauth.account.adapter import get_adapter from allauth.core import context from allauth.socialaccount.adapter import get_adapter as get_social_adapter from django.conf import settings +from django.contrib.auth.models import Group +from django.contrib.auth.models import User from django.forms import ValidationError from django.http import HttpRequest from django.test import TestCase @@ -81,6 +83,24 @@ class TestCustomAccountAdapter(TestCase): expected_url, ) + @override_settings(ACCOUNT_DEFAULT_GROUPS=["group1", "group2"]) + def test_save_user_adds_groups(self): + Group.objects.create(name="group1") + user = User.objects.create_user("testuser") + adapter = get_adapter() + form = mock.Mock( + cleaned_data={ + "username": "testuser", + "email": "user@example.com", + }, + ) + + user = adapter.save_user(HttpRequest(), user, form, commit=True) + + self.assertEqual(user.groups.count(), 1) + self.assertTrue(user.groups.filter(name="group1").exists()) + self.assertFalse(user.groups.filter(name="group2").exists()) + class TestCustomSocialAccountAdapter(TestCase): def test_is_open_for_signup(self): @@ -105,3 +125,19 @@ class TestCustomSocialAccountAdapter(TestCase): adapter.get_connect_redirect_url(request, socialaccount), expected_url, ) + + @override_settings(SOCIAL_ACCOUNT_DEFAULT_GROUPS=["group1", "group2"]) + def test_save_user_adds_groups(self): + Group.objects.create(name="group1") + adapter = get_social_adapter() + request = HttpRequest() + user = User.objects.create_user("testuser") + sociallogin = mock.Mock( + user=user, + ) + + user = adapter.save_user(request, sociallogin, None) + + self.assertEqual(user.groups.count(), 1) + self.assertTrue(user.groups.filter(name="group1").exists()) + self.assertFalse(user.groups.filter(name="group2").exists())