diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 4bf9ab89b..c05072a0b 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -806,13 +806,18 @@ class ConsumerPlugin( } set_permissions_for_object(permissions=permissions, object=document) - if self.metadata.custom_field_ids: - for field_id in self.metadata.custom_field_ids: + if self.metadata.custom_fields: + for field_id in self.metadata.custom_fields: field = CustomField.objects.get(pk=field_id) - CustomFieldInstance.objects.create( - field=field, - document=document, - ) # adds to document + value_field_name = CustomFieldInstance.get_value_field_name( + data_type=field.data_type, + ) + args = { + "field": field, + "document": document, + value_field_name: self.metadata.custom_fields[field_id], + } + CustomFieldInstance.objects.create(**args) # adds to document def _write(self, storage_type, source, target): with ( diff --git a/src/documents/data_models.py b/src/documents/data_models.py index 406fe6b5a..e56683515 100644 --- a/src/documents/data_models.py +++ b/src/documents/data_models.py @@ -29,7 +29,7 @@ class DocumentMetadataOverrides: view_groups: list[int] | None = None change_users: list[int] | None = None change_groups: list[int] | None = None - custom_field_ids: list[int] | None = None + custom_fields: dict | None = None def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides": """ @@ -81,11 +81,10 @@ class DocumentMetadataOverrides: self.change_groups.extend(other.change_groups) self.change_groups = list(set(self.change_groups)) - if self.custom_field_ids is None: - self.custom_field_ids = other.custom_field_ids - elif other.custom_field_ids is not None: - self.custom_field_ids.extend(other.custom_field_ids) - self.custom_field_ids = list(set(self.custom_field_ids)) + if self.custom_fields is None: + self.custom_fields = other.custom_fields + elif other.custom_fields is not None: + self.custom_fields.update(other.custom_fields) return self @@ -114,9 +113,13 @@ class DocumentMetadataOverrides: only_with_perms_in=["change_document"], ).values_list("id", flat=True), ) - overrides.custom_field_ids = list( - doc.custom_fields.values_list("field", flat=True), - ) + overrides.custom_fields = { + custom_field.id: value + for custom_field, value in doc.custom_fields.all().values_list( + "id", + "value", + ) + } groups_with_perms = get_groups_with_perms( doc, diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 3fe540ac6..337d6020e 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -769,26 +769,29 @@ def run_workflows( ), ) - if action.assign_custom_fields.exists(): + if action.assign_custom_fields_w_values: if not use_overrides: - for field in action.assign_custom_fields.all(): + for field_id in action.assign_custom_fields_w_values: if not CustomFieldInstance.objects.filter( - field=field, + field_id=field_id, document=document, ).exists(): # can be triggered on existing docs, so only add the field if it doesn't already exist - CustomFieldInstance.objects.create( - field=field, - document=document, + field = CustomField.objects.get(pk=field_id) + value_field_name = CustomFieldInstance.get_value_field_name( + data_type=field.data_type, ) + args = { + "field": field, + "document": document, + value_field_name: action.assign_custom_fields_w_values[ + field_id + ], + } + CustomFieldInstance.objects.create(**args) else: - overrides.custom_field_ids = list( - set( - (overrides.custom_field_ids or []) - + list( - action.assign_custom_fields.values_list("pk", flat=True), - ), - ), + overrides.custom_fields.update( + action.assign_custom_fields_w_values, ) def removal_action(): @@ -946,18 +949,18 @@ def run_workflows( if not use_overrides: CustomFieldInstance.objects.filter(document=document).delete() else: - overrides.custom_field_ids = None + overrides.custom_fields = None elif action.remove_custom_fields.exists(): if not use_overrides: CustomFieldInstance.objects.filter( field__in=action.remove_custom_fields.all(), document=document, ).delete() - elif overrides.custom_field_ids: + elif overrides.custom_fields: for field in action.remove_custom_fields.filter( - pk__in=overrides.custom_field_ids, + pk__in=overrides.custom_fields.keys(), ): - overrides.custom_field_ids.remove(field.pk) + overrides.custom_fields.pop(field.pk, None) def email_action(): if not settings.EMAIL_ENABLED: diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index cd923b281..624317d38 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -1362,7 +1362,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertEqual(overrides.filename, "simple.pdf") - self.assertEqual(overrides.custom_field_ids, [custom_field.id]) + self.assertEqual(overrides.custom_fields, {custom_field.id: None}) def test_upload_with_webui_source(self): """ diff --git a/src/documents/tests/test_consumer.py b/src/documents/tests/test_consumer.py index ff684804e..96afa61d3 100644 --- a/src/documents/tests/test_consumer.py +++ b/src/documents/tests/test_consumer.py @@ -408,7 +408,9 @@ class TestConsumer( with self.get_consumer( self.get_test_file(), - DocumentMetadataOverrides(custom_field_ids=[cf1.id, cf3.id]), + DocumentMetadataOverrides( + custom_fields={cf1.id: "value1", cf3.id: "http://example.com"}, + ), ) as consumer: consumer.run() @@ -420,6 +422,11 @@ class TestConsumer( self.assertIn(cf1, fields_used) self.assertNotIn(cf2, fields_used) self.assertIn(cf3, fields_used) + self.assertEqual(document.custom_fields.get(field=cf1).value, "value1") + self.assertEqual( + document.custom_fields.get(field=cf3).value, + "http://example.com", + ) self._assert_first_last_send_progress() def testOverrideAsn(self): diff --git a/src/documents/views.py b/src/documents/views.py index 46a7c0b6f..6cd8de5ec 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1471,7 +1471,7 @@ class PostDocumentView(GenericAPIView): created=created, asn=archive_serial_number, owner_id=request.user.id, - custom_field_ids=custom_field_ids, + custom_fields={cf_id: None for cf_id in custom_field_ids}, # for now ) async_task = consume_file.delay(