diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index aa5b8ea3f..132547704 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -1,3 +1,5 @@ +import itertools + from django.db.models import Q from django_q.tasks import async_task @@ -66,6 +68,27 @@ def remove_tag(doc_ids, tag): return "OK" +def modify_tags(doc_ids, add_tags, remove_tags): + qs = Document.objects.filter(id__in=doc_ids) + affected_docs = [doc.id for doc in qs] + + DocumentTagRelationship = Document.tags.through + + DocumentTagRelationship.objects.filter( + document_id__in=affected_docs, + tag_id__in=remove_tags, + ).delete() + + DocumentTagRelationship.objects.bulk_create([ + DocumentTagRelationship( + document_id=doc, tag_id=tag) for (doc,tag) in itertools.product(affected_docs, add_tags) + ], ignore_conflicts=True) + + async_task("documents.tasks.bulk_rename_files", document_ids=affected_docs) + + return "OK" + + def delete(doc_ids): Document.objects.filter(id__in=doc_ids).delete() diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index d9f1833bf..f34176d8a 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -217,6 +217,7 @@ class BulkEditSerializer(serializers.Serializer): "set_document_type", "add_tag", "remove_tag", + "modify_tags", "delete" ], label="Method", @@ -225,11 +226,19 @@ class BulkEditSerializer(serializers.Serializer): parameters = serializers.DictField(allow_empty=True) - def validate_documents(self, documents): + def _validate_document_id_list(self, documents, name="documents"): + if not type(documents) == list: + raise serializers.ValidationError(f"{name} must be a list") + if not all([type(i) == int for i in documents]): + raise serializers.ValidationError(f"{name} must be a list of integers") count = Document.objects.filter(id__in=documents).count() if not count == len(documents): raise serializers.ValidationError( - "Some documents don't exist or were specified twice.") + f"Some documents in {name} don't exist or were specified twice.") + + + def validate_documents(self, documents): + self._validate_document_id_list(documents) return documents def validate_method(self, method): @@ -241,6 +250,8 @@ class BulkEditSerializer(serializers.Serializer): return bulk_edit.add_tag elif method == "remove_tag": return bulk_edit.remove_tag + elif method == "modify_tags": + return bulk_edit.modify_tags elif method == "delete": return bulk_edit.delete else: @@ -283,6 +294,17 @@ class BulkEditSerializer(serializers.Serializer): else: raise serializers.ValidationError("correspondent not specified") + def _validate_parameters_modify_tags(self, parameters): + if "add_tags" in parameters: + self._validate_document_id_list(parameters['add_tags'], "add_tags") + else: + raise serializers.ValidationError("add_tags not specified") + + if "remove_tags" in parameters: + self._validate_document_id_list(parameters['remove_tags'], "remove_tags") + else: + raise serializers.ValidationError("remove_tags not specified") + def validate(self, attrs): method = attrs['method'] @@ -294,6 +316,8 @@ class BulkEditSerializer(serializers.Serializer): self._validate_parameters_document_type(parameters) elif method == bulk_edit.add_tag or method == bulk_edit.remove_tag: self._validate_parameters_tags(parameters) + elif method == bulk_edit.modify_tags: + self._validate_parameters_modify_tags(parameters) return attrs diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 0262b6d6a..030de652d 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -743,6 +743,15 @@ class TestBulkEdit(DirectoriesMixin, APITestCase): args, kwargs = self.async_task.call_args self.assertCountEqual(kwargs['document_ids'], [self.doc4.id]) + def test_modify_tags(self): + tag_unrelated = Tag.objects.create(name="unrelated") + self.doc2.tags.add(tag_unrelated) + self.doc3.tags.add(tag_unrelated) + bulk_edit.modify_tags([self.doc2.id, self.doc3.id], add_tags=[self.t2.id], remove_tags=[self.t1.id]) + + self.assertCountEqual(list(self.doc2.tags.all()), [self.t2, tag_unrelated]) + self.assertCountEqual(list(self.doc3.tags.all()), [self.t2, tag_unrelated]) + def test_delete(self): self.assertEqual(Document.objects.count(), 5) bulk_edit.delete([self.doc1.id, self.doc2.id])