mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Feature: OIDC & social authentication (#5190)
--------- Co-authored-by: Moritz Pflanzer <moritz@chickadee-engineering.com> Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:
30
src/paperless/adapter.py
Normal file
30
src/paperless/adapter.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from allauth.account.adapter import DefaultAccountAdapter
|
||||
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
class CustomAccountAdapter(DefaultAccountAdapter):
|
||||
def is_open_for_signup(self, request):
|
||||
allow_signups = super().is_open_for_signup(request)
|
||||
# Override with setting, otherwise default to super.
|
||||
return getattr(settings, "ACCOUNT_ALLOW_SIGNUPS", allow_signups)
|
||||
|
||||
|
||||
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
def is_open_for_signup(self, request, sociallogin):
|
||||
allow_signups = super().is_open_for_signup(request, sociallogin)
|
||||
# Override with setting, otherwise default to super.
|
||||
return getattr(settings, "SOCIALACCOUNT_ALLOW_SIGNUPS", allow_signups)
|
||||
|
||||
def get_connect_redirect_url(self, request, socialaccount):
|
||||
"""
|
||||
Returns the default URL to redirect to after successfully
|
||||
connecting a social account.
|
||||
"""
|
||||
url = reverse("base")
|
||||
return url
|
||||
|
||||
def populate_user(self, request, sociallogin, data):
|
||||
# TODO: If default global permissions are implemented, should also be here
|
||||
return super().populate_user(request, sociallogin, data) # pragma: no cover
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
@@ -105,10 +106,30 @@ class GroupSerializer(serializers.ModelSerializer):
|
||||
)
|
||||
|
||||
|
||||
class SocialAccountSerializer(serializers.ModelSerializer):
|
||||
name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = SocialAccount
|
||||
fields = (
|
||||
"id",
|
||||
"provider",
|
||||
"name",
|
||||
)
|
||||
|
||||
def get_name(self, obj):
|
||||
return obj.get_provider_account().to_str()
|
||||
|
||||
|
||||
class ProfileSerializer(serializers.ModelSerializer):
|
||||
email = serializers.EmailField(allow_null=False)
|
||||
password = ObfuscatedUserPasswordField(required=False, allow_null=False)
|
||||
auth_token = serializers.SlugRelatedField(read_only=True, slug_field="key")
|
||||
social_accounts = SocialAccountSerializer(
|
||||
many=True,
|
||||
read_only=True,
|
||||
source="socialaccount_set",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
@@ -118,6 +139,8 @@ class ProfileSerializer(serializers.ModelSerializer):
|
||||
"first_name",
|
||||
"last_name",
|
||||
"auth_token",
|
||||
"social_accounts",
|
||||
"has_usable_password",
|
||||
)
|
||||
|
||||
|
||||
|
@@ -303,6 +303,9 @@ INSTALLED_APPS = [
|
||||
"django_filters",
|
||||
"django_celery_results",
|
||||
"guardian",
|
||||
"allauth",
|
||||
"allauth.account",
|
||||
"allauth.socialaccount",
|
||||
*env_apps,
|
||||
]
|
||||
|
||||
@@ -339,6 +342,7 @@ MIDDLEWARE = [
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"allauth.account.middleware.AccountMiddleware",
|
||||
]
|
||||
|
||||
# Optional to enable compression
|
||||
@@ -350,6 +354,7 @@ ROOT_URLCONF = "paperless.urls"
|
||||
FORCE_SCRIPT_NAME = os.getenv("PAPERLESS_FORCE_SCRIPT_NAME")
|
||||
BASE_URL = (FORCE_SCRIPT_NAME or "") + "/"
|
||||
LOGIN_URL = BASE_URL + "accounts/login/"
|
||||
LOGIN_REDIRECT_URL = "/dashboard"
|
||||
LOGOUT_REDIRECT_URL = os.getenv("PAPERLESS_LOGOUT_REDIRECT_URL")
|
||||
|
||||
WSGI_APPLICATION = "paperless.wsgi.application"
|
||||
@@ -410,8 +415,28 @@ CHANNEL_LAYERS = {
|
||||
AUTHENTICATION_BACKENDS = [
|
||||
"guardian.backends.ObjectPermissionBackend",
|
||||
"django.contrib.auth.backends.ModelBackend",
|
||||
"allauth.account.auth_backends.AuthenticationBackend",
|
||||
]
|
||||
|
||||
ACCOUNT_LOGOUT_ON_GET = True
|
||||
ACCOUNT_DEFAULT_HTTP_PROTOCOL = os.getenv(
|
||||
"PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL",
|
||||
"https",
|
||||
)
|
||||
|
||||
ACCOUNT_ADAPTER = "paperless.adapter.CustomAccountAdapter"
|
||||
ACCOUNT_ALLOW_SIGNUPS = __get_boolean("PAPERLESS_ACCOUNT_ALLOW_SIGNUPS")
|
||||
|
||||
SOCIALACCOUNT_ADAPTER = "paperless.adapter.CustomSocialAccountAdapter"
|
||||
SOCIALACCOUNT_ALLOW_SIGNUPS = __get_boolean(
|
||||
"PAPERLESS_SOCIALACCOUNT_ALLOW_SIGNUPS",
|
||||
"yes",
|
||||
)
|
||||
SOCIALACCOUNT_AUTO_SIGNUP = __get_boolean("PAPERLESS_SOCIAL_AUTO_SIGNUP")
|
||||
SOCIALACCOUNT_PROVIDERS = json.loads(
|
||||
os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"),
|
||||
)
|
||||
|
||||
AUTO_LOGIN_USERNAME = os.getenv("PAPERLESS_AUTO_LOGIN_USERNAME")
|
||||
|
||||
if AUTO_LOGIN_USERNAME:
|
||||
|
43
src/paperless/tests/test_adapter.py
Normal file
43
src/paperless/tests/test_adapter.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from allauth.account.adapter import get_adapter
|
||||
from allauth.socialaccount.adapter import get_adapter as get_social_adapter
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
class TestCustomAccountAdapter(TestCase):
|
||||
def test_is_open_for_signup(self):
|
||||
adapter = get_adapter()
|
||||
|
||||
# Test when ACCOUNT_ALLOW_SIGNUPS is True
|
||||
settings.ACCOUNT_ALLOW_SIGNUPS = True
|
||||
self.assertTrue(adapter.is_open_for_signup(None))
|
||||
|
||||
# Test when ACCOUNT_ALLOW_SIGNUPS is False
|
||||
settings.ACCOUNT_ALLOW_SIGNUPS = False
|
||||
self.assertFalse(adapter.is_open_for_signup(None))
|
||||
|
||||
|
||||
class TestCustomSocialAccountAdapter(TestCase):
|
||||
def test_is_open_for_signup(self):
|
||||
adapter = get_social_adapter()
|
||||
|
||||
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is True
|
||||
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = True
|
||||
self.assertTrue(adapter.is_open_for_signup(None, None))
|
||||
|
||||
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is False
|
||||
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = False
|
||||
self.assertFalse(adapter.is_open_for_signup(None, None))
|
||||
|
||||
def test_get_connect_redirect_url(self):
|
||||
adapter = get_social_adapter()
|
||||
request = None
|
||||
socialaccount = None
|
||||
|
||||
# Test the default URL
|
||||
expected_url = reverse("base")
|
||||
self.assertEqual(
|
||||
adapter.get_connect_redirect_url(request, socialaccount),
|
||||
expected_url,
|
||||
)
|
@@ -41,10 +41,12 @@ from documents.views import WorkflowTriggerViewSet
|
||||
from documents.views import WorkflowViewSet
|
||||
from paperless.consumers import StatusConsumer
|
||||
from paperless.views import ApplicationConfigurationViewSet
|
||||
from paperless.views import DisconnectSocialAccountView
|
||||
from paperless.views import FaviconView
|
||||
from paperless.views import GenerateAuthTokenView
|
||||
from paperless.views import GroupViewSet
|
||||
from paperless.views import ProfileView
|
||||
from paperless.views import SocialAccountProvidersView
|
||||
from paperless.views import UserViewSet
|
||||
from paperless_mail.views import MailAccountTestView
|
||||
from paperless_mail.views import MailAccountViewSet
|
||||
@@ -132,6 +134,14 @@ urlpatterns = [
|
||||
name="bulk_edit_object_permissions",
|
||||
),
|
||||
path("profile/generate_auth_token/", GenerateAuthTokenView.as_view()),
|
||||
path(
|
||||
"profile/disconnect_social_account/",
|
||||
DisconnectSocialAccountView.as_view(),
|
||||
),
|
||||
path(
|
||||
"profile/social_account_providers/",
|
||||
SocialAccountProvidersView.as_view(),
|
||||
),
|
||||
re_path(
|
||||
"^profile/",
|
||||
ProfileView.as_view(),
|
||||
@@ -192,7 +202,7 @@ urlpatterns = [
|
||||
),
|
||||
# TODO: with localization, this is even worse! :/
|
||||
# login, logout
|
||||
path("accounts/", include("django.contrib.auth.urls")),
|
||||
path("accounts/", include("allauth.urls")),
|
||||
# Root of the Frontend
|
||||
re_path(
|
||||
r".*",
|
||||
|
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from allauth.socialaccount.adapter import get_adapter
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.db.models.functions import Lower
|
||||
from django.http import HttpResponse
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.views.generic import View
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from rest_framework.authtoken.models import Token
|
||||
@@ -14,6 +17,7 @@ from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.permissions import DjangoObjectPermissions
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from documents.permissions import PaperlessObjectPermissions
|
||||
@@ -168,3 +172,54 @@ class ApplicationConfigurationViewSet(ModelViewSet):
|
||||
|
||||
serializer_class = ApplicationConfigurationSerializer
|
||||
permission_classes = (IsAuthenticated, DjangoObjectPermissions)
|
||||
|
||||
|
||||
class DisconnectSocialAccountView(GenericAPIView):
|
||||
"""
|
||||
Disconnects a social account provider from the user account
|
||||
"""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
user = self.request.user
|
||||
|
||||
try:
|
||||
account = user.socialaccount_set.get(pk=request.data["id"])
|
||||
account_id = account.id
|
||||
account.delete()
|
||||
return Response(account_id)
|
||||
except SocialAccount.DoesNotExist:
|
||||
return HttpResponseBadRequest("Social account not found")
|
||||
|
||||
|
||||
class SocialAccountProvidersView(APIView):
|
||||
"""
|
||||
List of social account providers
|
||||
"""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
adapter = get_adapter()
|
||||
providers = adapter.list_providers(request)
|
||||
resp = [
|
||||
{"name": p.name, "login_url": p.get_login_url(request, process="connect")}
|
||||
for p in providers
|
||||
if p.id != "openid"
|
||||
]
|
||||
|
||||
for openid_provider in filter(lambda p: p.id == "openid", providers):
|
||||
resp += [
|
||||
{
|
||||
"name": b["name"],
|
||||
"login_url": openid_provider.get_login_url(
|
||||
request,
|
||||
process="connect",
|
||||
openid=b["openid_url"],
|
||||
),
|
||||
}
|
||||
for b in openid_provider.get_brands()
|
||||
]
|
||||
|
||||
return Response(sorted(resp, key=lambda p: p["name"]))
|
||||
|
Reference in New Issue
Block a user