From b25f083687994e37cfa794e84ea552225fe72147 Mon Sep 17 00:00:00 2001
From: Trenton H <797416+stumpylog@users.noreply.github.com>
Date: Thu, 12 Jan 2023 12:46:28 -0800
Subject: [PATCH] Updates the exporter to use pathlib and add a few more tests
 for coverage

---
 .../management/commands/document_consumer.py  |  2 +-
 .../management/commands/document_exporter.py  | 93 +++++++++----------
 .../tests/test_management_exporter.py         | 59 ++++++++++++
 3 files changed, 106 insertions(+), 48 deletions(-)

diff --git a/src/documents/management/commands/document_consumer.py b/src/documents/management/commands/document_consumer.py
index a405f590c..9107d574a 100644
--- a/src/documents/management/commands/document_consumer.py
+++ b/src/documents/management/commands/document_consumer.py
@@ -19,7 +19,7 @@ from watchdog.observers.polling import PollingObserver
 
 try:
     from inotifyrecursive import INotify, flags
-except ImportError:
+except ImportError:  # pragma: nocover
     INotify = flags = None
 
 logger = logging.getLogger("paperless.management.consumer")
diff --git a/src/documents/management/commands/document_exporter.py b/src/documents/management/commands/document_exporter.py
index 07b4643f2..3cd028f01 100644
--- a/src/documents/management/commands/document_exporter.py
+++ b/src/documents/management/commands/document_exporter.py
@@ -4,6 +4,9 @@ import os
 import shutil
 import tempfile
 import time
+from pathlib import Path
+from typing import List
+from typing import Set
 
 import tqdm
 from django.conf import settings
@@ -96,16 +99,16 @@ class Command(BaseCommand):
 
     def __init__(self, *args, **kwargs):
         BaseCommand.__init__(self, *args, **kwargs)
-        self.target = None
-        self.files_in_export_dir = []
-        self.exported_files = []
+        self.target: Path = None
+        self.files_in_export_dir: Set[Path] = set()
+        self.exported_files: List[Path] = []
         self.compare_checksums = False
         self.use_filename_format = False
         self.delete = False
 
     def handle(self, *args, **options):
 
-        self.target = options["target"]
+        self.target = Path(options["target"]).resolve()
         self.compare_checksums = options["compare_checksums"]
         self.use_filename_format = options["use_filename_format"]
         self.delete = options["delete"]
@@ -121,11 +124,14 @@ class Command(BaseCommand):
                 dir=settings.SCRATCH_DIR,
                 prefix="paperless-export",
             )
-            self.target = temp_dir.name
+            self.target = Path(temp_dir.name).resolve()
 
-        if not os.path.exists(self.target):
+        if not self.target.exists():
             raise CommandError("That path doesn't exist")
 
+        if not self.target.is_dir():
+            raise CommandError("That path isn't a directory")
+
         if not os.access(self.target, os.W_OK):
             raise CommandError("That path doesn't appear to be writable")
 
@@ -152,10 +158,9 @@ class Command(BaseCommand):
 
     def dump(self, progress_bar_disable=False):
         # 1. Take a snapshot of what files exist in the current export folder
-        for root, dirs, files in os.walk(self.target):
-            self.files_in_export_dir.extend(
-                map(lambda f: os.path.abspath(os.path.join(root, f)), files),
-            )
+        for x in self.target.glob("**/*"):
+            if x.is_file():
+                self.files_in_export_dir.add(x.resolve())
 
         # 2. Create manifest, containing all correspondents, types, tags, storage paths
         # comments, documents and ui_settings
@@ -238,16 +243,16 @@ class Command(BaseCommand):
 
             # 3.3. write filenames into manifest
             original_name = base_name
-            original_target = os.path.join(self.target, original_name)
+            original_target = (self.target / Path(original_name)).resolve()
             document_dict[EXPORTER_FILE_NAME] = original_name
 
             thumbnail_name = base_name + "-thumbnail.webp"
-            thumbnail_target = os.path.join(self.target, thumbnail_name)
+            thumbnail_target = (self.target / Path(thumbnail_name)).resolve()
             document_dict[EXPORTER_THUMBNAIL_NAME] = thumbnail_name
 
             if document.has_archive_version:
                 archive_name = base_name + "-archive.pdf"
-                archive_target = os.path.join(self.target, archive_name)
+                archive_target = (self.target / Path(archive_name)).resolve()
                 document_dict[EXPORTER_ARCHIVE_NAME] = archive_name
             else:
                 archive_target = None
@@ -256,24 +261,21 @@ class Command(BaseCommand):
             t = int(time.mktime(document.created.timetuple()))
             if document.storage_type == Document.STORAGE_TYPE_GPG:
 
-                os.makedirs(os.path.dirname(original_target), exist_ok=True)
-                with open(original_target, "wb") as f:
-                    with document.source_file as out_file:
-                        f.write(GnuPG.decrypted(out_file))
-                        os.utime(original_target, times=(t, t))
+                original_target.parent.mkdir(parents=True, exist_ok=True)
+                with document.source_file as out_file:
+                    original_target.write_bytes(GnuPG.decrypted(out_file))
+                    os.utime(original_target, times=(t, t))
 
-                os.makedirs(os.path.dirname(thumbnail_target), exist_ok=True)
-                with open(thumbnail_target, "wb") as f:
-                    with document.thumbnail_file as out_file:
-                        f.write(GnuPG.decrypted(out_file))
-                        os.utime(thumbnail_target, times=(t, t))
+                thumbnail_target.parent.mkdir(parents=True, exist_ok=True)
+                with document.thumbnail_file as out_file:
+                    thumbnail_target.write_bytes(GnuPG.decrypted(out_file))
+                    os.utime(thumbnail_target, times=(t, t))
 
                 if archive_target:
