Add an optional flag to the PostDocument endpoint to skip auto tagging

This commit is contained in:
Mathieu Braem 2025-02-13 20:34:16 +01:00
parent a9ef7ff58e
commit 9f19ac12dd
7 changed files with 50 additions and 0 deletions

View File

@ -192,6 +192,8 @@ The endpoint supports the following optional form fields:
- `tags`: Similar to correspondent. Specify this multiple times to - `tags`: Similar to correspondent. Specify this multiple times to
have multiple tags added to the document. have multiple tags added to the document.
- `archive_serial_number`: An optional archive serial number to set. - `archive_serial_number`: An optional archive serial number to set.
- `skip_auto_tags`: Boolean to indicate that the classifier should not
attempt to determine and add tags to the document.
- `custom_fields`: An array of custom field ids to assign (with an empty - `custom_fields`: An array of custom field ids to assign (with an empty
value) to the document. value) to the document.

View File

@ -577,6 +577,7 @@ class ConsumerPlugin(
original_file=self.unmodified_original original_file=self.unmodified_original
if self.unmodified_original if self.unmodified_original
else self.working_copy, else self.working_copy,
skip_auto_tagging=self.metadata.skip_auto_tagging,
) )
# After everything is in the database, copy the files into # After everything is in the database, copy the files into

View File

@ -30,6 +30,7 @@ class DocumentMetadataOverrides:
change_users: list[int] | None = None change_users: list[int] | None = None
change_groups: list[int] | None = None change_groups: list[int] | None = None
custom_field_ids: list[int] | None = None custom_field_ids: list[int] | None = None
skip_auto_tagging: bool | None = None
def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides": def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides":
""" """
@ -49,6 +50,8 @@ class DocumentMetadataOverrides:
self.storage_path_id = other.storage_path_id self.storage_path_id = other.storage_path_id
if other.owner_id is not None: if other.owner_id is not None:
self.owner_id = other.owner_id self.owner_id = other.owner_id
if other.skip_auto_tagging is not None:
self.skip_auto_tagging = other.skip_auto_tagging
# merge # merge
if self.tag_ids is None: if self.tag_ids is None:

View File

@ -1536,6 +1536,13 @@ class PostDocumentSerializer(serializers.Serializer):
required=False, required=False,
) )
skip_auto_tagging = serializers.BooleanField(
label="Skip auto tagging",
default=False,
write_only=True,
required=False,
)
def validate_document(self, document): def validate_document(self, document):
document_data = document.file.read() document_data = document.file.read()
mime_type = magic.from_buffer(document_data, mime=True) mime_type = magic.from_buffer(document_data, mime=True)

View File

@ -206,8 +206,12 @@ def set_tags(
base_url=None, base_url=None,
stdout=None, stdout=None,
style_func=None, style_func=None,
skip_auto_tagging=False,
**kwargs, **kwargs,
): ):
if skip_auto_tagging:
return
if replace: if replace:
Document.tags.through.objects.filter(document=document).exclude( Document.tags.through.objects.filter(document=document).exclude(
Q(tag__is_inbox_tag=True), Q(tag__is_inbox_tag=True),

View File

@ -854,6 +854,37 @@ class TestConsumer(
self._assert_first_last_send_progress() self._assert_first_last_send_progress()
@mock.patch("documents.consumer.load_classifier")
def testClassifyDocumentWithSkippedTags(self, m):
correspondent = Correspondent.objects.create(
name="test",
matching_algorithm=Correspondent.MATCH_AUTO,
)
dtype = DocumentType.objects.create(
name="test",
matching_algorithm=DocumentType.MATCH_AUTO,
)
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO)
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO)
m.return_value = MagicMock()
m.return_value.predict_correspondent.return_value = correspondent.pk
m.return_value.predict_document_type.return_value = dtype.pk
m.return_value.predict_tags.return_value = [t2.pk]
overrides = DocumentMetadataOverrides(tag_ids=[t1.pk], skip_auto_tagging=True)
with self.get_consumer(self.get_test_file(), overrides) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual(document.correspondent, correspondent)
self.assertEqual(document.document_type, dtype)
self.assertIn(t1, document.tags.all())
self.assertNotIn(t2, document.tags.all())
self._assert_first_last_send_progress()
@override_settings(CONSUMER_DELETE_DUPLICATES=True) @override_settings(CONSUMER_DELETE_DUPLICATES=True)
def test_delete_duplicate(self): def test_delete_duplicate(self):
dst = self.get_test_file() dst = self.get_test_file()

View File

@ -1385,6 +1385,7 @@ class PostDocumentView(GenericAPIView):
created = serializer.validated_data.get("created") created = serializer.validated_data.get("created")
archive_serial_number = serializer.validated_data.get("archive_serial_number") archive_serial_number = serializer.validated_data.get("archive_serial_number")
custom_field_ids = serializer.validated_data.get("custom_fields") custom_field_ids = serializer.validated_data.get("custom_fields")
skip_auto_tagging = serializer.validated_data.get("skip_auto_tagging")
t = int(mktime(datetime.now().timetuple())) t = int(mktime(datetime.now().timetuple()))
@ -1413,6 +1414,7 @@ class PostDocumentView(GenericAPIView):
asn=archive_serial_number, asn=archive_serial_number,
owner_id=request.user.id, owner_id=request.user.id,
custom_field_ids=custom_field_ids, custom_field_ids=custom_field_ids,
skip_auto_tagging=skip_auto_tagging,
) )
async_task = consume_file.delay( async_task = consume_file.delay(