Feature: OIDC & social authentication (#5190)

---------

Co-authored-by: Moritz Pflanzer <moritz@chickadee-engineering.com>
Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:
Moritz Pflanzer
2024-02-08 17:15:38 +01:00
committed by GitHub
parent b47f301831
commit c508be6ecd
33 changed files with 1197 additions and 190 deletions

View File

@@ -1,3 +1,7 @@
from unittest import mock
from allauth.socialaccount.models import SocialAccount
from allauth.socialaccount.models import SocialApp
from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.authtoken.models import Token
@@ -6,6 +10,44 @@ from rest_framework.test import APITestCase
from documents.tests.utils import DirectoriesMixin
# see allauth.socialaccount.providers.openid.provider.OpenIDProvider
class MockOpenIDProvider:
id = "openid"
name = "OpenID"
def get_brands(self):
default_servers = [
dict(id="yahoo", name="Yahoo", openid_url="http://me.yahoo.com"),
dict(id="hyves", name="Hyves", openid_url="http://hyves.nl"),
]
return default_servers
def get_login_url(self, request, **kwargs):
return "openid/login/"
# see allauth.socialaccount.providers.openid_connect.provider.OpenIDConnectProviderAccount
class MockOpenIDConnectProviderAccount:
def __init__(self, mock_social_account_dict):
self.account = mock_social_account_dict
def to_str(self):
return self.account["name"]
# see allauth.socialaccount.providers.openid_connect.provider.OpenIDConnectProvider
class MockOpenIDConnectProvider:
id = "openid_connect"
name = "OpenID Connect"
def __init__(self, app=None):
self.app = app
self.name = app.name
def get_login_url(self, request, **kwargs):
return f"{self.app.provider_id}/login/?process=connect"
class TestApiProfile(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/profile/"
@@ -19,6 +61,17 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
)
self.client.force_authenticate(user=self.user)
def setupSocialAccount(self):
SocialApp.objects.create(
name="Keycloak",
provider="openid_connect",
provider_id="keycloak-test",
)
self.user.socialaccount_set.add(
SocialAccount(uid="123456789", provider="keycloak-test"),
bulk=False,
)
def test_get_profile(self):
"""
GIVEN:
@@ -28,7 +81,6 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
THEN:
- Profile is returned
"""
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -37,6 +89,52 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
self.assertEqual(response.data["first_name"], self.user.first_name)
self.assertEqual(response.data["last_name"], self.user.last_name)
@mock.patch(
"allauth.socialaccount.models.SocialAccount.get_provider_account",
)
@mock.patch(
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
)
def test_get_profile_w_social(self, mock_list_providers, mock_get_provider_account):
"""
GIVEN:
- Configured user and setup social account
WHEN:
- API call is made to get profile
THEN:
- Profile is returned with social accounts
"""
self.setupSocialAccount()
openid_provider = (
MockOpenIDConnectProvider(
app=SocialApp.objects.get(provider_id="keycloak-test"),
),
)
mock_list_providers.return_value = [
openid_provider,
]
mock_get_provider_account.return_value = MockOpenIDConnectProviderAccount(
mock_social_account_dict={
"name": openid_provider[0].name,
},
)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data["social_accounts"],
[
{
"id": 1,
"provider": "keycloak-test",
"name": "Keycloak",
},
],
)
def test_update_profile(self):
"""
GIVEN:
@@ -103,3 +201,101 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
response = self.client.post(f"{self.ENDPOINT}generate_auth_token/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@mock.patch(
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
)
def test_get_social_account_providers(
self,
mock_list_providers,
):
"""
GIVEN:
- Configured user
WHEN:
- API call is made to get social account providers
THEN:
- Social account providers are returned
"""
self.setupSocialAccount()
mock_list_providers.return_value = [
MockOpenIDConnectProvider(
app=SocialApp.objects.get(provider_id="keycloak-test"),
),
]
response = self.client.get(f"{self.ENDPOINT}social_account_providers/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data[0]["name"],
"Keycloak",
)
self.assertIn(
"keycloak-test/login/?process=connect",
response.data[0]["login_url"],
)
@mock.patch(
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
)
def test_get_social_account_providers_openid(
self,
mock_list_providers,
):
"""
GIVEN:
- Configured user and openid social account provider
WHEN:
- API call is made to get social account providers
THEN:
- Brands for openid provider are returned
"""
mock_list_providers.return_value = [
MockOpenIDProvider(),
]
response = self.client.get(f"{self.ENDPOINT}social_account_providers/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
len(response.data),
2,
)
def test_disconnect_social_account(self):
"""
GIVEN:
- Configured user
WHEN:
- API call is made to disconnect a social account
THEN:
- Social account is deleted from the user or request fails
"""
self.setupSocialAccount()
# Test with invalid id
response = self.client.post(
f"{self.ENDPOINT}disconnect_social_account/",
{"id": -1},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
# Test with valid id
social_account_id = self.user.socialaccount_set.all()[0].pk
response = self.client.post(
f"{self.ENDPOINT}disconnect_social_account/",
{"id": social_account_id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, social_account_id)
self.assertEqual(
len(self.user.socialaccount_set.filter(pk=social_account_id)),
0,
)

View File

@@ -177,9 +177,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
os.path.join(self.dirs.media_dir, "documents"),
)
manifest = self._do_export(use_filename_format=use_filename_format)
num_permission_objects = Permission.objects.count()
self.assertEqual(len(manifest), 190)
manifest = self._do_export(use_filename_format=use_filename_format)
# dont include consumer or AnonymousUser users
self.assertEqual(
@@ -273,7 +273,7 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertEqual(Document.objects.get(id=self.d4.id).title, "wow_dec")
self.assertEqual(GroupObjectPermission.objects.count(), 1)
self.assertEqual(UserObjectPermission.objects.count(), 1)
self.assertEqual(Permission.objects.count(), 136)
self.assertEqual(Permission.objects.count(), num_permission_objects)
messages = check_sanity()
# everything is alright after the test
self.assertEqual(len(messages), 0)
@@ -753,15 +753,15 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
os.path.join(self.dirs.media_dir, "documents"),
)
self.assertEqual(ContentType.objects.count(), 34)
self.assertEqual(Permission.objects.count(), 136)
num_content_type_objects = ContentType.objects.count()
num_permission_objects = Permission.objects.count()
manifest = self._do_export()
with paperless_environment():
self.assertEqual(
len(list(filter(lambda e: e["model"] == "auth.permission", manifest))),
136,
num_permission_objects,
)
# add 1 more to db to show objects are not re-created by import
Permission.objects.create(
@@ -769,7 +769,7 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
codename="test_perm",
content_type_id=1,
)
self.assertEqual(Permission.objects.count(), 137)
self.assertEqual(Permission.objects.count(), num_permission_objects + 1)
# will cause an import error
self.user.delete()
@@ -778,5 +778,5 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
with self.assertRaises(IntegrityError):
call_command("document_importer", "--no-progress-bar", self.target)
self.assertEqual(ContentType.objects.count(), 34)
self.assertEqual(Permission.objects.count(), 137)
self.assertEqual(ContentType.objects.count(), num_content_type_objects)
self.assertEqual(Permission.objects.count(), num_permission_objects + 1)