mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Enhancement: require totp code for obtain auth token (#8936)
This commit is contained in:
parent
bf368aadd0
commit
c2a9ac332a
@ -3,6 +3,7 @@ import json
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from allauth.mfa.models import Authenticator
|
from allauth.mfa.models import Authenticator
|
||||||
|
from allauth.mfa.totp.internal import auth as totp_auth
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import Permission
|
from django.contrib.auth.models import Permission
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
|
|||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
self.assertEqual(response.data["detail"], "MFA required")
|
self.assertEqual(response.data["detail"], "MFA required")
|
||||||
|
|
||||||
|
@mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
|
||||||
|
def test_get_token_mfa_enabled(self, mock_validate_code):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- User with MFA enabled
|
||||||
|
WHEN:
|
||||||
|
- API request is made to obtain an auth token
|
||||||
|
THEN:
|
||||||
|
- MFA code is required
|
||||||
|
"""
|
||||||
|
user1 = User.objects.create_user(username="user1")
|
||||||
|
user1.set_password("password")
|
||||||
|
user1.save()
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/token/",
|
||||||
|
data={
|
||||||
|
"username": "user1",
|
||||||
|
"password": "password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
secret = totp_auth.generate_totp_secret()
|
||||||
|
totp_auth.TOTP.activate(
|
||||||
|
user1,
|
||||||
|
secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
# no code
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/token/",
|
||||||
|
data={
|
||||||
|
"username": "user1",
|
||||||
|
"password": "password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
|
||||||
|
|
||||||
|
# invalid code
|
||||||
|
mock_validate_code.return_value = False
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/token/",
|
||||||
|
data={
|
||||||
|
"username": "user1",
|
||||||
|
"password": "password",
|
||||||
|
"code": "123456",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
|
||||||
|
|
||||||
|
# valid code
|
||||||
|
mock_validate_code.return_value = True
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/token/",
|
||||||
|
data={
|
||||||
|
"username": "user1",
|
||||||
|
"password": "password",
|
||||||
|
"code": "123456",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
||||||
class TestApiUser(DirectoriesMixin, APITestCase):
|
class TestApiUser(DirectoriesMixin, APITestCase):
|
||||||
ENDPOINT = "/api/users/"
|
ENDPOINT = "/api/users/"
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
|
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
|
||||||
|
from allauth.mfa.models import Authenticator
|
||||||
|
from allauth.mfa.totp.internal.auth import TOTP
|
||||||
from allauth.socialaccount.models import SocialAccount
|
from allauth.socialaccount.models import SocialAccount
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import Permission
|
from django.contrib.auth.models import Permission
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||||
|
|
||||||
from paperless.models import ApplicationConfiguration
|
from paperless.models import ApplicationConfiguration
|
||||||
|
|
||||||
@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PaperlessAuthTokenSerializer(AuthTokenSerializer):
|
||||||
|
code = serializers.CharField(
|
||||||
|
label="MFA Code",
|
||||||
|
write_only=True,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self, attrs):
|
||||||
|
attrs = super().validate(attrs)
|
||||||
|
user = attrs.get("user")
|
||||||
|
code = attrs.get("code")
|
||||||
|
mfa_adapter = get_mfa_adapter()
|
||||||
|
if mfa_adapter.is_mfa_enabled(user):
|
||||||
|
if not code:
|
||||||
|
raise serializers.ValidationError(
|
||||||
|
"MFA code is required",
|
||||||
|
)
|
||||||
|
authenticator = Authenticator.objects.get(
|
||||||
|
user=user,
|
||||||
|
type=Authenticator.Type.TOTP,
|
||||||
|
)
|
||||||
|
if not TOTP(instance=authenticator).validate_code(
|
||||||
|
code,
|
||||||
|
):
|
||||||
|
raise serializers.ValidationError(
|
||||||
|
"Invalid MFA code",
|
||||||
|
)
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
class UserSerializer(serializers.ModelSerializer):
|
class UserSerializer(serializers.ModelSerializer):
|
||||||
password = ObfuscatedUserPasswordField(required=False)
|
password = ObfuscatedUserPasswordField(required=False)
|
||||||
user_permissions = serializers.SlugRelatedField(
|
user_permissions = serializers.SlugRelatedField(
|
||||||
|
@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _
|
|||||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||||
from django.views.generic import RedirectView
|
from django.views.generic import RedirectView
|
||||||
from django.views.static import serve
|
from django.views.static import serve
|
||||||
from rest_framework.authtoken import views
|
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
from documents.views import BulkDownloadView
|
from documents.views import BulkDownloadView
|
||||||
@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView
|
|||||||
from paperless.views import FaviconView
|
from paperless.views import FaviconView
|
||||||
from paperless.views import GenerateAuthTokenView
|
from paperless.views import GenerateAuthTokenView
|
||||||
from paperless.views import GroupViewSet
|
from paperless.views import GroupViewSet
|
||||||
|
from paperless.views import PaperlessObtainAuthTokenView
|
||||||
from paperless.views import ProfileView
|
from paperless.views import ProfileView
|
||||||
from paperless.views import SocialAccountProvidersView
|
from paperless.views import SocialAccountProvidersView
|
||||||
from paperless.views import TOTPView
|
from paperless.views import TOTPView
|
||||||
@ -157,7 +157,7 @@ urlpatterns = [
|
|||||||
),
|
),
|
||||||
path(
|
path(
|
||||||
"token/",
|
"token/",
|
||||||
views.obtain_auth_token,
|
PaperlessObtainAuthTokenView.as_view(),
|
||||||
),
|
),
|
||||||
re_path(
|
re_path(
|
||||||
"^profile/",
|
"^profile/",
|
||||||
|
@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound
|
|||||||
from django.views.generic import View
|
from django.views.generic import View
|
||||||
from django_filters.rest_framework import DjangoFilterBackend
|
from django_filters.rest_framework import DjangoFilterBackend
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
|
from rest_framework.authtoken.views import ObtainAuthToken
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.filters import OrderingFilter
|
from rest_framework.filters import OrderingFilter
|
||||||
from rest_framework.generics import GenericAPIView
|
from rest_framework.generics import GenericAPIView
|
||||||
@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet
|
|||||||
from paperless.models import ApplicationConfiguration
|
from paperless.models import ApplicationConfiguration
|
||||||
from paperless.serialisers import ApplicationConfigurationSerializer
|
from paperless.serialisers import ApplicationConfigurationSerializer
|
||||||
from paperless.serialisers import GroupSerializer
|
from paperless.serialisers import GroupSerializer
|
||||||
|
from paperless.serialisers import PaperlessAuthTokenSerializer
|
||||||
from paperless.serialisers import ProfileSerializer
|
from paperless.serialisers import ProfileSerializer
|
||||||
from paperless.serialisers import UserSerializer
|
from paperless.serialisers import UserSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||||
|
serializer_class = PaperlessAuthTokenSerializer
|
||||||
|
|
||||||
|
|
||||||
class StandardPagination(PageNumberPagination):
|
class StandardPagination(PageNumberPagination):
|
||||||
page_size = 25
|
page_size = 25
|
||||||
page_size_query_param = "page_size"
|
page_size_query_param = "page_size"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user