mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-02-28 01:19:36 -06:00
519 lines
15 KiB
Python
519 lines
15 KiB
Python
"""Tests for PaperlessCommand base class."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from django.core.management import CommandError
|
|
from django.db.models import QuerySet
|
|
from rich.console import Console
|
|
|
|
from documents.management.commands.base import PaperlessCommand
|
|
from documents.management.commands.base import ProcessResult
|
|
|
|
if TYPE_CHECKING:
|
|
from pytest_mock import MockerFixture
|
|
|
|
|
|
# --- Test Commands ---
|
|
# These simulate real command implementations for testing
|
|
|
|
|
|
class SimpleCommand(PaperlessCommand):
|
|
"""Command with default settings (progress bar, no multiprocessing)."""
|
|
|
|
help = "Simple test command"
|
|
|
|
def handle(self, *args, **options):
|
|
items = list(range(5))
|
|
results = []
|
|
for item in self.track(items, description="Processing..."):
|
|
results.append(item * 2)
|
|
self.stdout.write(f"Results: {results}")
|
|
|
|
|
|
class NoProgressBarCommand(PaperlessCommand):
|
|
"""Command with progress bar disabled."""
|
|
|
|
help = "No progress bar command"
|
|
supports_progress_bar = False
|
|
|
|
def handle(self, *args, **options):
|
|
items = list(range(3))
|
|
for _ in self.track(items):
|
|
# We don't need to actually work
|
|
pass
|
|
self.stdout.write("Done")
|
|
|
|
|
|
class MultiprocessCommand(PaperlessCommand):
|
|
"""Command with multiprocessing support."""
|
|
|
|
help = "Multiprocess test command"
|
|
supports_multiprocessing = True
|
|
|
|
def handle(self, *args, **options):
|
|
items = list(range(5))
|
|
results = []
|
|
for result in self.process_parallel(
|
|
_double_value,
|
|
items,
|
|
description="Processing...",
|
|
):
|
|
results.append(result)
|
|
successes = sum(1 for r in results if r.success)
|
|
self.stdout.write(f"Successes: {successes}")
|
|
|
|
|
|
# --- Helper Functions for Multiprocessing ---
|
|
# Must be at module level to be picklable
|
|
|
|
|
|
def _double_value(x: int) -> int:
|
|
"""Double the input value."""
|
|
return x * 2
|
|
|
|
|
|
def _divide_ten_by(x: int) -> float:
|
|
"""Divide 10 by x. Raises ZeroDivisionError if x is 0."""
|
|
return 10 / x
|
|
|
|
|
|
# --- Fixtures ---
|
|
|
|
|
|
@pytest.fixture
|
|
def console() -> Console:
|
|
"""Create a non-interactive console for testing."""
|
|
return Console(force_terminal=False, force_interactive=False)
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_command(console: Console) -> SimpleCommand:
|
|
"""Create a SimpleCommand instance configured for testing."""
|
|
command = SimpleCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
command.console = console
|
|
command.no_progress_bar = True
|
|
command.process_count = 1
|
|
return command
|
|
|
|
|
|
@pytest.fixture
|
|
def multiprocess_command(console: Console) -> MultiprocessCommand:
|
|
"""Create a MultiprocessCommand instance configured for testing."""
|
|
command = MultiprocessCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
command.console = console
|
|
command.no_progress_bar = True
|
|
command.process_count = 1
|
|
return command
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_queryset():
|
|
"""
|
|
Create a mock Django QuerySet that tracks method calls.
|
|
|
|
This verifies we use .count() instead of len() for querysets.
|
|
"""
|
|
|
|
class MockQuerySet(QuerySet):
|
|
def __init__(self, items: list):
|
|
self._items = items
|
|
self.count_called = False
|
|
|
|
def count(self) -> int:
|
|
self.count_called = True
|
|
return len(self._items)
|
|
|
|
def __iter__(self):
|
|
return iter(self._items)
|
|
|
|
def __len__(self):
|
|
raise AssertionError("len() should not be called on querysets")
|
|
|
|
return MockQuerySet
|
|
|
|
|
|
# --- Test Classes ---
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestProcessResult:
|
|
"""Tests for the ProcessResult dataclass."""
|
|
|
|
def test_success_result(self):
|
|
result = ProcessResult(item=1, result=2, error=None)
|
|
|
|
assert result.item == 1
|
|
assert result.result == 2
|
|
assert result.error is None
|
|
assert result.success is True
|
|
|
|
def test_error_result(self):
|
|
error = ValueError("test error")
|
|
result = ProcessResult(item=1, result=None, error=error)
|
|
|
|
assert result.item == 1
|
|
assert result.result is None
|
|
assert result.error is error
|
|
assert result.success is False
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestPaperlessCommandArguments:
|
|
"""Tests for argument parsing behavior."""
|
|
|
|
def test_progress_bar_argument_added_by_default(self):
|
|
command = SimpleCommand()
|
|
parser = command.create_parser("manage.py", "simple")
|
|
|
|
options = parser.parse_args(["--no-progress-bar"])
|
|
assert options.no_progress_bar is True
|
|
|
|
options = parser.parse_args([])
|
|
assert options.no_progress_bar is False
|
|
|
|
def test_progress_bar_argument_not_added_when_disabled(self):
|
|
command = NoProgressBarCommand()
|
|
parser = command.create_parser("manage.py", "noprogress")
|
|
|
|
options = parser.parse_args([])
|
|
assert not hasattr(options, "no_progress_bar")
|
|
|
|
def test_processes_argument_added_when_multiprocessing_enabled(self):
|
|
command = MultiprocessCommand()
|
|
parser = command.create_parser("manage.py", "multiprocess")
|
|
|
|
options = parser.parse_args(["--processes", "4"])
|
|
assert options.processes == 4
|
|
|
|
options = parser.parse_args([])
|
|
assert options.processes >= 1
|
|
|
|
def test_processes_argument_not_added_when_multiprocessing_disabled(self):
|
|
command = SimpleCommand()
|
|
parser = command.create_parser("manage.py", "simple")
|
|
|
|
options = parser.parse_args([])
|
|
assert not hasattr(options, "processes")
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestPaperlessCommandExecute:
|
|
"""Tests for the execute() setup behavior."""
|
|
|
|
@pytest.fixture
|
|
def base_options(self) -> dict:
|
|
"""Base options required for execute()."""
|
|
return {
|
|
"verbosity": 1,
|
|
"no_color": True,
|
|
"force_color": False,
|
|
"skip_checks": True,
|
|
}
|
|
|
|
@pytest.mark.parametrize(
|
|
("no_progress_bar_flag", "expected"),
|
|
[
|
|
pytest.param(False, False, id="progress-bar-enabled"),
|
|
pytest.param(True, True, id="progress-bar-disabled"),
|
|
],
|
|
)
|
|
def test_no_progress_bar_state_set(
|
|
self,
|
|
base_options: dict,
|
|
*,
|
|
no_progress_bar_flag: bool,
|
|
expected: bool,
|
|
):
|
|
command = SimpleCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
|
|
options = {**base_options, "no_progress_bar": no_progress_bar_flag}
|
|
command.execute(**options)
|
|
|
|
assert command.no_progress_bar is expected
|
|
|
|
def test_no_progress_bar_always_true_when_not_supported(self, base_options: dict):
|
|
command = NoProgressBarCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
|
|
command.execute(**base_options)
|
|
|
|
assert command.no_progress_bar is True
|
|
|
|
@pytest.mark.parametrize(
|
|
("processes", "expected"),
|
|
[
|
|
pytest.param(1, 1, id="single-process"),
|
|
pytest.param(4, 4, id="four-processes"),
|
|
],
|
|
)
|
|
def test_process_count_set(
|
|
self,
|
|
base_options: dict,
|
|
processes: int,
|
|
expected: int,
|
|
):
|
|
command = MultiprocessCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
|
|
options = {**base_options, "processes": processes, "no_progress_bar": True}
|
|
command.execute(**options)
|
|
|
|
assert command.process_count == expected
|
|
|
|
@pytest.mark.parametrize(
|
|
"invalid_count",
|
|
[
|
|
pytest.param(0, id="zero"),
|
|
pytest.param(-1, id="negative"),
|
|
],
|
|
)
|
|
def test_process_count_validation_rejects_invalid(
|
|
self,
|
|
base_options: dict,
|
|
invalid_count: int,
|
|
):
|
|
command = MultiprocessCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
|
|
options = {**base_options, "processes": invalid_count, "no_progress_bar": True}
|
|
|
|
with pytest.raises(CommandError, match="--processes must be at least 1"):
|
|
command.execute(**options)
|
|
|
|
def test_process_count_defaults_to_one_when_not_supported(self, base_options: dict):
|
|
command = SimpleCommand()
|
|
command.stdout = io.StringIO()
|
|
command.stderr = io.StringIO()
|
|
|
|
options = {**base_options, "no_progress_bar": True}
|
|
command.execute(**options)
|
|
|
|
assert command.process_count == 1
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestGetIterableLength:
|
|
"""Tests for the _get_iterable_length() method."""
|
|
|
|
def test_uses_count_for_querysets(
|
|
self,
|
|
simple_command: SimpleCommand,
|
|
mock_queryset,
|
|
):
|
|
"""Should call .count() on Django querysets rather than len()."""
|
|
queryset = mock_queryset([1, 2, 3, 4, 5])
|
|
|
|
result = simple_command._get_iterable_length(queryset)
|
|
|
|
assert result == 5
|
|
assert queryset.count_called is True
|
|
|
|
def test_uses_len_for_sized(self, simple_command: SimpleCommand):
|
|
"""Should use len() for sequences and other Sized types."""
|
|
result = simple_command._get_iterable_length([1, 2, 3, 4])
|
|
|
|
assert result == 4
|
|
|
|
def test_returns_none_for_unsized_iterables(self, simple_command: SimpleCommand):
|
|
"""Should return None for generators and other iterables without len()."""
|
|
result = simple_command._get_iterable_length(x for x in [1, 2, 3])
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestTrack:
|
|
"""Tests for the track() method."""
|
|
|
|
def test_with_progress_bar_disabled(self, simple_command: SimpleCommand):
|
|
simple_command.no_progress_bar = True
|
|
items = ["a", "b", "c"]
|
|
|
|
result = list(simple_command.track(items, description="Test..."))
|
|
|
|
assert result == items
|
|
|
|
def test_with_progress_bar_enabled(self, simple_command: SimpleCommand):
|
|
simple_command.no_progress_bar = False
|
|
items = [1, 2, 3]
|
|
|
|
result = list(simple_command.track(items, description="Processing..."))
|
|
|
|
assert result == items
|
|
|
|
def test_with_explicit_total(self, simple_command: SimpleCommand):
|
|
simple_command.no_progress_bar = False
|
|
|
|
def gen():
|
|
yield from [1, 2, 3]
|
|
|
|
result = list(simple_command.track(gen(), total=3))
|
|
|
|
assert result == [1, 2, 3]
|
|
|
|
def test_with_generator_no_total(self, simple_command: SimpleCommand):
|
|
def gen():
|
|
yield from [1, 2, 3]
|
|
|
|
result = list(simple_command.track(gen()))
|
|
|
|
assert result == [1, 2, 3]
|
|
|
|
def test_empty_iterable(self, simple_command: SimpleCommand):
|
|
result = list(simple_command.track([]))
|
|
|
|
assert result == []
|
|
|
|
def test_uses_queryset_count(
|
|
self,
|
|
simple_command: SimpleCommand,
|
|
mock_queryset,
|
|
mocker: MockerFixture,
|
|
):
|
|
"""Verify track() uses .count() for querysets."""
|
|
simple_command.no_progress_bar = False
|
|
queryset = mock_queryset([1, 2, 3])
|
|
|
|
spy = mocker.spy(simple_command, "_get_iterable_length")
|
|
|
|
result = list(simple_command.track(queryset))
|
|
|
|
assert result == [1, 2, 3]
|
|
spy.assert_called_once_with(queryset)
|
|
assert queryset.count_called is True
|
|
|
|
|
|
@pytest.mark.management
|
|
class TestProcessParallel:
|
|
"""Tests for the process_parallel() method."""
|
|
|
|
def test_sequential_processing_single_process(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
):
|
|
multiprocess_command.process_count = 1
|
|
items = [1, 2, 3, 4, 5]
|
|
|
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
|
|
|
assert len(results) == 5
|
|
assert all(r.success for r in results)
|
|
|
|
result_map = {r.item: r.result for r in results}
|
|
assert result_map == {1: 2, 2: 4, 3: 6, 4: 8, 5: 10}
|
|
|
|
def test_sequential_processing_handles_errors(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
):
|
|
multiprocess_command.process_count = 1
|
|
items = [1, 2, 0, 4] # 0 causes ZeroDivisionError
|
|
|
|
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
|
|
|
assert len(results) == 4
|
|
|
|
successes = [r for r in results if r.success]
|
|
failures = [r for r in results if not r.success]
|
|
|
|
assert len(successes) == 3
|
|
assert len(failures) == 1
|
|
assert failures[0].item == 0
|
|
assert isinstance(failures[0].error, ZeroDivisionError)
|
|
|
|
def test_parallel_closes_db_connections(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
mocker: MockerFixture,
|
|
):
|
|
multiprocess_command.process_count = 2
|
|
items = [1, 2, 3]
|
|
|
|
mock_connections = mocker.patch(
|
|
"documents.management.commands.base.db.connections",
|
|
)
|
|
|
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
|
|
|
mock_connections.close_all.assert_called_once()
|
|
assert len(results) == 3
|
|
|
|
def test_parallel_processing_handles_errors(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
mocker: MockerFixture,
|
|
):
|
|
multiprocess_command.process_count = 2
|
|
items = [1, 2, 0, 4]
|
|
|
|
mocker.patch("documents.management.commands.base.db.connections")
|
|
|
|
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
|
|
|
failures = [r for r in results if not r.success]
|
|
assert len(failures) == 1
|
|
assert failures[0].item == 0
|
|
|
|
def test_empty_items(self, multiprocess_command: MultiprocessCommand):
|
|
results = list(multiprocess_command.process_parallel(_double_value, []))
|
|
|
|
assert results == []
|
|
|
|
def test_result_contains_original_item(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
):
|
|
items = [10, 20, 30]
|
|
|
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
|
|
|
for result in results:
|
|
assert result.item in items
|
|
assert result.result == result.item * 2
|
|
|
|
def test_sequential_path_used_for_single_process(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
mocker: MockerFixture,
|
|
):
|
|
"""Verify single process uses sequential path (important for testing)."""
|
|
multiprocess_command.process_count = 1
|
|
|
|
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
|
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
|
|
|
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
|
|
|
spy_sequential.assert_called_once()
|
|
spy_parallel.assert_not_called()
|
|
|
|
def test_parallel_path_used_for_multiple_processes(
|
|
self,
|
|
multiprocess_command: MultiprocessCommand,
|
|
mocker: MockerFixture,
|
|
):
|
|
"""Verify multiple processes uses parallel path."""
|
|
multiprocess_command.process_count = 2
|
|
|
|
mocker.patch("documents.management.commands.base.db.connections")
|
|
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
|
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
|
|
|
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
|
|
|
spy_parallel.assert_called_once()
|
|
spy_sequential.assert_not_called()
|