diff --git a/src/paperless_mail/serialisers.py b/src/paperless_mail/serialisers.py index 0d15f617c..49d61838a 100644 --- a/src/paperless_mail/serialisers.py +++ b/src/paperless_mail/serialisers.py @@ -1,5 +1,6 @@ from documents.serialisers import CorrespondentField from documents.serialisers import DocumentTypeField +from documents.serialisers import OwnedObjectSerializer from documents.serialisers import TagsField from paperless_mail.models import MailAccount from paperless_mail.models import MailRule @@ -18,7 +19,7 @@ class ObfuscatedPasswordField(serializers.Field): return data -class MailAccountSerializer(serializers.ModelSerializer): +class MailAccountSerializer(OwnedObjectSerializer): password = ObfuscatedPasswordField() class Meta: @@ -42,17 +43,13 @@ class MailAccountSerializer(serializers.ModelSerializer): super().update(instance, validated_data) return instance - def create(self, validated_data): - mail_account = MailAccount.objects.create(**validated_data) - return mail_account - class AccountField(serializers.PrimaryKeyRelatedField): def get_queryset(self): return MailAccount.objects.all().order_by("-id") -class MailRuleSerializer(serializers.ModelSerializer): +class MailRuleSerializer(OwnedObjectSerializer): account = AccountField(required=True) action_parameter = serializers.CharField( allow_null=True, @@ -96,7 +93,7 @@ class MailRuleSerializer(serializers.ModelSerializer): def create(self, validated_data): if "assign_tags" in validated_data: assign_tags = validated_data.pop("assign_tags") - mail_rule = MailRule.objects.create(**validated_data) + mail_rule = super().create(validated_data) if assign_tags: mail_rule.assign_tags.set(assign_tags) return mail_rule diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index d86240c7c..edc58ac89 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -1,3 +1,4 @@ +from documents.views import PassUserMixin from paperless.views import StandardPagination from paperless_mail.models import MailAccount from paperless_mail.models import MailRule @@ -7,7 +8,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.viewsets import ModelViewSet -class MailAccountViewSet(ModelViewSet): +class MailAccountViewSet(ModelViewSet, PassUserMixin): model = MailAccount queryset = MailAccount.objects.all().order_by("pk") @@ -15,27 +16,11 @@ class MailAccountViewSet(ModelViewSet): pagination_class = StandardPagination permission_classes = (IsAuthenticated,) - # TODO: user-scoped - # def get_queryset(self): - # user = self.request.user - # return MailAccount.objects.filter(user=user) - # def perform_create(self, serializer): - # serializer.save(user=self.request.user) - - -class MailRuleViewSet(ModelViewSet): +class MailRuleViewSet(ModelViewSet, PassUserMixin): model = MailRule queryset = MailRule.objects.all().order_by("pk") serializer_class = MailRuleSerializer pagination_class = StandardPagination permission_classes = (IsAuthenticated,) - - # TODO: user-scoped - # def get_queryset(self): - # user = self.request.user - # return MailRule.objects.filter(user=user) - - # def perform_create(self, serializer): - # serializer.save(user=self.request.user)