From bbd4659fbffed2796d555f5f1d75fb2d4e317cb6 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 23 Jun 2023 08:56:18 -0700 Subject: [PATCH] Include global and object-level permissions in export / import adds test for transaction --- .../management/commands/document_exporter.py | 20 ++++++ .../management/commands/document_importer.py | 14 +++- .../tests/test_management_exporter.py | 66 ++++++++++++++++++- 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/src/documents/management/commands/document_exporter.py b/src/documents/management/commands/document_exporter.py index fba89695b..22fb59308 100644 --- a/src/documents/management/commands/document_exporter.py +++ b/src/documents/management/commands/document_exporter.py @@ -11,13 +11,17 @@ from typing import Set import tqdm from django.conf import settings from django.contrib.auth.models import Group +from django.contrib.auth.models import Permission from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType from django.core import serializers from django.core.management.base import BaseCommand from django.core.management.base import CommandError from django.db import transaction from django.utils import timezone from filelock import FileLock +from guardian.models import GroupObjectPermission +from guardian.models import UserObjectPermission from documents.file_handling import delete_empty_directories from documents.file_handling import generate_filename @@ -261,6 +265,22 @@ class Command(BaseCommand): serializers.serialize("json", UiSettings.objects.all()), ) + manifest += json.loads( + serializers.serialize("json", ContentType.objects.all()), + ) + + manifest += json.loads( + serializers.serialize("json", Permission.objects.all()), + ) + + manifest += json.loads( + serializers.serialize("json", UserObjectPermission.objects.all()), + ) + + manifest += json.loads( + serializers.serialize("json", GroupObjectPermission.objects.all()), + ) + # 3. Export files from each document for index, document_dict in tqdm.tqdm( enumerate(document_manifest), diff --git a/src/documents/management/commands/document_importer.py b/src/documents/management/commands/document_importer.py index b00cb45fa..baf6d7528 100644 --- a/src/documents/management/commands/document_importer.py +++ b/src/documents/management/commands/document_importer.py @@ -7,11 +7,15 @@ from pathlib import Path import tqdm from django.conf import settings +from django.contrib.auth.models import Permission +from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldDoesNotExist from django.core.management import call_command from django.core.management.base import BaseCommand from django.core.management.base import CommandError from django.core.serializers.base import DeserializationError +from django.db import IntegrityError +from django.db import transaction from django.db.models.signals import m2m_changed from django.db.models.signals import post_save from filelock import FileLock @@ -116,9 +120,13 @@ class Command(BaseCommand): ): # Fill up the database with whatever is in the manifest try: - for manifest_path in manifest_paths: - call_command("loaddata", manifest_path) - except (FieldDoesNotExist, DeserializationError) as e: + with transaction.atomic(): + for manifest_path in manifest_paths: + # delete these since pk can change, re-created from import + ContentType.objects.all().delete() + Permission.objects.all().delete() + call_command("loaddata", manifest_path) + except (FieldDoesNotExist, DeserializationError, IntegrityError) as e: self.stdout.write(self.style.ERROR("Database import failed")) if ( self.version is not None diff --git a/src/documents/tests/test_management_exporter.py b/src/documents/tests/test_management_exporter.py index e7c116caf..421ae51fc 100644 --- a/src/documents/tests/test_management_exporter.py +++ b/src/documents/tests/test_management_exporter.py @@ -7,11 +7,18 @@ from pathlib import Path from unittest import mock from zipfile import ZipFile +from django.contrib.auth.models import Group +from django.contrib.auth.models import Permission +from django.contrib.contenttypes.models import ContentType from django.core.management import call_command from django.core.management.base import CommandError +from django.db import IntegrityError from django.test import TestCase from django.test import override_settings from django.utils import timezone +from guardian.models import GroupObjectPermission +from guardian.models import UserObjectPermission +from guardian.shortcuts import assign_perm from documents.management.commands import document_exporter from documents.models import Correspondent @@ -34,6 +41,8 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.addCleanup(shutil.rmtree, self.target) self.user = User.objects.create(username="temp_admin") + self.user2 = User.objects.create(username="user2") + self.group1 = Group.objects.create(name="group1") self.d1 = Document.objects.create( content="Content", @@ -73,6 +82,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): user=self.user, ) + assign_perm("view_document", self.user2, self.d2) + assign_perm("view_document", self.group1, self.d3) + self.t1 = Tag.objects.create(name="t") self.dt1 = DocumentType.objects.create(name="dt") self.c1 = Correspondent.objects.create(name="c") @@ -141,12 +153,12 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): manifest = self._do_export(use_filename_format=use_filename_format) - self.assertEqual(len(manifest), 10) + self.assertEqual(len(manifest), 149) # dont include consumer or AnonymousUser users self.assertEqual( len(list(filter(lambda e: e["model"] == "auth.user", manifest))), - 1, + 2, ) self.assertEqual( @@ -218,6 +230,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): Correspondent.objects.all().delete() DocumentType.objects.all().delete() Tag.objects.all().delete() + Permission.objects.all().delete() + UserObjectPermission.objects.all().delete() + GroupObjectPermission.objects.all().delete() self.assertEqual(Document.objects.count(), 0) call_command("document_importer", "--no-progress-bar", self.target) @@ -230,6 +245,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertEqual(Document.objects.get(id=self.d2.id).title, "wow2") self.assertEqual(Document.objects.get(id=self.d3.id).title, "wow2") self.assertEqual(Document.objects.get(id=self.d4.id).title, "wow_dec") + self.assertEqual(GroupObjectPermission.objects.count(), 1) + self.assertEqual(UserObjectPermission.objects.count(), 1) + self.assertEqual(Permission.objects.count(), 108) messages = check_sanity() # everything is alright after the test self.assertEqual(len(messages), 0) @@ -641,3 +659,47 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertEqual(Document.objects.count(), 0) call_command("document_importer", "--no-progress-bar", self.target) self.assertEqual(Document.objects.count(), 4) + + def test_import_db_transaction_failed(self): + """ + GIVEN: + - Import from manifest started + WHEN: + - Import of database fails + THEN: + - ContentType & Permission objects are not deleted, db transaction rolled back + """ + + shutil.rmtree(os.path.join(self.dirs.media_dir, "documents")) + shutil.copytree( + os.path.join(os.path.dirname(__file__), "samples", "documents"), + os.path.join(self.dirs.media_dir, "documents"), + ) + + self.assertEqual(ContentType.objects.count(), 27) + self.assertEqual(Permission.objects.count(), 108) + + manifest = self._do_export() + + with paperless_environment(): + self.assertEqual( + len(list(filter(lambda e: e["model"] == "auth.permission", manifest))), + 108, + ) + # add 1 more to db to show objects are not re-created by import + Permission.objects.create( + name="test", + codename="test_perm", + content_type_id=1, + ) + self.assertEqual(Permission.objects.count(), 109) + + # will cause an import error + self.user.delete() + self.user = User.objects.create(username="temp_admin") + + with self.assertRaises(IntegrityError): + call_command("document_importer", "--no-progress-bar", self.target) + + self.assertEqual(ContentType.objects.count(), 27) + self.assertEqual(Permission.objects.count(), 109)