From 5e3ee3a80dccb5eb0e0da488f39a36c5b7f95485 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 17 Jan 2025 19:51:53 -0800 Subject: [PATCH] Fix: disable API basic auth if MFA enabled (#8792) --- src/documents/tests/test_api_permissions.py | 26 +++++++++++++++++++++ src/paperless/auth.py | 13 +++++++++++ src/paperless/settings.py | 2 +- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/documents/tests/test_api_permissions.py b/src/documents/tests/test_api_permissions.py index eeea830cb..ef50c55f7 100644 --- a/src/documents/tests/test_api_permissions.py +++ b/src/documents/tests/test_api_permissions.py @@ -1,4 +1,6 @@ +import base64 import json +from unittest import mock from allauth.mfa.models import Authenticator from django.contrib.auth.models import Group @@ -462,6 +464,30 @@ class TestApiAuth(DirectoriesMixin, APITestCase): self.assertNotIn("user_can_change", results[0]) self.assertNotIn("is_shared_by_requester", results[0]) + @mock.patch("allauth.mfa.adapter.DefaultMFAAdapter.is_mfa_enabled") + def test_basic_auth_mfa_enabled(self, mock_is_mfa_enabled): + """ + GIVEN: + - User with MFA enabled + WHEN: + - API request is made with basic auth + THEN: + - MFA required error is returned + """ + user1 = User.objects.create_user(username="user1") + user1.set_password("password") + user1.save() + + mock_is_mfa_enabled.return_value = True + + response = self.client.get( + "/api/documents/", + HTTP_AUTHORIZATION="Basic " + base64.b64encode(b"user1:password").decode(), + ) + + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response.data["detail"], "MFA required") + class TestApiUser(DirectoriesMixin, APITestCase): ENDPOINT = "/api/users/" diff --git a/src/paperless/auth.py b/src/paperless/auth.py index 6ca97d608..36131847f 100644 --- a/src/paperless/auth.py +++ b/src/paperless/auth.py @@ -1,5 +1,6 @@ import logging +from allauth.mfa.adapter import get_adapter as get_mfa_adapter from django.conf import settings from django.contrib import auth from django.contrib.auth.middleware import PersistentRemoteUserMiddleware @@ -7,6 +8,7 @@ from django.contrib.auth.models import User from django.http import HttpRequest from django.utils.deprecation import MiddlewareMixin from rest_framework import authentication +from rest_framework import exceptions logger = logging.getLogger("paperless.auth") @@ -70,3 +72,14 @@ class PaperlessRemoteUserAuthentication(authentication.RemoteUserAuthentication) """ header = settings.HTTP_REMOTE_USER_HEADER_NAME + + +class PaperlessBasicAuthentication(authentication.BasicAuthentication): + def authenticate(self, request): + user_tuple = super().authenticate(request) + user = user_tuple[0] if user_tuple else None + mfa_adapter = get_mfa_adapter() + if user and mfa_adapter.is_mfa_enabled(user): + raise exceptions.AuthenticationFailed("MFA required") + + return user_tuple diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 3fc9bfdbf..ef842dde6 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -336,7 +336,7 @@ if DEBUG: REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": [ - "rest_framework.authentication.BasicAuthentication", + "paperless.auth.PaperlessBasicAuthentication", "rest_framework.authentication.TokenAuthentication", "rest_framework.authentication.SessionAuthentication", ],