fix python tests for user object perms

This commit is contained in:
Michael Shamoon 2022-12-06 20:14:33 -08:00
parent 18e0012a59
commit 2973e4672a
6 changed files with 56 additions and 39 deletions

View File

@ -19,10 +19,13 @@ class PaperlessObjectPermissions(DjangoObjectPermissions):
} }
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
if hasattr(obj, "owner") and request.user == obj.owner: if hasattr(obj, "owner") and obj.owner is not None:
return True if request.user == obj.owner:
return True
else:
return super().has_object_permission(request, view, obj)
else: else:
return super().has_object_permission(request, view, obj) return True # no owner
class PaperlessAdminPermissions(BasePermission): class PaperlessAdminPermissions(BasePermission):

View File

@ -88,8 +88,8 @@ class OwnedObjectSerializer(serializers.ModelSerializer):
user_object_perms = UserObjectPermission.objects.filter( user_object_perms = UserObjectPermission.objects.filter(
object_pk=obj.pk, object_pk=obj.pk,
content_type=content_type, content_type=content_type,
).values("user", "permission__codename") ).values_list("user", "permission__codename")
return user_object_perms return list(user_object_perms)
permissions = SerializerMethodField() permissions = SerializerMethodField()
@ -164,7 +164,9 @@ class OwnedObjectSerializer(serializers.ModelSerializer):
) )
def create(self, validated_data): def create(self, validated_data):
if self.user and validated_data["owner"] is None: if self.user and (
"owner" not in validated_data or validated_data["owner"] is None
):
validated_data["owner"] = self.user validated_data["owner"] = self.user
instance = super().create(validated_data) instance = super().create(validated_data)
if "grant_permissions" in validated_data: if "grant_permissions" in validated_data:
@ -306,10 +308,6 @@ class TagSerializerVersion1(MatchingModelSerializer):
"is_insensitive", "is_insensitive",
"is_inbox_tag", "is_inbox_tag",
"document_count", "document_count",
"owner",
"permissions",
"grant_permissions",
"revoke_permissions",
) )
@ -342,6 +340,10 @@ class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
"is_insensitive", "is_insensitive",
"is_inbox_tag", "is_inbox_tag",
"document_count", "document_count",
"owner",
"permissions",
"grant_permissions",
"revoke_permissions",
) )
def validate_color(self, color): def validate_color(self, color):
@ -461,6 +463,9 @@ class SavedViewSerializer(OwnedObjectSerializer):
rules_data = validated_data.pop("filter_rules") rules_data = validated_data.pop("filter_rules")
else: else:
rules_data = None rules_data = None
if "user" in validated_data:
# backwards compatibility
validated_data["owner"] = validated_data.pop("user")
super().update(instance, validated_data) super().update(instance, validated_data)
if rules_data is not None: if rules_data is not None:
SavedViewFilterRule.objects.filter(saved_view=instance).delete() SavedViewFilterRule.objects.filter(saved_view=instance).delete()
@ -470,6 +475,9 @@ class SavedViewSerializer(OwnedObjectSerializer):
def create(self, validated_data): def create(self, validated_data):
rules_data = validated_data.pop("filter_rules") rules_data = validated_data.pop("filter_rules")
if "user" in validated_data:
# backwards compatibility
validated_data["owner"] = validated_data.pop("user")
saved_view = SavedView.objects.create(**validated_data) saved_view = SavedView.objects.create(**validated_data)
for rule_data in rules_data: for rule_data in rules_data:
SavedViewFilterRule.objects.create(saved_view=saved_view, **rule_data) SavedViewFilterRule.objects.create(saved_view=saved_view, **rule_data)

View File

