mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-19 10:19:27 -05:00
Add an optional flag to the PostDocument endpoint to skip auto tagging
This commit is contained in:
parent
a9ef7ff58e
commit
9f19ac12dd
@ -192,6 +192,8 @@ The endpoint supports the following optional form fields:
|
|||||||
- `tags`: Similar to correspondent. Specify this multiple times to
|
- `tags`: Similar to correspondent. Specify this multiple times to
|
||||||
have multiple tags added to the document.
|
have multiple tags added to the document.
|
||||||
- `archive_serial_number`: An optional archive serial number to set.
|
- `archive_serial_number`: An optional archive serial number to set.
|
||||||
|
- `skip_auto_tags`: Boolean to indicate that the classifier should not
|
||||||
|
attempt to determine and add tags to the document.
|
||||||
- `custom_fields`: An array of custom field ids to assign (with an empty
|
- `custom_fields`: An array of custom field ids to assign (with an empty
|
||||||
value) to the document.
|
value) to the document.
|
||||||
|
|
||||||
|
@ -577,6 +577,7 @@ class ConsumerPlugin(
|
|||||||
original_file=self.unmodified_original
|
original_file=self.unmodified_original
|
||||||
if self.unmodified_original
|
if self.unmodified_original
|
||||||
else self.working_copy,
|
else self.working_copy,
|
||||||
|
skip_auto_tagging=self.metadata.skip_auto_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
# After everything is in the database, copy the files into
|
# After everything is in the database, copy the files into
|
||||||
|
@ -30,6 +30,7 @@ class DocumentMetadataOverrides:
|
|||||||
change_users: list[int] | None = None
|
change_users: list[int] | None = None
|
||||||
change_groups: list[int] | None = None
|
change_groups: list[int] | None = None
|
||||||
custom_field_ids: list[int] | None = None
|
custom_field_ids: list[int] | None = None
|
||||||
|
skip_auto_tagging: bool | None = None
|
||||||
|
|
||||||
def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides":
|
def update(self, other: "DocumentMetadataOverrides") -> "DocumentMetadataOverrides":
|
||||||
"""
|
"""
|
||||||
@ -49,6 +50,8 @@ class DocumentMetadataOverrides:
|
|||||||
self.storage_path_id = other.storage_path_id
|
self.storage_path_id = other.storage_path_id
|
||||||
if other.owner_id is not None:
|
if other.owner_id is not None:
|
||||||
self.owner_id = other.owner_id
|
self.owner_id = other.owner_id
|
||||||
|
if other.skip_auto_tagging is not None:
|
||||||
|
self.skip_auto_tagging = other.skip_auto_tagging
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
if self.tag_ids is None:
|
if self.tag_ids is None:
|
||||||
|
@ -1536,6 +1536,13 @@ class PostDocumentSerializer(serializers.Serializer):
|
|||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
skip_auto_tagging = serializers.BooleanField(
|
||||||
|
label="Skip auto tagging",
|
||||||
|
default=False,
|
||||||
|
write_only=True,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_document(self, document):
|
def validate_document(self, document):
|
||||||
document_data = document.file.read()
|
document_data = document.file.read()
|
||||||
mime_type = magic.from_buffer(document_data, mime=True)
|
mime_type = magic.from_buffer(document_data, mime=True)
|
||||||
|
@ -206,8 +206,12 @@ def set_tags(
|
|||||||
base_url=None,
|
base_url=None,
|
||||||
stdout=None,
|
stdout=None,
|
||||||
style_func=None,
|
style_func=None,
|
||||||
|
skip_auto_tagging=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if skip_auto_tagging:
|
||||||
|
return
|
||||||
|
|
||||||
if replace:
|
if replace:
|
||||||
Document.tags.through.objects.filter(document=document).exclude(
|
Document.tags.through.objects.filter(document=document).exclude(
|
||||||
Q(tag__is_inbox_tag=True),
|
Q(tag__is_inbox_tag=True),
|
||||||
|
@ -854,6 +854,37 @@ class TestConsumer(
|
|||||||
|
|
||||||
self._assert_first_last_send_progress()
|
self._assert_first_last_send_progress()
|
||||||
|
|
||||||
|
@mock.patch("documents.consumer.load_classifier")
|
||||||
|
def testClassifyDocumentWithSkippedTags(self, m):
|
||||||
|
correspondent = Correspondent.objects.create(
|
||||||
|
name="test",
|
||||||
|
matching_algorithm=Correspondent.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
dtype = DocumentType.objects.create(
|
||||||
|
name="test",
|
||||||
|
matching_algorithm=DocumentType.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO)
|
||||||
|
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO)
|
||||||
|
|
||||||
|
m.return_value = MagicMock()
|
||||||
|
m.return_value.predict_correspondent.return_value = correspondent.pk
|
||||||
|
m.return_value.predict_document_type.return_value = dtype.pk
|
||||||
|
m.return_value.predict_tags.return_value = [t2.pk]
|
||||||
|
|
||||||
|
overrides = DocumentMetadataOverrides(tag_ids=[t1.pk], skip_auto_tagging=True)
|
||||||
|
with self.get_consumer(self.get_test_file(), overrides) as consumer:
|
||||||
|
consumer.run()
|
||||||
|
|
||||||
|
document = Document.objects.first()
|
||||||
|
|
||||||
|
self.assertEqual(document.correspondent, correspondent)
|
||||||
|
self.assertEqual(document.document_type, dtype)
|
||||||
|
self.assertIn(t1, document.tags.all())
|
||||||
|
self.assertNotIn(t2, document.tags.all())
|
||||||
|
|
||||||
|
self._assert_first_last_send_progress()
|
||||||
|
|
||||||
@override_settings(CONSUMER_DELETE_DUPLICATES=True)
|
@override_settings(CONSUMER_DELETE_DUPLICATES=True)
|
||||||
def test_delete_duplicate(self):
|
def test_delete_duplicate(self):
|
||||||
dst = self.get_test_file()
|
dst = self.get_test_file()
|
||||||
|
@ -1385,6 +1385,7 @@ class PostDocumentView(GenericAPIView):
|
|||||||
created = serializer.validated_data.get("created")
|
created = serializer.validated_data.get("created")
|
||||||
archive_serial_number = serializer.validated_data.get("archive_serial_number")
|
archive_serial_number = serializer.validated_data.get("archive_serial_number")
|
||||||
custom_field_ids = serializer.validated_data.get("custom_fields")
|
custom_field_ids = serializer.validated_data.get("custom_fields")
|
||||||
|
skip_auto_tagging = serializer.validated_data.get("skip_auto_tagging")
|
||||||
|
|
||||||
t = int(mktime(datetime.now().timetuple()))
|
t = int(mktime(datetime.now().timetuple()))
|
||||||
|
|
||||||
@ -1413,6 +1414,7 @@ class PostDocumentView(GenericAPIView):
|
|||||||
asn=archive_serial_number,
|
asn=archive_serial_number,
|
||||||
owner_id=request.user.id,
|
owner_id=request.user.id,
|
||||||
custom_field_ids=custom_field_ids,
|
custom_field_ids=custom_field_ids,
|
||||||
|
skip_auto_tagging=skip_auto_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
async_task = consume_file.delay(
|
async_task = consume_file.delay(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user