From 446842ecfc1a73e573ad3ae6e2fb573273ca1d2d Mon Sep 17 00:00:00 2001 From: Michael Shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 12 Dec 2022 13:24:59 -0800 Subject: [PATCH] Document uploads should be owned by user --- .../src/app/services/upload-documents.service.ts | 5 ++++- src/documents/consumer.py | 9 +++++++++ src/documents/serialisers.py | 14 ++++++++++++++ src/documents/tasks.py | 2 ++ src/documents/views.py | 2 ++ 5 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src-ui/src/app/services/upload-documents.service.ts b/src-ui/src/app/services/upload-documents.service.ts index 5e7ef7fbe..3ec93ecf6 100644 --- a/src-ui/src/app/services/upload-documents.service.ts +++ b/src-ui/src/app/services/upload-documents.service.ts @@ -7,6 +7,7 @@ import { } from './consumer-status.service' import { DocumentService } from './rest/document.service' import { Subscription } from 'rxjs' +import { SettingsService } from './settings.service' @Injectable({ providedIn: 'root', @@ -16,7 +17,8 @@ export class UploadDocumentsService { constructor( private documentService: DocumentService, - private consumerStatusService: ConsumerStatusService + private consumerStatusService: ConsumerStatusService, + private settings: SettingsService ) {} uploadFiles(files: NgxFileDropEntry[]) { @@ -26,6 +28,7 @@ export class UploadDocumentsService { fileEntry.file((file: File) => { let formData = new FormData() formData.append('document', file, file.name) + formData.set('owner', this.settings.currentUser.id.toString()) let status = this.consumerStatusService.newFileUpload(file.name) status.message = $localize`Connecting...` diff --git a/src/documents/consumer.py b/src/documents/consumer.py index b46b3a683..ad5a8416a 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -11,6 +11,7 @@ import magic from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.conf import settings +from django.contrib.auth.models import User from django.db import transaction from django.db.models import Q from django.utils import timezone @@ -99,6 +100,7 @@ class Consumer(LoggingMixin): self.override_tag_ids = None self.override_document_type_id = None self.task_id = None + self.owner_id = None self.channel_layer = get_channel_layer() @@ -255,6 +257,7 @@ class Consumer(LoggingMixin): override_tag_ids=None, task_id=None, override_created=None, + override_owner_id=None, ) -> Document: """ Return the document object if it was successfully created. @@ -268,6 +271,7 @@ class Consumer(LoggingMixin): self.override_tag_ids = override_tag_ids self.task_id = task_id or str(uuid.uuid4()) self.override_created = override_created + self.override_owner_id = override_owner_id self._send_progress(0, 100, "STARTING", MESSAGE_NEW_FILE) @@ -526,6 +530,11 @@ class Consumer(LoggingMixin): for tag_id in self.override_tag_ids: document.tags.add(Tag.objects.get(pk=tag_id)) + if self.override_owner_id: + document.owner = User.objects.get( + pk=self.override_owner_id, + ) + def _write(self, storage_type, source, target): with open(source, "rb") as read_file: with open(target, "wb") as write_file: diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 6470beedc..f3ef49f86 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -681,6 +681,14 @@ class PostDocumentSerializer(serializers.Serializer): required=False, ) + owner = serializers.PrimaryKeyRelatedField( + queryset=User.objects.all(), + label="Owner", + allow_null=True, + write_only=True, + required=False, + ) + def validate_document(self, document): document_data = document.file.read() mime_type = magic.from_buffer(document_data, mime=True) @@ -710,6 +718,12 @@ class PostDocumentSerializer(serializers.Serializer): else: return None + def validate_owner(self, owner): + if owner: + return owner.id + else: + return None + class BulkDownloadSerializer(DocumentListSerializer): diff --git a/src/documents/tasks.py b/src/documents/tasks.py index f40513f9c..e1026dea9 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -96,6 +96,7 @@ def consume_file( override_tag_ids=None, task_id=None, override_created=None, + override_owner_id=None, ): path = Path(path).resolve() @@ -184,6 +185,7 @@ def consume_file( override_tag_ids=override_tag_ids, task_id=task_id, override_created=override_created, + override_owner_id=override_owner_id, ) if document: diff --git a/src/documents/views.py b/src/documents/views.py index 3ea6885ac..86d607574 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -637,6 +637,7 @@ class PostDocumentView(GenericAPIView): tag_ids = serializer.validated_data.get("tags") title = serializer.validated_data.get("title") created = serializer.validated_data.get("created") + owner_id = serializer.validated_data.get("owner") t = int(mktime(datetime.now().timetuple())) @@ -662,6 +663,7 @@ class PostDocumentView(GenericAPIView): override_tag_ids=tag_ids, task_id=task_id, override_created=created, + override_owner_id=owner_id, ) return Response("OK")