@ -1158,21 +1158,21 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
u2 = User.objects.create_superuser("user2") u2 = User.objects.create_superuser("user2")
v1 = SavedView.objects.create( v1 = SavedView.objects.create(
user=u1, owner=u1,
name="test1", name="test1",
sort_field="", sort_field="",
show_on_dashboard=False, show_on_dashboard=False,
show_in_sidebar=False, show_in_sidebar=False,
) )
v2 = SavedView.objects.create( v2 = SavedView.objects.create(
user=u2, owner=u2,
name="test2", name="test2",
sort_field="", sort_field="",
show_on_dashboard=False, show_on_dashboard=False,
show_in_sidebar=False, show_in_sidebar=False,
) )
v3 = SavedView.objects.create( v3 = SavedView.objects.create(
user=u2, owner=u2,
name="test3", name="test3",
sort_field="", sort_field="",
show_on_dashboard=False, show_on_dashboard=False,
@ -1219,7 +1219,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
v1 = SavedView.objects.get(name="test") v1 = SavedView.objects.get(name="test")
self.assertEqual(v1.sort_field, "created2") self.assertEqual(v1.sort_field, "created2")
self.assertEqual(v1.filter_rules.count(), 1) self.assertEqual(v1.filter_rules.count(), 1)
self.assertEqual(v1.user, self.user) self.assertEqual(v1.owner, self.user)
response = self.client.patch( response = self.client.patch(
f"/api/saved_views/{v1.id}/", f"/api/saved_views/{v1.id}/",
@ -3015,17 +3015,13 @@ class TestApiUser(APITestCase):
response = self.client.get(self.ENDPOINT) response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 2) self.assertEqual(response.data["count"], 3) # AnonymousUser
returned_user1 = response.data["results"][1] returned_user2 = response.data["results"][2]
from pprint import pprint self.assertEqual(returned_user2["username"], user1.username)
self.assertEqual(returned_user2["password"], "**********")
pprint(returned_user1) self.assertEqual(returned_user2["first_name"], user1.first_name)
self.assertEqual(returned_user2["last_name"], user1.last_name)
self.assertEqual(returned_user1["username"], user1.username)
self.assertEqual(returned_user1["password"], "**********")
self.assertEqual(returned_user1["first_name"], user1.first_name)
self.assertEqual(returned_user1["last_name"], user1.last_name)
def test_create_user(self): def test_create_user(self):
""" """

View File

@ -124,7 +124,7 @@ class TestExportImport(DirectoriesMixin, TestCase):
manifest = self._do_export(use_filename_format=use_filename_format) manifest = self._do_export(use_filename_format=use_filename_format)
self.assertEqual(len(manifest), 11) self.assertEqual(len(manifest), 12)
self.assertEqual( self.assertEqual(
len(list(filter(lambda e: e["model"] == "documents.document", manifest))), len(list(filter(lambda e: e["model"] == "documents.document", manifest))),
4, 4,

View File

@ -31,8 +31,8 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={}) out = self.call_command(environ={})
# just the consumer user which is created # just the consumer user which is created
# during migration # during migration, and AnonymousUser
self.assertEqual(User.objects.count(), 1) self.assertEqual(User.objects.count(), 2)
self.assertTrue(User.objects.filter(username="consumer").exists()) self.assertTrue(User.objects.filter(username="consumer").exists())
self.assertEqual(User.objects.filter(is_superuser=True).count(), 0) self.assertEqual(User.objects.filter(is_superuser=True).count(), 0)
self.assertEqual( self.assertEqual(
@ -50,10 +50,10 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"}) out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
# count is 2 as there's the consumer # count is 3 as there's the consumer
# user already created during migration # user already created during migration, and AnonymousUser
user: User = User.objects.get_by_natural_key("admin") user: User = User.objects.get_by_natural_key("admin")
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser) self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "root@localhost") self.assertEqual(user.email, "root@localhost")
self.assertEqual(out, 'Created superuser "admin" with provided password.\n') self.assertEqual(out, 'Created superuser "admin" with provided password.\n')
@ -70,7 +70,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"}) out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
with self.assertRaises(User.DoesNotExist): with self.assertRaises(User.DoesNotExist):
User.objects.get_by_natural_key("admin") User.objects.get_by_natural_key("admin")
self.assertEqual( self.assertEqual(
@ -91,7 +91,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"}) out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
user: User = User.objects.get_by_natural_key("admin") user: User = User.objects.get_by_natural_key("admin")
self.assertTrue(user.check_password("password")) self.assertTrue(user.check_password("password"))
self.assertEqual(out, "Did not create superuser, a user admin already exists\n") self.assertEqual(out, "Did not create superuser, a user admin already exists\n")
@ -110,7 +110,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"}) out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
user: User = User.objects.get_by_natural_key("admin") user: User = User.objects.get_by_natural_key("admin")
self.assertTrue(user.check_password("password")) self.assertTrue(user.check_password("password"))
self.assertFalse(user.is_superuser) self.assertFalse(user.is_superuser)
@ -149,7 +149,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
) )
user: User = User.objects.get_by_natural_key("admin") user: User = User.objects.get_by_natural_key("admin")
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser) self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "hello@world.com") self.assertEqual(user.email, "hello@world.com")
self.assertEqual(user.username, "admin") self.assertEqual(user.username, "admin")
@ -173,7 +173,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
) )
user: User = User.objects.get_by_natural_key("super") user: User = User.objects.get_by_natural_key("super")
self.assertEqual(User.objects.count(), 2) self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser) self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "hello@world.com") self.assertEqual(user.email, "hello@world.com")
self.assertEqual(user.username, "super") self.assertEqual(user.username, "super")

View File

@ -174,14 +174,16 @@ class CorrespondentViewSet(ModelViewSet, PassUserMixin):
) )
class TagViewSet(ModelViewSet, PassUserMixin): class TagViewSet(ModelViewSet):
model = Tag model = Tag
queryset = Tag.objects.annotate(document_count=Count("documents")).order_by( queryset = Tag.objects.annotate(document_count=Count("documents")).order_by(
Lower("name"), Lower("name"),
) )
def get_serializer_class(self): def get_serializer_class(self, *args, **kwargs):
# from UserPassMixin
kwargs.setdefault("user", self.request.user)
if int(self.request.version) == 1: if int(self.request.version) == 1:
return TagSerializerVersion1 return TagSerializerVersion1
else: else:
@ -189,7 +191,11 @@ class TagViewSet(ModelViewSet, PassUserMixin):
pagination_class = StandardPagination pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions) permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrandtedPermissionsFilter,
)
filterset_class = TagFilterSet filterset_class = TagFilterSet
ordering_fields = ("name", "matching_algorithm", "match", "document_count") ordering_fields = ("name", "matching_algorithm", "match", "document_count")
@ -204,7 +210,11 @@ class DocumentTypeViewSet(ModelViewSet, PassUserMixin):
serializer_class = DocumentTypeSerializer serializer_class = DocumentTypeSerializer
pagination_class = StandardPagination pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions) permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrandtedPermissionsFilter,
)
filterset_class = DocumentTypeFilterSet filterset_class = DocumentTypeFilterSet
ordering_fields = ("name", "matching_algorithm", "match", "document_count") ordering_fields = ("name", "matching_algorithm", "match", "document_count")