-                    os.makedirs(os.path.dirname(archive_target), exist_ok=True)
-                    with open(archive_target, "wb") as f:
-                        with document.archive_path as out_file:
-                            f.write(GnuPG.decrypted(out_file))
-                            os.utime(archive_target, times=(t, t))
+                    archive_target.parent.mkdir(parents=True, exist_ok=True)
+                    with document.archive_path as out_file:
+                        archive_target.write_bytes(GnuPG.decrypted(out_file))
+                        os.utime(archive_target, times=(t, t))
             else:
                 self.check_and_copy(
                     document.source_path,
@@ -291,16 +293,14 @@ class Command(BaseCommand):
                     )
 
         # 4.1 write manifest to target folder
-        manifest_path = os.path.abspath(os.path.join(self.target, "manifest.json"))
-
-        with open(manifest_path, "w") as f:
-            json.dump(manifest, f, indent=2)
+        manifest_path = (self.target / Path("manifest.json")).resolve()
+        manifest_path.write_text(json.dumps(manifest, indent=2))
 
         # 4.2 write version information to target folder
-        version_path = os.path.abspath(os.path.join(self.target, "version.json"))
-
-        with open(version_path, "w") as f:
-            json.dump({"version": version.__full_version_str__}, f, indent=2)
+        version_path = (self.target / Path("version.json")).resolve()
+        version_path.write_text(
+            json.dumps({"version": version.__full_version_str__}, indent=2),
+        )
 
         if self.delete:
             # 5. Remove files which we did not explicitly export in this run
@@ -309,25 +309,24 @@ class Command(BaseCommand):
                 self.files_in_export_dir.remove(manifest_path)
 
             for f in self.files_in_export_dir:
-                os.remove(f)
+                f.unlink()
 
                 delete_empty_directories(
-                    os.path.abspath(os.path.dirname(f)),
-                    os.path.abspath(self.target),
+                    f.parent,
+                    self.target,
                 )
 
-    def check_and_copy(self, source, source_checksum, target):
-        if os.path.abspath(target) in self.files_in_export_dir:
-            self.files_in_export_dir.remove(os.path.abspath(target))
+    def check_and_copy(self, source, source_checksum, target: Path):
+        if target in self.files_in_export_dir:
+            self.files_in_export_dir.remove(target)
 
         perform_copy = False
 
-        if os.path.exists(target):
+        if target.exists():
             source_stat = os.stat(source)
-            target_stat = os.stat(target)
+            target_stat = target.stat()
             if self.compare_checksums and source_checksum:
-                with open(target, "rb") as f:
-                    target_checksum = hashlib.md5(f.read()).hexdigest()
+                target_checksum = hashlib.md5(target.read_bytes()).hexdigest()
                 perform_copy = target_checksum != source_checksum
             elif source_stat.st_mtime != target_stat.st_mtime:
                 perform_copy = True
@@ -338,5 +337,5 @@ class Command(BaseCommand):
             perform_copy = True
 
         if perform_copy:
-            os.makedirs(os.path.dirname(target), exist_ok=True)
+            target.parent.mkdir(parents=True, exist_ok=True)
             shutil.copy2(source, target)
diff --git a/src/documents/tests/test_management_exporter.py b/src/documents/tests/test_management_exporter.py
index a24b292d7..5aff05793 100644
--- a/src/documents/tests/test_management_exporter.py
+++ b/src/documents/tests/test_management_exporter.py
@@ -8,6 +8,7 @@ from unittest import mock
 from zipfile import ZipFile
 
 from django.core.management import call_command
+from django.core.management.base import CommandError
 from django.test import override_settings
 from django.test import TestCase
 from django.utils import timezone
@@ -438,3 +439,61 @@ class TestExportImport(DirectoriesMixin, TestCase):
             self.assertEqual(len(zip.namelist()), 14)
             self.assertIn("manifest.json", zip.namelist())
             self.assertIn("version.json", zip.namelist())
+
+    def test_export_target_not_exists(self):
+        """
+        GIVEN:
+            - Request to export documents to directory that doesn't exist
+        WHEN:
+            - Export command is called
+        THEN:
+            - Error is raised
+        """
+        args = ["document_exporter", "/tmp/foo/bar"]
+
+        with self.assertRaises(CommandError) as e:
+
+            call_command(*args)
+
+            self.assertEqual("That path isn't a directory", str(e))
+
+    def test_export_target_exists_but_is_file(self):
+        """
+        GIVEN:
+            - Request to export documents to file instead of directory
+        WHEN:
+            - Export command is called
+        THEN:
+            - Error is raised
+        """
+
+        with tempfile.NamedTemporaryFile() as tmp_file:
+
+            args = ["document_exporter", tmp_file.name]
+
+            with self.assertRaises(CommandError) as e:
+
+                call_command(*args)
+
+                self.assertEqual("That path isn't a directory", str(e))
+
+    def test_export_target_not_writable(self):
+        """
+        GIVEN:
+            - Request to export documents to directory that's not writeable
+        WHEN:
+            - Export command is called
+        THEN:
+            - Error is raised
+        """
+        with tempfile.TemporaryDirectory() as tmp_dir:
+
+            os.chmod(tmp_dir, 0o000)
+
+            args = ["document_exporter", tmp_dir]
+
+            with self.assertRaises(CommandError) as e:
+
+                call_command(*args)
+
+                self.assertEqual("That path doesn't appear to be writable", str(e))