diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index db282cacd..553669e32 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -74,7 +74,18 @@ class MatchingModelSerializer(serializers.ModelSerializer): return match -class CorrespondentSerializer(MatchingModelSerializer): +class OwnedObjectSerializer(serializers.ModelSerializer): + def __init__(self, *args, **kwargs): + self.user = kwargs.pop("user", None) + return super().__init__(*args, **kwargs) + + def create(self, validated_data): + if self.user and validated_data["owner"] is None: + validated_data["owner"] = self.user + return super().create(validated_data) + + +class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer): last_correspondence = serializers.DateTimeField(read_only=True) @@ -89,10 +100,11 @@ class CorrespondentSerializer(MatchingModelSerializer): "is_insensitive", "document_count", "last_correspondence", + "owner", ) -class DocumentTypeSerializer(MatchingModelSerializer): +class DocumentTypeSerializer(MatchingModelSerializer, OwnedObjectSerializer): class Meta: model = DocumentType fields = ( @@ -103,6 +115,7 @@ class DocumentTypeSerializer(MatchingModelSerializer): "matching_algorithm", "is_insensitive", "document_count", + "owner", ) @@ -153,10 +166,11 @@ class TagSerializerVersion1(MatchingModelSerializer): "is_insensitive", "is_inbox_tag", "document_count", + "owner", ) -class TagSerializer(MatchingModelSerializer): +class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer): def get_text_color(self, obj): try: h = obj.color.lstrip("#") @@ -214,7 +228,7 @@ class StoragePathField(serializers.PrimaryKeyRelatedField): return StoragePath.objects.all() -class DocumentSerializer(DynamicFieldsModelSerializer): +class DocumentSerializer(DynamicFieldsModelSerializer, OwnedObjectSerializer): correspondent = CorrespondentField(allow_null=True) tags = TagsField(many=True) @@ -265,6 +279,7 @@ class DocumentSerializer(DynamicFieldsModelSerializer): "archive_serial_number", "original_file_name", "archived_file_name", + "owner", ) @@ -274,7 +289,7 @@ class SavedViewFilterRuleSerializer(serializers.ModelSerializer): fields = ["rule_type", "value"] -class SavedViewSerializer(serializers.ModelSerializer): +class SavedViewSerializer(OwnedObjectSerializer): filter_rules = SavedViewFilterRuleSerializer(many=True) @@ -289,6 +304,7 @@ class SavedViewSerializer(serializers.ModelSerializer): "sort_field", "sort_reverse", "filter_rules", + "owner", ] def update(self, instance, validated_data): @@ -562,7 +578,7 @@ class BulkDownloadSerializer(DocumentListSerializer): }[compression] -class StoragePathSerializer(MatchingModelSerializer): +class StoragePathSerializer(MatchingModelSerializer, OwnedObjectSerializer): class Meta: model = StoragePath fields = ( @@ -574,6 +590,7 @@ class StoragePathSerializer(MatchingModelSerializer): "matching_algorithm", "is_insensitive", "document_count", + "owner", ) def validate_path(self, path): diff --git a/src/documents/views.py b/src/documents/views.py index 2a8881376..8b01f0be1 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -42,6 +42,7 @@ from rest_framework.exceptions import NotFound from rest_framework.filters import OrderingFilter from rest_framework.filters import SearchFilter from rest_framework.generics import GenericAPIView +from rest_framework.mixins import CreateModelMixin from rest_framework.mixins import DestroyModelMixin from rest_framework.mixins import ListModelMixin from rest_framework.mixins import RetrieveModelMixin @@ -137,7 +138,17 @@ class IndexView(TemplateView): return context -class CorrespondentViewSet(ModelViewSet): +class PassUserMixin(CreateModelMixin): + """ + Pass a user object to serializer + """ + + def get_serializer(self, *args, **kwargs): + kwargs.setdefault("user", self.request.user) + return super().get_serializer(*args, **kwargs) + + +class CorrespondentViewSet(ModelViewSet, PassUserMixin): model = Correspondent queryset = Correspondent.objects.annotate( @@ -163,7 +174,7 @@ class CorrespondentViewSet(ModelViewSet): ) -class TagViewSet(ModelViewSet): +class TagViewSet(ModelViewSet, PassUserMixin): model = Tag queryset = Tag.objects.annotate(document_count=Count("documents")).order_by( @@ -183,7 +194,7 @@ class TagViewSet(ModelViewSet): ordering_fields = ("name", "matching_algorithm", "match", "document_count") -class DocumentTypeViewSet(ModelViewSet): +class DocumentTypeViewSet(ModelViewSet, PassUserMixin): model = DocumentType queryset = DocumentType.objects.annotate( @@ -204,6 +215,7 @@ class DocumentViewSet( DestroyModelMixin, ListModelMixin, GenericViewSet, + PassUserMixin, ): model = Document queryset = Document.objects.all() @@ -551,7 +563,7 @@ class LogViewSet(ViewSet): return Response(self.log_files) -class SavedViewViewSet(ModelViewSet): +class SavedViewViewSet(ModelViewSet, PassUserMixin): model = SavedView queryset = SavedView.objects.all() @@ -824,7 +836,7 @@ class RemoteVersionView(GenericAPIView): ) -class StoragePathViewSet(ModelViewSet): +class StoragePathViewSet(ModelViewSet, PassUserMixin): model = StoragePath queryset = StoragePath.objects.annotate(document_count=Count("documents")).order_by(