mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Feature: update user profile (#4678)
This commit is contained in:
105
src/documents/tests/test_api_profile.py
Normal file
105
src/documents/tests/test_api_profile.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework import status
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
class TestApiProfile(DirectoriesMixin, APITestCase):
|
||||
ENDPOINT = "/api/profile/"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.user = User.objects.create_superuser(
|
||||
username="temp_admin",
|
||||
first_name="firstname",
|
||||
last_name="surname",
|
||||
)
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
def test_get_profile(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user
|
||||
WHEN:
|
||||
- API call is made to get profile
|
||||
THEN:
|
||||
- Profile is returned
|
||||
"""
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
self.assertEqual(response.data["email"], self.user.email)
|
||||
self.assertEqual(response.data["first_name"], self.user.first_name)
|
||||
self.assertEqual(response.data["last_name"], self.user.last_name)
|
||||
|
||||
def test_update_profile(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user
|
||||
WHEN:
|
||||
- API call is made to update profile
|
||||
THEN:
|
||||
- Profile is updated
|
||||
"""
|
||||
|
||||
user_data = {
|
||||
"email": "new@email.com",
|
||||
"password": "superpassword1234",
|
||||
"first_name": "new first name",
|
||||
"last_name": "new last name",
|
||||
}
|
||||
response = self.client.patch(self.ENDPOINT, user_data)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
user = User.objects.get(username=self.user.username)
|
||||
self.assertTrue(user.check_password(user_data["password"]))
|
||||
self.assertEqual(user.email, user_data["email"])
|
||||
self.assertEqual(user.first_name, user_data["first_name"])
|
||||
self.assertEqual(user.last_name, user_data["last_name"])
|
||||
|
||||
def test_update_auth_token(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Configured user
|
||||
WHEN:
|
||||
- API call is made to generate auth token
|
||||
THEN:
|
||||
- Token is created the first time, updated the second
|
||||
"""
|
||||
|
||||
self.assertEqual(len(Token.objects.all()), 0)
|
||||
|
||||
response = self.client.post(f"{self.ENDPOINT}generate_auth_token/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
token1 = Token.objects.filter(user=self.user).first()
|
||||
self.assertIsNotNone(token1)
|
||||
|
||||
response = self.client.post(f"{self.ENDPOINT}generate_auth_token/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
token2 = Token.objects.filter(user=self.user).first()
|
||||
|
||||
self.assertNotEqual(token1.key, token2.key)
|
||||
|
||||
def test_profile_not_logged_in(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- User not logged in
|
||||
WHEN:
|
||||
- API call is made to get profile and update token
|
||||
THEN:
|
||||
- Profile is returned
|
||||
"""
|
||||
|
||||
self.client.logout()
|
||||
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
response = self.client.post(f"{self.ENDPOINT}generate_auth_token/")
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
@@ -97,3 +97,19 @@ class GroupSerializer(serializers.ModelSerializer):
|
||||
"name",
|
||||
"permissions",
|
||||
)
|
||||
|
||||
|
||||
class ProfileSerializer(serializers.ModelSerializer):
|
||||
email = serializers.EmailField(allow_null=False)
|
||||
password = ObfuscatedUserPasswordField(required=False, allow_null=False)
|
||||
auth_token = serializers.SlugRelatedField(read_only=True, slug_field="key")
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = (
|
||||
"email",
|
||||
"password",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"auth_token",
|
||||
)
|
||||
|
@@ -35,7 +35,9 @@ from documents.views import UiSettingsView
|
||||
from documents.views import UnifiedSearchViewSet
|
||||
from paperless.consumers import StatusConsumer
|
||||
from paperless.views import FaviconView
|
||||
from paperless.views import GenerateAuthTokenView
|
||||
from paperless.views import GroupViewSet
|
||||
from paperless.views import ProfileView
|
||||
from paperless.views import UserViewSet
|
||||
from paperless_mail.views import MailAccountTestView
|
||||
from paperless_mail.views import MailAccountViewSet
|
||||
@@ -119,6 +121,12 @@ urlpatterns = [
|
||||
BulkEditObjectPermissionsView.as_view(),
|
||||
name="bulk_edit_object_permissions",
|
||||
),
|
||||
path("profile/generate_auth_token/", GenerateAuthTokenView.as_view()),
|
||||
re_path(
|
||||
"^profile/",
|
||||
ProfileView.as_view(),
|
||||
name="profile_view",
|
||||
),
|
||||
*api_router.urls,
|
||||
],
|
||||
),
|
||||
|
@@ -7,7 +7,9 @@ from django.db.models.functions import Lower
|
||||
from django.http import HttpResponse
|
||||
from django.views.generic import View
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
@@ -17,6 +19,7 @@ from documents.permissions import PaperlessObjectPermissions
|
||||
from paperless.filters import GroupFilterSet
|
||||
from paperless.filters import UserFilterSet
|
||||
from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import ProfileSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
|
||||
|
||||
@@ -106,3 +109,54 @@ class GroupViewSet(ModelViewSet):
|
||||
filter_backends = (DjangoFilterBackend, OrderingFilter)
|
||||
filterset_class = GroupFilterSet
|
||||
ordering_fields = ("name",)
|
||||
|
||||
|
||||
class ProfileView(GenericAPIView):
|
||||
"""
|
||||
User profile view, only available when logged in
|
||||
"""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
serializer_class = ProfileSerializer
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
user = self.request.user
|
||||
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
return Response(serializer.to_representation(user))
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
user = self.request.user if hasattr(self.request, "user") else None
|
||||
|
||||
if len(serializer.validated_data.get("password").replace("*", "")) > 0:
|
||||
user.set_password(serializer.validated_data.get("password"))
|
||||
user.save()
|
||||
serializer.validated_data.pop("password")
|
||||
|
||||
for key, value in serializer.validated_data.items():
|
||||
setattr(user, key, value)
|
||||
user.save()
|
||||
|
||||
return Response(serializer.to_representation(user))
|
||||
|
||||
|
||||
class GenerateAuthTokenView(GenericAPIView):
|
||||
"""
|
||||
Generates (or re-generates) an auth token, requires a logged in user
|
||||
unlike the default DRF endpoint
|
||||
"""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
user = self.request.user
|
||||
|
||||
existing_token = Token.objects.filter(user=user).first()
|
||||
if existing_token is not None:
|
||||
existing_token.delete()
|
||||
token = Token.objects.create(user=user)
|
||||
return Response(
|
||||
token.key,
|
||||
)
|
||||
|
Reference in New Issue
Block a user