From 9f19ac12dddf099a039213b36ad3d3956d5c468e Mon Sep 17 00:00:00 2001 From: Mathieu Braem Date: Thu, 13 Feb 2025 20:34:16 +0100 Subject: [PATCH] Add an optional flag to the PostDocument endpoint to skip auto tagging --- docs/api.md | 2 ++ src/documents/consumer.py | 1 + src/documents/data_models.py | 3 +++ src/documents/serialisers.py | 7 +++++++ src/documents/signals/handlers.py | 4 ++++ src/documents/tests/test_consumer.py | 31 ++++++++++++++++++++++++++++ src/documents/views.py | 2 ++ 7 files changed, 50 insertions(+) diff --git a/docs/api.md b/docs/api.md index 9c28476c4..0bb661e92 100644 --- a/docs/api.md +++ b/docs/api.md @@ -192,6 +192,8 @@ The endpoint supports the following optional form fields: - `tags`: Similar to correspondent. Specify this multiple times to have multiple tags added to the document. - `archive_serial_number`: An optional archive serial number to set. +- `skip_auto_tags`: Boolean to indicate that the classifier should not + attempt to determine and add tags to the document. - `custom_fields`: An array of custom field ids to assign (with an empty value) to the document. diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 81739fa7a..ec62c176c 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -577,6 +577,7 @@ class ConsumerPlugin( original_file=self.unmodified_original if self.unmodified_original else self.working_copy, + skip_auto_tagging=self.metadata.skip_auto_tagging, ) # After everything is in the database, copy the files into diff --git a/src/documents/data_models.py b/src/documents/data_models.py index 231e59005..5a6741e1d 100644 --- a/src/documents/data_models.py +++ b/src/documents/data_models.py @@ -30,6 +30,7 @@ class DocumentMetadataOverrides: change_users: list[int] | None = None change_groups: list[int] | None = None custom_field_ids: list[int] | None = None + skip_auto_tagging: bool | None = None def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides": """ @@ -49,6 +50,8 @@ class DocumentMetadataOverrides: self.storage_path_id = other.storage_path_id if other.owner_id is not None: self.owner_id = other.owner_id + if other.skip_auto_tagging is not None: + self.skip_auto_tagging = other.skip_auto_tagging # merge if self.tag_ids is None: diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 6a0a1eec1..e863dcde9 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1536,6 +1536,13 @@ class PostDocumentSerializer(serializers.Serializer): required=False, ) + skip_auto_tagging = serializers.BooleanField( + label="Skip auto tagging", + default=False, + write_only=True, + required=False, + ) + def validate_document(self, document): document_data = document.file.read() mime_type = magic.from_buffer(document_data, mime=True) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 5da9ef879..d055e0623 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -206,8 +206,12 @@ def set_tags( base_url=None, stdout=None, style_func=None, + skip_auto_tagging=False, **kwargs, ): + if skip_auto_tagging: + return + if replace: Document.tags.through.objects.filter(document=document).exclude( Q(tag__is_inbox_tag=True), diff --git a/src/documents/tests/test_consumer.py b/src/documents/tests/test_consumer.py index 6f576ab24..1ba3cec20 100644 --- a/src/documents/tests/test_consumer.py +++ b/src/documents/tests/test_consumer.py @@ -854,6 +854,37 @@ class TestConsumer( self._assert_first_last_send_progress() + @mock.patch("documents.consumer.load_classifier") + def testClassifyDocumentWithSkippedTags(self, m): + correspondent = Correspondent.objects.create( + name="test", + matching_algorithm=Correspondent.MATCH_AUTO, + ) + dtype = DocumentType.objects.create( + name="test", + matching_algorithm=DocumentType.MATCH_AUTO, + ) + t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO) + t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO) + + m.return_value = MagicMock() + m.return_value.predict_correspondent.return_value = correspondent.pk + m.return_value.predict_document_type.return_value = dtype.pk + m.return_value.predict_tags.return_value = [t2.pk] + + overrides = DocumentMetadataOverrides(tag_ids=[t1.pk], skip_auto_tagging=True) + with self.get_consumer(self.get_test_file(), overrides) as consumer: + consumer.run() + + document = Document.objects.first() + + self.assertEqual(document.correspondent, correspondent) + self.assertEqual(document.document_type, dtype) + self.assertIn(t1, document.tags.all()) + self.assertNotIn(t2, document.tags.all()) + + self._assert_first_last_send_progress() + @override_settings(CONSUMER_DELETE_DUPLICATES=True) def test_delete_duplicate(self): dst = self.get_test_file() diff --git a/src/documents/views.py b/src/documents/views.py index a856883f3..4c91bac50 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1385,6 +1385,7 @@ class PostDocumentView(GenericAPIView): created = serializer.validated_data.get("created") archive_serial_number = serializer.validated_data.get("archive_serial_number") custom_field_ids = serializer.validated_data.get("custom_fields") + skip_auto_tagging = serializer.validated_data.get("skip_auto_tagging") t = int(mktime(datetime.now().timetuple())) @@ -1413,6 +1414,7 @@ class PostDocumentView(GenericAPIView): asn=archive_serial_number, owner_id=request.user.id, custom_field_ids=custom_field_ids, + skip_auto_tagging=skip_auto_tagging, ) async_task = consume_file.delay(