diff --git a/docs/api.md b/docs/api.md index 39a06f37f..97ccf4c3a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -274,6 +274,7 @@ The endpoint supports the following optional form fields: - `correspondent`: Specify the ID of a correspondent that the consumer should use for the document. - `document_type`: Similar to correspondent. +- `storage_path`: Similar to correspondent. - `tags`: Similar to correspondent. Specify this multiple times to have multiple tags added to the document. - `archive_serial_number`: An optional archive serial number to set. diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index b1dd9aee9..41d3139b3 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -966,6 +966,14 @@ class PostDocumentSerializer(serializers.Serializer): required=False, ) + storage_path = serializers.PrimaryKeyRelatedField( + queryset=StoragePath.objects.all(), + label="Storage path", + allow_null=True, + write_only=True, + required=False, + ) + tags = serializers.PrimaryKeyRelatedField( many=True, queryset=Tag.objects.all(), @@ -1005,6 +1013,12 @@ class PostDocumentSerializer(serializers.Serializer): else: return None + def validate_storage_path(self, storage_path): + if storage_path: + return storage_path.id + else: + return None + def validate_tags(self, tags): if tags: return [tag.id for tag in tags] diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 8415b9a71..f19711a96 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -819,7 +819,13 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) as f: response = self.client.post( "/api/documents/post_document/", - {"document": f, "title": "", "correspondent": "", "document_type": ""}, + { + "document": f, + "title": "", + "correspondent": "", + "document_type": "", + "storage_path": "", + }, ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -833,6 +839,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertIsNone(overrides.title) self.assertIsNone(overrides.correspondent_id) self.assertIsNone(overrides.document_type_id) + self.assertIsNone(overrides.storage_path_id) self.assertIsNone(overrides.tag_ids) def test_upload_invalid_form(self): @@ -975,6 +982,48 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.consume_file_mock.assert_not_called() + def test_upload_with_storage_path(self): + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) + + sp = StoragePath.objects.create(name="invoices") + with open( + os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), + "rb", + ) as f: + response = self.client.post( + "/api/documents/post_document/", + {"document": f, "storage_path": sp.id}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.consume_file_mock.assert_called_once() + + _, overrides = self.get_last_consume_delay_call_args() + + self.assertEqual(overrides.storage_path_id, sp.id) + self.assertIsNone(overrides.correspondent_id) + self.assertIsNone(overrides.title) + self.assertIsNone(overrides.tag_ids) + + def test_upload_with_invalid_storage_path(self): + self.consume_file_mock.return_value = celery.result.AsyncResult( + id=str(uuid.uuid4()), + ) + + with open( + os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), + "rb", + ) as f: + response = self.client.post( + "/api/documents/post_document/", + {"document": f, "storage_path": 34578}, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.consume_file_mock.assert_not_called() + def test_upload_with_tags(self): self.consume_file_mock.return_value = celery.result.AsyncResult( id=str(uuid.uuid4()), diff --git a/src/documents/views.py b/src/documents/views.py index 84633cc03..83f2fc321 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -850,6 +850,7 @@ class PostDocumentView(GenericAPIView): doc_name, doc_data = serializer.validated_data.get("document") correspondent_id = serializer.validated_data.get("correspondent") document_type_id = serializer.validated_data.get("document_type") + storage_path_id = serializer.validated_data.get("storage_path") tag_ids = serializer.validated_data.get("tags") title = serializer.validated_data.get("title") created = serializer.validated_data.get("created") @@ -876,6 +877,7 @@ class PostDocumentView(GenericAPIView): title=title, correspondent_id=correspondent_id, document_type_id=document_type_id, + storage_path_id=storage_path_id, tag_ids=tag_ids, created=created, asn=archive_serial_number,