Enhancement: support default groups for regular and social account signup (#9039)

This commit is contained in:
shamoon
2025-02-24 09:23:20 -08:00
committed by GitHub
parent a548c32c1f
commit 047f7c3619
7 changed files with 216 additions and 7 deletions

View File

@@ -1,12 +1,17 @@
import logging
from urllib.parse import quote
from allauth.account.adapter import DefaultAccountAdapter
from allauth.core import context
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.forms import ValidationError
from django.urls import reverse
logger = logging.getLogger("paperless.auth")
class CustomAccountAdapter(DefaultAccountAdapter):
def is_open_for_signup(self, request):
@@ -61,6 +66,20 @@ class CustomAccountAdapter(DefaultAccountAdapter):
path = path.replace("UID-KEY", quote(key))
return settings.PAPERLESS_URL + path
def save_user(self, request, user, form, commit=True): # noqa: FBT002
"""
Save the user instance. Default groups are assigned to the user, if
specified in the settings.
"""
user: User = super().save_user(request, user, form, commit)
group_names: list[str] = settings.ACCOUNT_DEFAULT_GROUPS
if len(group_names) > 0:
groups = Group.objects.filter(name__in=group_names)
logger.debug(f"Adding default groups to user `{user}`: {group_names}")
user.groups.add(*groups)
user.save()
return user
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
def is_open_for_signup(self, request, sociallogin):
@@ -80,10 +99,19 @@ class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
url = reverse("base")
return url
def populate_user(self, request, sociallogin, data):
def save_user(self, request, sociallogin, form=None):
"""
Populate the user with data from the social account. Stub is kept in case
global default permissions are implemented in the future.
Save the user instance. Default groups are assigned to the user, if
specified in the settings.
"""
# TODO: If default global permissions are implemented, should also be here
return super().populate_user(request, sociallogin, data) # pragma: no cover
# save_user also calls account_adapter save_user which would set ACCOUNT_DEFAULT_GROUPS
user: User = super().save_user(request, sociallogin, form)
group_names: list[str] = settings.SOCIAL_ACCOUNT_DEFAULT_GROUPS
if len(group_names) > 0:
groups = Group.objects.filter(name__in=group_names)
logger.debug(
f"Adding default social groups to user `{user}`: {group_names}",
)
user.groups.add(*groups)
user.save()
return user

View File

@@ -2,6 +2,7 @@ from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
from paperless.signals import handle_failed_login
from paperless.signals import handle_social_account_updated
class PaperlessConfig(AppConfig):
@@ -13,4 +14,9 @@ class PaperlessConfig(AppConfig):
from django.contrib.auth.signals import user_login_failed
user_login_failed.connect(handle_failed_login)
from allauth.socialaccount.signals import social_account_updated
social_account_updated.connect(handle_social_account_updated)
AppConfig.ready(self)

View File

@@ -480,6 +480,7 @@ ACCOUNT_DEFAULT_HTTP_PROTOCOL = os.getenv(
ACCOUNT_ADAPTER = "paperless.adapter.CustomAccountAdapter"
ACCOUNT_ALLOW_SIGNUPS = __get_boolean("PAPERLESS_ACCOUNT_ALLOW_SIGNUPS")
ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_ACCOUNT_DEFAULT_GROUPS")
SOCIALACCOUNT_ADAPTER = "paperless.adapter.CustomSocialAccountAdapter"
SOCIALACCOUNT_ALLOW_SIGNUPS = __get_boolean(
@@ -490,6 +491,8 @@ SOCIALACCOUNT_AUTO_SIGNUP = __get_boolean("PAPERLESS_SOCIAL_AUTO_SIGNUP")
SOCIALACCOUNT_PROVIDERS = json.loads(
os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"),
)
SOCIAL_ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS")
SOCIAL_ACCOUNT_SYNC_GROUPS = __get_boolean("PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS")
MFA_TOTP_ISSUER = "Paperless-ngx"

View File

@@ -30,3 +30,21 @@ def handle_failed_login(sender, credentials, request, **kwargs):
log_output += f" from private IP `{client_ip}`."
logger.info(log_output)
def handle_social_account_updated(sender, request, sociallogin, **kwargs):
"""
Handle the social account update signal.
"""
from django.contrib.auth.models import Group
social_account_groups = sociallogin.account.extra_data.get(
"groups",
[],
) # None if not found
if settings.SOCIAL_ACCOUNT_SYNC_GROUPS and social_account_groups is not None:
groups = Group.objects.filter(name__in=social_account_groups)
logger.debug(
f"Syncing groups for user `{sociallogin.user}`: {social_account_groups}",
)
sociallogin.user.groups.set(groups, clear=True)

View File

@@ -4,6 +4,8 @@ from allauth.account.adapter import get_adapter
from allauth.core import context
from allauth.socialaccount.adapter import get_adapter as get_social_adapter
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.forms import ValidationError
from django.http import HttpRequest
from django.test import TestCase
@@ -81,6 +83,24 @@ class TestCustomAccountAdapter(TestCase):
expected_url,
)
@override_settings(ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
def test_save_user_adds_groups(self):
Group.objects.create(name="group1")
user = User.objects.create_user("testuser")
adapter = get_adapter()
form = mock.Mock(
cleaned_data={
"username": "testuser",
"email": "user@example.com",
},
)
user = adapter.save_user(HttpRequest(), user, form, commit=True)
self.assertEqual(user.groups.count(), 1)
self.assertTrue(user.groups.filter(name="group1").exists())
self.assertFalse(user.groups.filter(name="group2").exists())
class TestCustomSocialAccountAdapter(TestCase):
def test_is_open_for_signup(self):
@@ -105,3 +125,19 @@ class TestCustomSocialAccountAdapter(TestCase):
adapter.get_connect_redirect_url(request, socialaccount),
expected_url,
)
@override_settings(SOCIAL_ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
def test_save_user_adds_groups(self):
Group.objects.create(name="group1")
adapter = get_social_adapter()
request = HttpRequest()
user = User.objects.create_user("testuser")
sociallogin = mock.Mock(
user=user,
)
user = adapter.save_user(request, sociallogin, None)
self.assertEqual(user.groups.count(), 1)
self.assertTrue(user.groups.filter(name="group1").exists())
self.assertFalse(user.groups.filter(name="group2").exists())

View File

@@ -1,7 +1,13 @@
from unittest.mock import Mock
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.http import HttpRequest
from django.test import TestCase
from django.test import override_settings
from paperless.signals import handle_failed_login
from paperless.signals import handle_social_account_updated
class TestFailedLoginLogging(TestCase):
@@ -99,3 +105,88 @@ class TestFailedLoginLogging(TestCase):
"INFO:paperless.auth:Login failed for user `john lennon` from private IP `10.0.0.1`.",
],
)
class TestSyncSocialLoginGroups(TestCase):
@override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True)
def test_sync_enabled(self):
"""
GIVEN:
- Enabled group syncing, a user, and a social login
WHEN:
- The social login is updated via signal after login
THEN:
- The user's groups are updated to match the social login's groups
"""
group = Group.objects.create(name="group1")
user = User.objects.create_user(username="testuser")
sociallogin = Mock(
user=user,
account=Mock(
extra_data={
"groups": ["group1"],
},
),
)
handle_social_account_updated(
sender=None,
request=HttpRequest(),
sociallogin=sociallogin,
)
self.assertEqual(list(user.groups.all()), [group])
@override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=False)
def test_sync_disabled(self):
"""
GIVEN:
- Disabled group syncing, a user, and a social login
WHEN:
- The social login is updated via signal after login
THEN:
- The user's groups are not updated
"""
Group.objects.create(name="group1")
user = User.objects.create_user(username="testuser")
sociallogin = Mock(
user=user,
account=Mock(
extra_data={
"groups": ["group1"],
},
),
)
handle_social_account_updated(
sender=None,
request=HttpRequest(),
sociallogin=sociallogin,
)
self.assertEqual(list(user.groups.all()), [])
@override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True)
def test_no_groups(self):
"""
GIVEN:
- Enabled group syncing, a user, and a social login with no groups
WHEN:
- The social login is updated via signal after login
THEN:
- The user's groups are cleared to match the social login's groups
"""
group = Group.objects.create(name="group1")
user = User.objects.create_user(username="testuser")
user.groups.add(group)
user.save()
sociallogin = Mock(
user=user,
account=Mock(
extra_data={
"groups": [],
},
),
)
handle_social_account_updated(
sender=None,
request=HttpRequest(),
sociallogin=sociallogin,
)
self.assertEqual(list(user.groups.all()), [])