Feature: OIDC & social authentication (#5190)

---------

Co-authored-by: Moritz Pflanzer <moritz@chickadee-engineering.com>
Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:
Moritz Pflanzer
2024-02-08 17:15:38 +01:00
committed by GitHub
parent cd3d609a50
commit e122a0a141
33 changed files with 1197 additions and 190 deletions

30
src/paperless/adapter.py Normal file
View File

@@ -0,0 +1,30 @@
from allauth.account.adapter import DefaultAccountAdapter
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from django.conf import settings
from django.urls import reverse
class CustomAccountAdapter(DefaultAccountAdapter):
def is_open_for_signup(self, request):
allow_signups = super().is_open_for_signup(request)
# Override with setting, otherwise default to super.
return getattr(settings, "ACCOUNT_ALLOW_SIGNUPS", allow_signups)
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
def is_open_for_signup(self, request, sociallogin):
allow_signups = super().is_open_for_signup(request, sociallogin)
# Override with setting, otherwise default to super.
return getattr(settings, "SOCIALACCOUNT_ALLOW_SIGNUPS", allow_signups)
def get_connect_redirect_url(self, request, socialaccount):
"""
Returns the default URL to redirect to after successfully
connecting a social account.
"""
url = reverse("base")
return url
def populate_user(self, request, sociallogin, data):
# TODO: If default global permissions are implemented, should also be here
return super().populate_user(request, sociallogin, data) # pragma: no cover

View File

@@ -1,5 +1,6 @@
import logging
from allauth.socialaccount.models import SocialAccount
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
@@ -105,10 +106,30 @@ class GroupSerializer(serializers.ModelSerializer):
)
class SocialAccountSerializer(serializers.ModelSerializer):
name = serializers.SerializerMethodField()
class Meta:
model = SocialAccount
fields = (
"id",
"provider",
"name",
)
def get_name(self, obj):
return obj.get_provider_account().to_str()
class ProfileSerializer(serializers.ModelSerializer):
email = serializers.EmailField(allow_null=False)
password = ObfuscatedUserPasswordField(required=False, allow_null=False)
auth_token = serializers.SlugRelatedField(read_only=True, slug_field="key")
social_accounts = SocialAccountSerializer(
many=True,
read_only=True,
source="socialaccount_set",
)
class Meta:
model = User
@@ -118,6 +139,8 @@ class ProfileSerializer(serializers.ModelSerializer):
"first_name",
"last_name",
"auth_token",
"social_accounts",
"has_usable_password",
)

View File

