diff --git a/src/documents/tests/test_api_permissions.py b/src/documents/tests/test_api_permissions.py index 5de1887b2..3785c8f2a 100644 --- a/src/documents/tests/test_api_permissions.py +++ b/src/documents/tests/test_api_permissions.py @@ -3,6 +3,7 @@ import json from unittest import mock from allauth.mfa.models import Authenticator +from allauth.mfa.totp.internal import auth as totp_auth from django.contrib.auth.models import Group from django.contrib.auth.models import Permission from django.contrib.auth.models import User @@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.data["detail"], "MFA required") + @mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code") + def test_get_token_mfa_enabled(self, mock_validate_code): + """ + GIVEN: + - User with MFA enabled + WHEN: + - API request is made to obtain an auth token + THEN: + - MFA code is required + """ + user1 = User.objects.create_user(username="user1") + user1.set_password("password") + user1.save() + + response = self.client.post( + "/api/token/", + data={ + "username": "user1", + "password": "password", + }, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + secret = totp_auth.generate_totp_secret() + totp_auth.TOTP.activate( + user1, + secret, + ) + + # no code + response = self.client.post( + "/api/token/", + data={ + "username": "user1", + "password": "password", + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data["non_field_errors"][0], "MFA code is required") + + # invalid code + mock_validate_code.return_value = False + response = self.client.post( + "/api/token/", + data={ + "username": "user1", + "password": "password", + "code": "123456", + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code") + + # valid code + mock_validate_code.return_value = True + response = self.client.post( + "/api/token/", + data={ + "username": "user1", + "password": "password", + "code": "123456", + }, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + class TestApiUser(DirectoriesMixin, APITestCase): ENDPOINT = "/api/users/" diff --git a/src/paperless/serialisers.py b/src/paperless/serialisers.py index d5acfe465..fb1f511f7 100644 --- a/src/paperless/serialisers.py +++ b/src/paperless/serialisers.py @@ -1,11 +1,14 @@ import logging from allauth.mfa.adapter import get_adapter as get_mfa_adapter +from allauth.mfa.models import Authenticator +from allauth.mfa.totp.internal.auth import TOTP from allauth.socialaccount.models import SocialAccount from django.contrib.auth.models import Group from django.contrib.auth.models import Permission from django.contrib.auth.models import User from rest_framework import serializers +from rest_framework.authtoken.serializers import AuthTokenSerializer from paperless.models import ApplicationConfiguration @@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field): return data +class PaperlessAuthTokenSerializer(AuthTokenSerializer): + code = serializers.CharField( + label="MFA Code", + write_only=True, + required=False, + ) + + def validate(self, attrs): + attrs = super().validate(attrs) + user = attrs.get("user") + code = attrs.get("code") + mfa_adapter = get_mfa_adapter() + if mfa_adapter.is_mfa_enabled(user): + if not code: + raise serializers.ValidationError( + "MFA code is required", + ) + authenticator = Authenticator.objects.get( + user=user, + type=Authenticator.Type.TOTP, + ) + if not TOTP(instance=authenticator).validate_code( + code, + ): + raise serializers.ValidationError( + "Invalid MFA code", + ) + return attrs + + class UserSerializer(serializers.ModelSerializer): password = ObfuscatedUserPasswordField(required=False) user_permissions = serializers.SlugRelatedField( diff --git a/src/paperless/urls.py b/src/paperless/urls.py index c528c5e2a..703a72042 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _ from django.views.decorators.csrf import ensure_csrf_cookie from django.views.generic import RedirectView from django.views.static import serve -from rest_framework.authtoken import views from rest_framework.routers import DefaultRouter from documents.views import BulkDownloadView @@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView from paperless.views import FaviconView from paperless.views import GenerateAuthTokenView from paperless.views import GroupViewSet +from paperless.views import PaperlessObtainAuthTokenView from paperless.views import ProfileView from paperless.views import SocialAccountProvidersView from paperless.views import TOTPView @@ -157,7 +157,7 @@ urlpatterns = [ ), path( "token/", - views.obtain_auth_token, + PaperlessObtainAuthTokenView.as_view(), ), re_path( "^profile/", diff --git a/src/paperless/views.py b/src/paperless/views.py index 03721adf2..bcabd182f 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound from django.views.generic import View from django_filters.rest_framework import DjangoFilterBackend from rest_framework.authtoken.models import Token +from rest_framework.authtoken.views import ObtainAuthToken from rest_framework.decorators import action from rest_framework.filters import OrderingFilter from rest_framework.generics import GenericAPIView @@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet from paperless.models import ApplicationConfiguration from paperless.serialisers import ApplicationConfigurationSerializer from paperless.serialisers import GroupSerializer +from paperless.serialisers import PaperlessAuthTokenSerializer from paperless.serialisers import ProfileSerializer from paperless.serialisers import UserSerializer +class PaperlessObtainAuthTokenView(ObtainAuthToken): + serializer_class = PaperlessAuthTokenSerializer + + class StandardPagination(PageNumberPagination): page_size = 25 page_size_query_param = "page_size"