Enhancement: require totp code for obtain auth token (#8936)

This commit is contained in:
shamoon 2025-01-29 07:23:44 -08:00
parent bf368aadd0
commit c2a9ac332a
No known key found for this signature in database
4 changed files with 107 additions and 2 deletions

View File

@ -3,6 +3,7 @@ import json
from unittest import mock from unittest import mock
from allauth.mfa.models import Authenticator from allauth.mfa.models import Authenticator
from allauth.mfa.totp.internal import auth as totp_auth
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.contrib.auth.models import User from django.contrib.auth.models import User
@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.data["detail"], "MFA required") self.assertEqual(response.data["detail"], "MFA required")
@mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
def test_get_token_mfa_enabled(self, mock_validate_code):
"""
GIVEN:
- User with MFA enabled
WHEN:
- API request is made to obtain an auth token
THEN:
- MFA code is required
"""
user1 = User.objects.create_user(username="user1")
user1.set_password("password")
user1.save()
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
secret = totp_auth.generate_totp_secret()
totp_auth.TOTP.activate(
user1,
secret,
)
# no code
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
# invalid code
mock_validate_code.return_value = False
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
"code": "123456",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
# valid code
mock_validate_code.return_value = True
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
"code": "123456",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
class TestApiUser(DirectoriesMixin, APITestCase): class TestApiUser(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/users/" ENDPOINT = "/api/users/"

View File

@ -1,11 +1,14 @@
import logging import logging
from allauth.mfa.adapter import get_adapter as get_mfa_adapter from allauth.mfa.adapter import get_adapter as get_mfa_adapter
from allauth.mfa.models import Authenticator
from allauth.mfa.totp.internal.auth import TOTP
from allauth.socialaccount.models import SocialAccount from allauth.socialaccount.models import SocialAccount
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import serializers from rest_framework import serializers
from rest_framework.authtoken.serializers import AuthTokenSerializer
from paperless.models import ApplicationConfiguration from paperless.models import ApplicationConfiguration
@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field):
return data return data
class PaperlessAuthTokenSerializer(AuthTokenSerializer):
code = serializers.CharField(
label="MFA Code",
write_only=True,
required=False,
)
def validate(self, attrs):
attrs = super().validate(attrs)
user = attrs.get("user")
code = attrs.get("code")
mfa_adapter = get_mfa_adapter()
if mfa_adapter.is_mfa_enabled(user):
if not code:
raise serializers.ValidationError(
"MFA code is required",
)
authenticator = Authenticator.objects.get(
user=user,
type=Authenticator.Type.TOTP,
)
if not TOTP(instance=authenticator).validate_code(
code,
):
raise serializers.ValidationError(
"Invalid MFA code",
)
return attrs
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
password = ObfuscatedUserPasswordField(required=False) password = ObfuscatedUserPasswordField(required=False)
user_permissions = serializers.SlugRelatedField( user_permissions = serializers.SlugRelatedField(

View File

@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import RedirectView from django.views.generic import RedirectView
from django.views.static import serve from django.views.static import serve
from rest_framework.authtoken import views
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView from documents.views import BulkDownloadView
@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView
from paperless.views import FaviconView from paperless.views import FaviconView
from paperless.views import GenerateAuthTokenView from paperless.views import GenerateAuthTokenView
from paperless.views import GroupViewSet from paperless.views import GroupViewSet
from paperless.views import PaperlessObtainAuthTokenView
from paperless.views import ProfileView from paperless.views import ProfileView
from paperless.views import SocialAccountProvidersView from paperless.views import SocialAccountProvidersView
from paperless.views import TOTPView from paperless.views import TOTPView
@ -157,7 +157,7 @@ urlpatterns = [
), ),
path( path(
"token/", "token/",
views.obtain_auth_token, PaperlessObtainAuthTokenView.as_view(),
), ),
re_path( re_path(
"^profile/", "^profile/",

View File

@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound
from django.views.generic import View from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView from rest_framework.generics import GenericAPIView
@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet
from paperless.models import ApplicationConfiguration from paperless.models import ApplicationConfiguration
from paperless.serialisers import ApplicationConfigurationSerializer from paperless.serialisers import ApplicationConfigurationSerializer
from paperless.serialisers import GroupSerializer from paperless.serialisers import GroupSerializer
from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer from paperless.serialisers import UserSerializer
class PaperlessObtainAuthTokenView(ObtainAuthToken):
serializer_class = PaperlessAuthTokenSerializer
class StandardPagination(PageNumberPagination): class StandardPagination(PageNumberPagination):
page_size = 25 page_size = 25
page_size_query_param = "page_size" page_size_query_param = "page_size"