mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-12 00:19:48 +00:00
Feature: OIDC & social authentication (#5190)
--------- Co-authored-by: Moritz Pflanzer <moritz@chickadee-engineering.com> Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
from unittest import mock
|
||||
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework import status
|
||||
from rest_framework.authtoken.models import Token
|
||||
@@ -6,6 +10,44 @@ from rest_framework.test import APITestCase
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
# see allauth.socialaccount.providers.openid.provider.OpenIDProvider
|
||||
class MockOpenIDProvider:
|
||||
id = "openid"
|
||||
name = "OpenID"
|
||||
|
||||
def get_brands(self):
|
||||
default_servers = [
|
||||
dict(id="yahoo", name="Yahoo", openid_url="http://me.yahoo.com"),
|
||||
dict(id="hyves", name="Hyves", openid_url="http://hyves.nl"),
|
||||
]
|
||||
return default_servers
|
||||
|
||||
def get_login_url(self, request, **kwargs):
|
||||
return "openid/login/"
|
||||
|
||||
|
||||
# see allauth.socialaccount.providers.openid_connect.provider.OpenIDConnectProviderAccount
|
||||
class MockOpenIDConnectProviderAccount:
|
||||
def __init__(self, mock_social_account_dict):
|
||||
self.account = mock_social_account_dict
|
||||
|
||||
def to_str(self):
|
||||
return self.account["name"]
|
||||
|
||||
|
||||
# see allauth.socialaccount.providers.openid_connect.provider.OpenIDConnectProvider
|
||||
class MockOpenIDConnectProvider:
|
||||
id = "openid_connect"
|
||||
name = "OpenID Connect"
|
||||
|
||||
def __init__(self, app=None):
|
||||
self.app = app
|
||||
self.name = app.name
|
||||
|
||||
def get_login_url(self, request, **kwargs):
|
||||
return f"{self.app.provider_id}/login/?process=connect"
|
||||
|
||||
|
||||
class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
ENDPOINT = "/api/profile/"
|
||||
|
||||
@@ -19,6 +61,17 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
def setupSocialAccount(self):
|
||||
SocialApp.objects.create(
|
||||
name="Keycloak",
|
||||
provider="openid_connect",
|
||||
provider_id="keycloak-test",
|
||||
)
|
||||
self.user.socialaccount_set.add(
|
||||
SocialAccount(uid="123456789", provider="keycloak-test"),
|
||||
bulk=False,
|
||||
)
|
||||
|
||||
def test_get_profile(self):
|
||||
"""
|
||||
GIVEN:
|
||||
@@ -28,7 +81,6 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
THEN:
|
||||
- Profile is returned
|
||||
"""
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
@@ -37,6 +89,52 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.data["first_name"], self.user.first_name)
|
||||
self.assertEqual(response.data["last_name"], self.user.last_name)
|
||||
|
||||
@mock.patch(
|
||||
"allauth.socialaccount.models.SocialAccount.get_provider_account",
|
||||
)
|
||||
@mock.patch(
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
|
||||
)
|
||||
def test_get_profile_w_social(self, mock_list_providers, mock_get_provider_account):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user and setup social account
|
||||
WHEN:
|
||||
- API call is made to get profile
|
||||
THEN:
|
||||
- Profile is returned with social accounts
|
||||
"""
|
||||
self.setupSocialAccount()
|
||||
|
||||
openid_provider = (
|
||||
MockOpenIDConnectProvider(
|
||||
app=SocialApp.objects.get(provider_id="keycloak-test"),
|
||||
),
|
||||
)
|
||||
mock_list_providers.return_value = [
|
||||
openid_provider,
|
||||
]
|
||||
mock_get_provider_account.return_value = MockOpenIDConnectProviderAccount(
|
||||
mock_social_account_dict={
|
||||
"name": openid_provider[0].name,
|
||||
},
|
||||
)
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
self.assertEqual(
|
||||
response.data["social_accounts"],
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "keycloak-test",
|
||||
"name": "Keycloak",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_update_profile(self):
|
||||
"""
|
||||
GIVEN:
|
||||
@@ -103,3 +201,101 @@ class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
|
||||
response = self.client.post(f"{self.ENDPOINT}generate_auth_token/")
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
@mock.patch(
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
|
||||
)
|
||||
def test_get_social_account_providers(
|
||||
self,
|
||||
mock_list_providers,
|
||||
):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user
|
||||
WHEN:
|
||||
- API call is made to get social account providers
|
||||
THEN:
|
||||
- Social account providers are returned
|
||||
"""
|
||||
self.setupSocialAccount()
|
||||
|
||||
mock_list_providers.return_value = [
|
||||
MockOpenIDConnectProvider(
|
||||
app=SocialApp.objects.get(provider_id="keycloak-test"),
|
||||
),
|
||||
]
|
||||
|
||||
response = self.client.get(f"{self.ENDPOINT}social_account_providers/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(
|
||||
response.data[0]["name"],
|
||||
"Keycloak",
|
||||
)
|
||||
self.assertIn(
|
||||
"keycloak-test/login/?process=connect",
|
||||
response.data[0]["login_url"],
|
||||
)
|
||||
|
||||
@mock.patch(
|
||||
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers",
|
||||
)
|
||||
def test_get_social_account_providers_openid(
|
||||
self,
|
||||
mock_list_providers,
|
||||
):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user and openid social account provider
|
||||
WHEN:
|
||||
- API call is made to get social account providers
|
||||
THEN:
|
||||
- Brands for openid provider are returned
|
||||
"""
|
||||
|
||||
mock_list_providers.return_value = [
|
||||
MockOpenIDProvider(),
|
||||
]
|
||||
|
||||
response = self.client.get(f"{self.ENDPOINT}social_account_providers/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(
|
||||
len(response.data),
|
||||
2,
|
||||
)
|
||||
|
||||
def test_disconnect_social_account(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user
|
||||
WHEN:
|
||||
- API call is made to disconnect a social account
|
||||
THEN:
|
||||
- Social account is deleted from the user or request fails
|
||||
"""
|
||||
self.setupSocialAccount()
|
||||
|
||||
# Test with invalid id
|
||||
response = self.client.post(
|
||||
f"{self.ENDPOINT}disconnect_social_account/",
|
||||
{"id": -1},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Test with valid id
|
||||
social_account_id = self.user.socialaccount_set.all()[0].pk
|
||||
|
||||
response = self.client.post(
|
||||
f"{self.ENDPOINT}disconnect_social_account/",
|
||||
{"id": social_account_id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, social_account_id)
|
||||
|
||||
self.assertEqual(
|
||||
len(self.user.socialaccount_set.filter(pk=social_account_id)),
|
||||
0,
|
||||
)
|
||||
|
@@ -177,9 +177,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
os.path.join(self.dirs.media_dir, "documents"),
|
||||
)
|
||||
|
||||
manifest = self._do_export(use_filename_format=use_filename_format)
|
||||
num_permission_objects = Permission.objects.count()
|
||||
|
||||
self.assertEqual(len(manifest), 190)
|
||||
manifest = self._do_export(use_filename_format=use_filename_format)
|
||||
|
||||
# dont include consumer or AnonymousUser users
|
||||
self.assertEqual(
|
||||
@@ -273,7 +273,7 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
self.assertEqual(Document.objects.get(id=self.d4.id).title, "wow_dec")
|
||||
self.assertEqual(GroupObjectPermission.objects.count(), 1)
|
||||
self.assertEqual(UserObjectPermission.objects.count(), 1)
|
||||
self.assertEqual(Permission.objects.count(), 136)
|
||||
self.assertEqual(Permission.objects.count(), num_permission_objects)
|
||||
messages = check_sanity()
|
||||
# everything is alright after the test
|
||||
self.assertEqual(len(messages), 0)
|
||||
@@ -753,15 +753,15 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
os.path.join(self.dirs.media_dir, "documents"),
|
||||
)
|
||||
|
||||
self.assertEqual(ContentType.objects.count(), 34)
|
||||
self.assertEqual(Permission.objects.count(), 136)
|
||||
num_content_type_objects = ContentType.objects.count()
|
||||
num_permission_objects = Permission.objects.count()
|
||||
|
||||
manifest = self._do_export()
|
||||
|
||||
with paperless_environment():
|
||||
self.assertEqual(
|
||||
len(list(filter(lambda e: e["model"] == "auth.permission", manifest))),
|
||||
136,
|
||||
num_permission_objects,
|
||||
)
|
||||
# add 1 more to db to show objects are not re-created by import
|
||||
Permission.objects.create(
|
||||
@@ -769,7 +769,7 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
codename="test_perm",
|
||||
content_type_id=1,
|
||||
)
|
||||
self.assertEqual(Permission.objects.count(), 137)
|
||||
self.assertEqual(Permission.objects.count(), num_permission_objects + 1)
|
||||
|
||||
# will cause an import error
|
||||
self.user.delete()
|
||||
@@ -778,5 +778,5 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
with self.assertRaises(IntegrityError):
|
||||
call_command("document_importer", "--no-progress-bar", self.target)
|
||||
|
||||
self.assertEqual(ContentType.objects.count(), 34)
|
||||
self.assertEqual(Permission.objects.count(), 137)
|
||||
self.assertEqual(ContentType.objects.count(), num_content_type_objects)
|
||||
self.assertEqual(Permission.objects.count(), num_permission_objects + 1)
|
||||
|
Reference in New Issue
Block a user