@@ -303,6 +303,9 @@ INSTALLED_APPS = [
"django_filters",
"django_celery_results",
"guardian",
"allauth",
"allauth.account",
"allauth.socialaccount",
*env_apps,
]
@@ -339,6 +342,7 @@ MIDDLEWARE = [
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"allauth.account.middleware.AccountMiddleware",
]
# Optional to enable compression
@@ -350,6 +354,7 @@ ROOT_URLCONF = "paperless.urls"
FORCE_SCRIPT_NAME = os.getenv("PAPERLESS_FORCE_SCRIPT_NAME")
BASE_URL = (FORCE_SCRIPT_NAME or "") + "/"
LOGIN_URL = BASE_URL + "accounts/login/"
LOGIN_REDIRECT_URL = "/dashboard"
LOGOUT_REDIRECT_URL = os.getenv("PAPERLESS_LOGOUT_REDIRECT_URL")
WSGI_APPLICATION = "paperless.wsgi.application"
@@ -410,8 +415,28 @@ CHANNEL_LAYERS = {
AUTHENTICATION_BACKENDS = [
"guardian.backends.ObjectPermissionBackend",
"django.contrib.auth.backends.ModelBackend",
"allauth.account.auth_backends.AuthenticationBackend",
]
ACCOUNT_LOGOUT_ON_GET = True
ACCOUNT_DEFAULT_HTTP_PROTOCOL = os.getenv(
"PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL",
"https",
)
ACCOUNT_ADAPTER = "paperless.adapter.CustomAccountAdapter"
ACCOUNT_ALLOW_SIGNUPS = __get_boolean("PAPERLESS_ACCOUNT_ALLOW_SIGNUPS")
SOCIALACCOUNT_ADAPTER = "paperless.adapter.CustomSocialAccountAdapter"
SOCIALACCOUNT_ALLOW_SIGNUPS = __get_boolean(
"PAPERLESS_SOCIALACCOUNT_ALLOW_SIGNUPS",
"yes",
)
SOCIALACCOUNT_AUTO_SIGNUP = __get_boolean("PAPERLESS_SOCIAL_AUTO_SIGNUP")
SOCIALACCOUNT_PROVIDERS = json.loads(
os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"),
)
AUTO_LOGIN_USERNAME = os.getenv("PAPERLESS_AUTO_LOGIN_USERNAME")
if AUTO_LOGIN_USERNAME:

View File

@@ -0,0 +1,43 @@
from allauth.account.adapter import get_adapter
from allauth.socialaccount.adapter import get_adapter as get_social_adapter
from django.conf import settings
from django.test import TestCase
from django.urls import reverse
class TestCustomAccountAdapter(TestCase):
def test_is_open_for_signup(self):
adapter = get_adapter()
# Test when ACCOUNT_ALLOW_SIGNUPS is True
settings.ACCOUNT_ALLOW_SIGNUPS = True
self.assertTrue(adapter.is_open_for_signup(None))
# Test when ACCOUNT_ALLOW_SIGNUPS is False
settings.ACCOUNT_ALLOW_SIGNUPS = False
self.assertFalse(adapter.is_open_for_signup(None))
class TestCustomSocialAccountAdapter(TestCase):
def test_is_open_for_signup(self):
adapter = get_social_adapter()
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is True
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = True
self.assertTrue(adapter.is_open_for_signup(None, None))
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is False
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = False
self.assertFalse(adapter.is_open_for_signup(None, None))
def test_get_connect_redirect_url(self):
adapter = get_social_adapter()
request = None
socialaccount = None
# Test the default URL
expected_url = reverse("base")
self.assertEqual(
adapter.get_connect_redirect_url(request, socialaccount),
expected_url,
)

View File

@@ -41,10 +41,12 @@ from documents.views import WorkflowTriggerViewSet
from documents.views import WorkflowViewSet
from paperless.consumers import StatusConsumer
from paperless.views import ApplicationConfigurationViewSet
from paperless.views import DisconnectSocialAccountView
from paperless.views import FaviconView
from paperless.views import GenerateAuthTokenView
from paperless.views import GroupViewSet
from paperless.views import ProfileView
from paperless.views import SocialAccountProvidersView
from paperless.views import UserViewSet
from paperless_mail.views import MailAccountTestView
from paperless_mail.views import MailAccountViewSet
@@ -132,6 +134,14 @@ urlpatterns = [
name="bulk_edit_object_permissions",
),
path("profile/generate_auth_token/", GenerateAuthTokenView.as_view()),
path(
"profile/disconnect_social_account/",
DisconnectSocialAccountView.as_view(),
),
path(
"profile/social_account_providers/",
SocialAccountProvidersView.as_view(),
),
re_path(
"^profile/",
ProfileView.as_view(),
@@ -192,7 +202,7 @@ urlpatterns = [
),
# TODO: with localization, this is even worse! :/
# login, logout
path("accounts/", include("django.contrib.auth.urls")),
path("accounts/", include("allauth.urls")),
# Root of the Frontend
re_path(
r".*",

View File

@@ -1,10 +1,13 @@
import os
from collections import OrderedDict
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialAccount
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.db.models.functions import Lower
from django.http import HttpResponse
from django.http import HttpResponseBadRequest
from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.authtoken.models import Token
@@ -14,6 +17,7 @@ from rest_framework.pagination import PageNumberPagination
from rest_framework.permissions import DjangoObjectPermissions
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
from documents.permissions import PaperlessObjectPermissions
@@ -168,3 +172,54 @@ class ApplicationConfigurationViewSet(ModelViewSet):
serializer_class = ApplicationConfigurationSerializer
permission_classes = (IsAuthenticated, DjangoObjectPermissions)
class DisconnectSocialAccountView(GenericAPIView):
"""
Disconnects a social account provider from the user account
"""
permission_classes = [IsAuthenticated]
def post(self, request, *args, **kwargs):
user = self.request.user
try:
account = user.socialaccount_set.get(pk=request.data["id"])
account_id = account.id
account.delete()
return Response(account_id)
except SocialAccount.DoesNotExist:
return HttpResponseBadRequest("Social account not found")
class SocialAccountProvidersView(APIView):
"""
List of social account providers
"""
permission_classes = [IsAuthenticated]
def get(self, request, *args, **kwargs):
adapter = get_adapter()
providers = adapter.list_providers(request)
resp = [
{"name": p.name, "login_url": p.get_login_url(request, process="connect")}
for p in providers
if p.id != "openid"
]
for openid_provider in filter(lambda p: p.id == "openid", providers):
resp += [
{
"name": b["name"],
"login_url": openid_provider.get_login_url(
request,
process="connect",
openid=b["openid_url"],
),
}
for b in openid_provider.get_brands()
]
return Response(sorted(resp, key=lambda p: p["name"]))