mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-02-26 01:09:34 -06:00
Initial conversion to new base class
This commit is contained in:
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"django-filter~=25.1",
|
||||
"django-guardian~=3.2.0",
|
||||
"django-multiselectfield~=1.0.1",
|
||||
"django-rich~=2.2.0",
|
||||
"django-soft-delete~=1.0.18",
|
||||
"django-treenode>=0.23.2",
|
||||
"djangorestframework~=3.16",
|
||||
@@ -76,7 +77,6 @@ dependencies = [
|
||||
"setproctitle~=1.3.4",
|
||||
"tika-client~=0.10.0",
|
||||
"torch~=2.10.0",
|
||||
"tqdm~=4.67.1",
|
||||
"watchfiles>=1.1.1",
|
||||
"whitenoise~=6.11",
|
||||
"whoosh-reloaded>=2.7.5",
|
||||
@@ -304,6 +304,7 @@ markers = [
|
||||
"tika: Tests requiring Tika service",
|
||||
"greenmail: Tests requiring Greenmail service",
|
||||
"date_parsing: Tests which cover date parsing from content or filename",
|
||||
"management: Tests which cover management commands/functionality",
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
|
||||
316
src/documents/management/commands/base.py
Normal file
316
src/documents/management/commands/base.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Base command class for Paperless-ngx management commands.
|
||||
|
||||
Provides automatic progress bar and multiprocessing support with minimal boilerplate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from concurrent.futures import as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import ClassVar
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from django import db
|
||||
from django.core.management import CommandError
|
||||
from django_rich.management import RichCommand
|
||||
from rich.progress import BarColumn
|
||||
from rich.progress import MofNCompleteColumn
|
||||
from rich.progress import Progress
|
||||
from rich.progress import SpinnerColumn
|
||||
from rich.progress import TextColumn
|
||||
from rich.progress import TimeElapsedColumn
|
||||
from rich.progress import TimeRemainingColumn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import ArgumentParser
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Sequence
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProcessResult(Generic[T, R]):
|
||||
"""
|
||||
Result of processing a single item in parallel.
|
||||
|
||||
Attributes:
|
||||
item: The input item that was processed.
|
||||
result: The return value from the processing function, or None if an error occurred.
|
||||
error: The exception if processing failed, or None on success.
|
||||
"""
|
||||
|
||||
item: T
|
||||
result: R | None
|
||||
error: BaseException | None
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""Return True if the item was processed successfully."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
class PaperlessCommand(RichCommand):
|
||||
"""
|
||||
Base command class with automatic progress bar and multiprocessing support.
|
||||
|
||||
Features are opt-in via class attributes:
|
||||
supports_progress_bar: Adds --no-progress-bar argument (default: True)
|
||||
supports_multiprocessing: Adds --processes argument (default: False)
|
||||
|
||||
Example usage:
|
||||
|
||||
class Command(PaperlessCommand):
|
||||
help = "Process all documents"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
documents = Document.objects.all()
|
||||
for doc in self.track(documents, description="Processing..."):
|
||||
process_document(doc)
|
||||
|
||||
class Command(PaperlessCommand):
|
||||
help = "Regenerate thumbnails"
|
||||
supports_multiprocessing = True
|
||||
|
||||
def handle(self, *args, **options):
|
||||
ids = list(Document.objects.values_list("id", flat=True))
|
||||
for result in self.process_parallel(process_doc, ids):
|
||||
if result.error:
|
||||
self.console.print(f"[red]Failed: {result.error}[/red]")
|
||||
"""
|
||||
|
||||
supports_progress_bar: ClassVar[bool] = True
|
||||
supports_multiprocessing: ClassVar[bool] = False
|
||||
|
||||
# Instance attributes set by execute() before handle() runs
|
||||
no_progress_bar: bool
|
||||
process_count: int
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add arguments based on supported features."""
|
||||
super().add_arguments(parser)
|
||||
|
||||
if self.supports_progress_bar:
|
||||
parser.add_argument(
|
||||
"--no-progress-bar",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Disable the progress bar",
|
||||
)
|
||||
|
||||
if self.supports_multiprocessing:
|
||||
default_processes = max(1, (os.cpu_count() or 1) // 4)
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
default=default_processes,
|
||||
type=int,
|
||||
help=f"Number of processes to use (default: {default_processes})",
|
||||
)
|
||||
|
||||
def execute(self, *args, **options) -> int:
|
||||
"""
|
||||
Set up instance state before handle() is called.
|
||||
|
||||
This is called by Django's command infrastructure after argument parsing
|
||||
but before handle(). We use it to set instance attributes from options.
|
||||
"""
|
||||
# Set progress bar state
|
||||
if self.supports_progress_bar:
|
||||
self.no_progress_bar = options.get("no_progress_bar", False)
|
||||
else:
|
||||
self.no_progress_bar = True
|
||||
|
||||
# Set multiprocessing state
|
||||
if self.supports_multiprocessing:
|
||||
self.process_count = options.get("processes", 1)
|
||||
if self.process_count < 1:
|
||||
raise CommandError("--processes must be at least 1")
|
||||
else:
|
||||
self.process_count = 1
|
||||
|
||||
return super().execute(*args, **options)
|
||||
|
||||
def _create_progress(self, description: str) -> Progress:
|
||||
"""
|
||||
Create a configured Progress instance.
|
||||
|
||||
Args:
|
||||
description: Text to display alongside the progress bar.
|
||||
|
||||
Returns:
|
||||
A Progress instance configured with appropriate columns.
|
||||
"""
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
transient=False,
|
||||
)
|
||||
|
||||
def _get_iterable_length(self, iterable: Iterable[T]) -> int | None:
|
||||
"""
|
||||
Attempt to determine the length of an iterable without consuming it.
|
||||
|
||||
Tries .count() first (for Django querysets - executes SELECT COUNT(*)),
|
||||
then falls back to len() for sequences.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to measure.
|
||||
|
||||
Returns:
|
||||
The length if determinable, None otherwise.
|
||||
"""
|
||||
# Django querysets have .count() which is a SELECT COUNT(*)
|
||||
# This is much more efficient than len() which evaluates the queryset
|
||||
# Note: list.count(value) requires an argument, so we catch TypeError
|
||||
if hasattr(iterable, "count") and callable(iterable.count):
|
||||
try:
|
||||
return iterable.count()
|
||||
except TypeError:
|
||||
# list.count() requires an argument, fall through to len()
|
||||
pass
|
||||
|
||||
# Fall back to len() for sequences
|
||||
try:
|
||||
return len(iterable) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
def track(
|
||||
self,
|
||||
iterable: Iterable[T],
|
||||
*,
|
||||
description: str = "Processing...",
|
||||
total: int | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Iterate over items with an optional progress bar.
|
||||
|
||||
Respects --no-progress-bar flag. When disabled, simply yields items
|
||||
without any progress display.
|
||||
|
||||
Args:
|
||||
iterable: The items to iterate over.
|
||||
description: Text to display alongside the progress bar.
|
||||
total: Total number of items. If None, attempts to determine
|
||||
automatically via .count() (for querysets) or len().
|
||||
|
||||
Yields:
|
||||
Items from the iterable.
|
||||
|
||||
Example:
|
||||
for doc in self.track(documents, description="Renaming..."):
|
||||
process(doc)
|
||||
"""
|
||||
if self.no_progress_bar:
|
||||
yield from iterable
|
||||
return
|
||||
|
||||
# Attempt to determine total if not provided
|
||||
if total is None:
|
||||
total = self._get_iterable_length(iterable)
|
||||
|
||||
with self._create_progress(description) as progress:
|
||||
task_id = progress.add_task(description, total=total)
|
||||
for item in iterable:
|
||||
yield item
|
||||
progress.advance(task_id)
|
||||
|
||||
def process_parallel(
|
||||
self,
|
||||
fn: Callable[[T], R],
|
||||
items: Sequence[T],
|
||||
*,
|
||||
description: str = "Processing...",
|
||||
) -> Generator[ProcessResult[T, R], None, None]:
|
||||
"""
|
||||
Process items in parallel with progress tracking.
|
||||
|
||||
When --processes=1, runs sequentially in the main process without
|
||||
spawning subprocesses. This is critical for testing, as multiprocessing
|
||||
breaks fixtures, mocks, and database transactions.
|
||||
|
||||
When --processes > 1, uses ProcessPoolExecutor and automatically closes
|
||||
database connections before spawning workers (required for PostgreSQL).
|
||||
|
||||
Args:
|
||||
fn: Function to apply to each item. Must be picklable for parallel
|
||||
execution (i.e., defined at module level, not a lambda or closure).
|
||||
items: Sequence of items to process.
|
||||
description: Text to display alongside the progress bar.
|
||||
|
||||
Yields:
|
||||
ProcessResult for each item, containing the item, result, and any error.
|
||||
|
||||
Example:
|
||||
def regenerate_thumbnail(doc_id: int) -> Path:
|
||||
...
|
||||
|
||||
for result in self.process_parallel(regenerate_thumbnail, doc_ids):
|
||||
if result.error:
|
||||
self.console.print(f"[red]Failed {result.item}[/red]")
|
||||
"""
|
||||
total = len(items)
|
||||
|
||||
if self.process_count == 1:
|
||||
# Sequential execution in main process - critical for testing
|
||||
yield from self._process_sequential(fn, items, description, total)
|
||||
else:
|
||||
# Parallel execution with ProcessPoolExecutor
|
||||
yield from self._process_parallel(fn, items, description, total)
|
||||
|
||||
def _process_sequential(
|
||||
self,
|
||||
fn: Callable[[T], R],
|
||||
items: Sequence[T],
|
||||
description: str,
|
||||
total: int,
|
||||
) -> Generator[ProcessResult[T, R], None, None]:
|
||||
"""Process items sequentially in the main process."""
|
||||
for item in self.track(items, description=description, total=total):
|
||||
try:
|
||||
result = fn(item)
|
||||
yield ProcessResult(item=item, result=result, error=None)
|
||||
except Exception as e:
|
||||
yield ProcessResult(item=item, result=None, error=e)
|
||||
|
||||
def _process_parallel(
|
||||
self,
|
||||
fn: Callable[[T], R],
|
||||
items: Sequence[T],
|
||||
description: str,
|
||||
total: int,
|
||||
) -> Generator[ProcessResult[T, R], None, None]:
|
||||
"""Process items in parallel using ProcessPoolExecutor."""
|
||||
# Close database connections before forking - required for PostgreSQL
|
||||
db.connections.close_all()
|
||||
|
||||
with self._create_progress(description) as progress:
|
||||
task_id = progress.add_task(description, total=total)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self.process_count) as executor:
|
||||
# Submit all tasks and map futures back to items
|
||||
future_to_item = {executor.submit(fn, item): item for item in items}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_item):
|
||||
item = future_to_item[future]
|
||||
try:
|
||||
result = future.result()
|
||||
yield ProcessResult(item=item, result=result, error=None)
|
||||
except Exception as e:
|
||||
yield ProcessResult(item=item, result=None, error=e)
|
||||
finally:
|
||||
progress.advance(task_id)
|
||||
@@ -1,20 +1,15 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
||||
import tqdm
|
||||
from django import db
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from documents.management.commands.mixins import MultiProcessMixin
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.models import Document
|
||||
from documents.tasks import update_document_content_maybe_archive_file
|
||||
|
||||
logger = logging.getLogger("paperless.management.archiver")
|
||||
|
||||
|
||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
class Command(PaperlessCommand):
|
||||
help = (
|
||||
"Using the current classification model, assigns correspondents, tags "
|
||||
"and document types to all documents, effectively allowing you to "
|
||||
@@ -22,7 +17,10 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
"modified) after their initial import."
|
||||
)
|
||||
|
||||
supports_multiprocessing = True
|
||||
|
||||
def add_arguments(self, parser):
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--overwrite",
|
||||
@@ -44,13 +42,8 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
"run on this specific document."
|
||||
),
|
||||
)
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
self.add_argument_processes_mixin(parser)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_processes_mixin(**options)
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
|
||||
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
overwrite = options["overwrite"]
|
||||
@@ -60,35 +53,21 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
else:
|
||||
documents = Document.objects.all()
|
||||
|
||||
document_ids = list(
|
||||
map(
|
||||
lambda doc: doc.id,
|
||||
filter(lambda d: overwrite or not d.has_archive_version, documents),
|
||||
),
|
||||
)
|
||||
|
||||
# Note to future self: this prevents django from reusing database
|
||||
# connections between processes, which is bad and does not work
|
||||
# with postgres.
|
||||
db.connections.close_all()
|
||||
document_ids = [
|
||||
doc.id for doc in documents if overwrite or not doc.has_archive_version
|
||||
]
|
||||
|
||||
try:
|
||||
logging.getLogger().handlers[0].level = logging.ERROR
|
||||
|
||||
if self.process_count == 1:
|
||||
for doc_id in document_ids:
|
||||
update_document_content_maybe_archive_file(doc_id)
|
||||
else: # pragma: no cover
|
||||
with multiprocessing.Pool(self.process_count) as pool:
|
||||
list(
|
||||
tqdm.tqdm(
|
||||
pool.imap_unordered(
|
||||
update_document_content_maybe_archive_file,
|
||||
document_ids,
|
||||
),
|
||||
total=len(document_ids),
|
||||
disable=self.no_progress_bar,
|
||||
),
|
||||
for result in self.process_parallel(
|
||||
update_document_content_maybe_archive_file,
|
||||
document_ids,
|
||||
description="Archiving...",
|
||||
):
|
||||
if result.error:
|
||||
self.console.print(
|
||||
f"[red]Failed document {result.item}: {result.error}[/red]",
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
self.stdout.write(self.style.NOTICE("Aborting..."))
|
||||
self.console.print("[yellow]Aborting...[/yellow]")
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
from typing import Final
|
||||
|
||||
import rapidfuzz
|
||||
import tqdm
|
||||
from django.core.management import BaseCommand
|
||||
from django.core.management import CommandError
|
||||
|
||||
from documents.management.commands.mixins import MultiProcessMixin
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.models import Document
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass(frozen=True, slots=True)
|
||||
class _WorkPackage:
|
||||
first_doc: Document
|
||||
second_doc: Document
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass(frozen=True, slots=True)
|
||||
class _WorkResult:
|
||||
doc_one_pk: int
|
||||
doc_two_pk: int
|
||||
@@ -31,22 +27,23 @@ class _WorkResult:
|
||||
def _process_and_match(work: _WorkPackage) -> _WorkResult:
|
||||
"""
|
||||
Does basic processing of document content, gets the basic ratio
|
||||
and returns the result package
|
||||
and returns the result package.
|
||||
"""
|
||||
# Normalize the string some, lower case, whitespace, etc
|
||||
first_string = rapidfuzz.utils.default_process(work.first_doc.content)
|
||||
second_string = rapidfuzz.utils.default_process(work.second_doc.content)
|
||||
|
||||
# Basic matching ratio
|
||||
match = rapidfuzz.fuzz.ratio(first_string, second_string)
|
||||
|
||||
return _WorkResult(work.first_doc.pk, work.second_doc.pk, match)
|
||||
|
||||
|
||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
class Command(PaperlessCommand):
|
||||
help = "Searches for documents where the content almost matches"
|
||||
|
||||
supports_multiprocessing = True
|
||||
|
||||
def add_arguments(self, parser):
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--ratio",
|
||||
default=85.0,
|
||||
@@ -59,91 +56,72 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
action="store_true",
|
||||
help="If set, one document of matches above the ratio WILL BE DELETED",
|
||||
)
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
self.add_argument_processes_mixin(parser)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
RATIO_MIN: Final[float] = 0.0
|
||||
RATIO_MAX: Final[float] = 100.0
|
||||
|
||||
self.handle_processes_mixin(**options)
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
|
||||
if options["delete"]:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
"The command is configured to delete documents. Use with caution",
|
||||
),
|
||||
self.console.print(
|
||||
"[yellow]The command is configured to delete documents. "
|
||||
"Use with caution.[/yellow]",
|
||||
)
|
||||
|
||||
opt_ratio = options["ratio"]
|
||||
checked_pairs: set[tuple[int, int]] = set()
|
||||
work_pkgs: list[_WorkPackage] = []
|
||||
|
||||
# Ratio is a float from 0.0 to 100.0
|
||||
if opt_ratio < RATIO_MIN or opt_ratio > RATIO_MAX:
|
||||
raise CommandError("The ratio must be between 0 and 100")
|
||||
|
||||
all_docs = Document.objects.all().order_by("id")
|
||||
|
||||
# Build work packages for processing
|
||||
for first_doc in all_docs:
|
||||
for second_doc in all_docs:
|
||||
# doc to doc is obviously not useful
|
||||
if first_doc.pk == second_doc.pk:
|
||||
continue
|
||||
# Skip empty documents (e.g. password-protected)
|
||||
if first_doc.content.strip() == "" or second_doc.content.strip() == "":
|
||||
continue
|
||||
# Skip matching which have already been matched together
|
||||
# doc 1 to doc 2 is the same as doc 2 to doc 1
|
||||
doc_1_to_doc_2 = (first_doc.pk, second_doc.pk)
|
||||
doc_2_to_doc_1 = doc_1_to_doc_2[::-1]
|
||||
if doc_1_to_doc_2 in checked_pairs or doc_2_to_doc_1 in checked_pairs:
|
||||
continue
|
||||
checked_pairs.update([doc_1_to_doc_2, doc_2_to_doc_1])
|
||||
# Actually something useful to work on now
|
||||
work_pkgs.append(_WorkPackage(first_doc, second_doc))
|
||||
|
||||
# Don't spin up a pool of 1 process
|
||||
if self.process_count == 1:
|
||||
results = []
|
||||
for work in tqdm.tqdm(work_pkgs, disable=self.no_progress_bar):
|
||||
results.append(_process_and_match(work))
|
||||
else: # pragma: no cover
|
||||
with multiprocessing.Pool(processes=self.process_count) as pool:
|
||||
results = list(
|
||||
tqdm.tqdm(
|
||||
pool.imap_unordered(_process_and_match, work_pkgs),
|
||||
total=len(work_pkgs),
|
||||
disable=self.no_progress_bar,
|
||||
),
|
||||
results: list[_WorkResult] = []
|
||||
for result in self.process_parallel(
|
||||
_process_and_match,
|
||||
work_pkgs,
|
||||
description="Matching...",
|
||||
):
|
||||
if result.error:
|
||||
self.console.print(
|
||||
f"[red]Failed: {result.error}[/red]",
|
||||
)
|
||||
elif result.result is not None:
|
||||
results.append(result.result)
|
||||
|
||||
# Check results
|
||||
messages = []
|
||||
maybe_delete_ids = []
|
||||
for result in sorted(results):
|
||||
if result.ratio >= opt_ratio:
|
||||
messages: list[str] = []
|
||||
maybe_delete_ids: list[int] = []
|
||||
for match_result in sorted(results):
|
||||
if match_result.ratio >= opt_ratio:
|
||||
messages.append(
|
||||
self.style.NOTICE(
|
||||
f"Document {result.doc_one_pk} fuzzy match"
|
||||
f" to {result.doc_two_pk} (confidence {result.ratio:.3f})\n",
|
||||
f"Document {match_result.doc_one_pk} fuzzy match"
|
||||
f" to {match_result.doc_two_pk}"
|
||||
f" (confidence {match_result.ratio:.3f})\n",
|
||||
),
|
||||
)
|
||||
maybe_delete_ids.append(result.doc_two_pk)
|
||||
maybe_delete_ids.append(match_result.doc_two_pk)
|
||||
|
||||
if len(messages) == 0:
|
||||
messages.append(
|
||||
self.style.SUCCESS("No matches found\n"),
|
||||
)
|
||||
self.stdout.writelines(
|
||||
messages,
|
||||
)
|
||||
messages.append(self.style.SUCCESS("No matches found\n"))
|
||||
self.stdout.writelines(messages)
|
||||
|
||||
if options["delete"]:
|
||||
self.stdout.write(
|
||||
self.style.NOTICE(
|
||||
f"Deleting {len(maybe_delete_ids)} documents based on ratio matches",
|
||||
),
|
||||
self.console.print(
|
||||
f"[yellow]Deleting {len(maybe_delete_ids)} documents "
|
||||
f"based on ratio matches[/yellow]",
|
||||
)
|
||||
Document.objects.filter(pk__in=maybe_delete_ids).delete()
|
||||
|
||||
@@ -1,25 +1,12 @@
|
||||
import logging
|
||||
|
||||
import tqdm
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db.models.signals import post_save
|
||||
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.models import Document
|
||||
|
||||
|
||||
class Command(ProgressBarMixin, BaseCommand):
|
||||
help = "This will rename all documents to match the latest filename format."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
class Command(PaperlessCommand):
|
||||
help = "Rename all documents"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
logging.getLogger().handlers[0].level = logging.ERROR
|
||||
|
||||
for document in tqdm.tqdm(
|
||||
Document.objects.all(),
|
||||
disable=self.no_progress_bar,
|
||||
):
|
||||
for document in self.track(Document.objects.all(), description="Renaming..."):
|
||||
post_save.send(Document, instance=document, created=False)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import logging
|
||||
|
||||
import tqdm
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from documents.classifier import load_classifier
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.models import Document
|
||||
from documents.signals.handlers import set_correspondent
|
||||
from documents.signals.handlers import set_document_type
|
||||
@@ -14,7 +11,7 @@ from documents.signals.handlers import set_tags
|
||||
logger = logging.getLogger("paperless.management.retagger")
|
||||
|
||||
|
||||
class Command(ProgressBarMixin, BaseCommand):
|
||||
class Command(PaperlessCommand):
|
||||
help = (
|
||||
"Using the current classification model, assigns correspondents, tags "
|
||||
"and document types to all documents, effectively allowing you to "
|
||||
@@ -23,6 +20,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
)
|
||||
|
||||
def add_arguments(self, parser):
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument("-c", "--correspondent", default=False, action="store_true")
|
||||
parser.add_argument("-T", "--tags", default=False, action="store_true")
|
||||
parser.add_argument("-t", "--document_type", default=False, action="store_true")
|
||||
@@ -34,7 +32,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
action="store_true",
|
||||
help=(
|
||||
"By default this command won't try to assign a correspondent "
|
||||
"if more than one matches the document. Use this flag if "
|
||||
"if more than one matches the document. Use this flag if "
|
||||
"you'd rather it just pick the first one it finds."
|
||||
),
|
||||
)
|
||||
@@ -49,7 +47,6 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
"and tags that do not match anymore due to changed rules."
|
||||
),
|
||||
)
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
parser.add_argument(
|
||||
"--suggest",
|
||||
default=False,
|
||||
@@ -68,8 +65,6 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
|
||||
if options["inbox_only"]:
|
||||
queryset = Document.objects.filter(tags__is_inbox_tag=True)
|
||||
else:
|
||||
@@ -84,7 +79,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
for document in tqdm.tqdm(documents, disable=self.no_progress_bar):
|
||||
for document in self.track(documents, description="Retagging..."):
|
||||
if options["correspondent"]:
|
||||
set_correspondent(
|
||||
sender=None,
|
||||
@@ -122,6 +117,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
stdout=self.stdout,
|
||||
style_func=self.style,
|
||||
)
|
||||
|
||||
if options["storage_path"]:
|
||||
set_storage_path(
|
||||
sender=None,
|
||||
|
||||
@@ -1,43 +1,39 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import shutil
|
||||
|
||||
import tqdm
|
||||
from django import db
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from documents.management.commands.mixins import MultiProcessMixin
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.models import Document
|
||||
from documents.parsers import get_parser_class_for_mime_type
|
||||
|
||||
|
||||
def _process_document(doc_id) -> None:
|
||||
def _process_document(doc_id: int) -> None:
|
||||
document: Document = Document.objects.get(id=doc_id)
|
||||
parser_class = get_parser_class_for_mime_type(document.mime_type)
|
||||
|
||||
if parser_class:
|
||||
parser = parser_class(logging_group=None)
|
||||
else:
|
||||
if parser_class is None:
|
||||
print(f"{document} No parser for mime type {document.mime_type}") # noqa: T201
|
||||
return
|
||||
|
||||
parser = parser_class(logging_group=None)
|
||||
|
||||
try:
|
||||
thumb = parser.get_thumbnail(
|
||||
document.source_path,
|
||||
document.mime_type,
|
||||
document.get_public_filename(),
|
||||
)
|
||||
|
||||
shutil.move(thumb, document.thumbnail_path)
|
||||
finally:
|
||||
parser.cleanup()
|
||||
|
||||
|
||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
class Command(PaperlessCommand):
|
||||
help = "This will regenerate the thumbnails for all documents."
|
||||
|
||||
supports_multiprocessing = True
|
||||
|
||||
def add_arguments(self, parser) -> None:
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--document",
|
||||
@@ -49,36 +45,23 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
||||
"run on this specific document."
|
||||
),
|
||||
)
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
self.add_argument_processes_mixin(parser)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
logging.getLogger().handlers[0].level = logging.ERROR
|
||||
|
||||
self.handle_processes_mixin(**options)
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
|
||||
if options["document"]:
|
||||
documents = Document.objects.filter(pk=options["document"])
|
||||
else:
|
||||
documents = Document.objects.all()
|
||||
|
||||
ids = [doc.id for doc in documents]
|
||||
ids = list(documents.values_list("id", flat=True))
|
||||
|
||||
# Note to future self: this prevents django from reusing database
|
||||
# connections between processes, which is bad and does not work
|
||||
# with postgres.
|
||||
db.connections.close_all()
|
||||
|
||||
if self.process_count == 1:
|
||||
for doc_id in ids:
|
||||
_process_document(doc_id)
|
||||
else: # pragma: no cover
|
||||
with multiprocessing.Pool(processes=self.process_count) as pool:
|
||||
list(
|
||||
tqdm.tqdm(
|
||||
pool.imap_unordered(_process_document, ids),
|
||||
total=len(ids),
|
||||
disable=self.no_progress_bar,
|
||||
),
|
||||
for result in self.process_parallel(
|
||||
_process_document,
|
||||
ids,
|
||||
description="Regenerating thumbnails...",
|
||||
):
|
||||
if result.error:
|
||||
self.console.print(
|
||||
f"[red]Failed document {result.item}: {result.error}[/red]",
|
||||
)
|
||||
|
||||
@@ -1,27 +1,21 @@
|
||||
from auditlog.models import LogEntry
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from tqdm import tqdm
|
||||
|
||||
from documents.management.commands.mixins import ProgressBarMixin
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
|
||||
|
||||
class Command(BaseCommand, ProgressBarMixin):
|
||||
"""
|
||||
Prune the audit logs of objects that no longer exist.
|
||||
"""
|
||||
class Command(PaperlessCommand):
|
||||
"""Prune the audit logs of objects that no longer exist."""
|
||||
|
||||
help = "Prunes the audit logs of objects that no longer exist."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
self.add_argument_progress_bar_mixin(parser)
|
||||
|
||||
def handle(self, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
def handle(self, *args, **options):
|
||||
with transaction.atomic():
|
||||
for log_entry in tqdm(LogEntry.objects.all(), disable=self.no_progress_bar):
|
||||
for log_entry in self.track(
|
||||
LogEntry.objects.all(),
|
||||
description="Pruning audit logs...",
|
||||
):
|
||||
model_class = log_entry.content_type.model_class()
|
||||
# use global_objects for SoftDeleteModel
|
||||
objects = (
|
||||
model_class.global_objects
|
||||
if hasattr(model_class, "global_objects")
|
||||
@@ -32,8 +26,8 @@ class Command(BaseCommand, ProgressBarMixin):
|
||||
and not objects.filter(pk=log_entry.object_id).exists()
|
||||
):
|
||||
log_entry.delete()
|
||||
tqdm.write(
|
||||
self.style.NOTICE(
|
||||
f"Deleted audit log entry for {model_class.__name__} #{log_entry.object_id}",
|
||||
),
|
||||
self.console.print(
|
||||
f"Deleted audit log entry for "
|
||||
f"{model_class.__name__} #{log_entry.object_id}",
|
||||
style="yellow",
|
||||
)
|
||||
|
||||
0
src/documents/tests/management/__init__.py
Normal file
0
src/documents/tests/management/__init__.py
Normal file
579
src/documents/tests/management/test_management_base_cmd.py
Normal file
579
src/documents/tests/management/test_management_base_cmd.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for PaperlessCommand base class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import ClassVar
|
||||
|
||||
import pytest
|
||||
from django.core.management import CommandError
|
||||
from rich.console import Console
|
||||
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.management.commands.base import ProcessResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
# --- Test Commands ---
|
||||
# These simulate real command implementations for testing
|
||||
|
||||
|
||||
class SimpleCommand(PaperlessCommand):
|
||||
"""Command with default settings (progress bar, no multiprocessing)."""
|
||||
|
||||
help = "Simple test command"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
items = list(range(5))
|
||||
results = []
|
||||
for item in self.track(items, description="Processing..."):
|
||||
results.append(item * 2)
|
||||
self.stdout.write(f"Results: {results}")
|
||||
|
||||
|
||||
class NoProgressBarCommand(PaperlessCommand):
|
||||
"""Command with progress bar disabled."""
|
||||
|
||||
help = "No progress bar command"
|
||||
supports_progress_bar = False
|
||||
|
||||
def handle(self, *args, **options):
|
||||
items = list(range(3))
|
||||
for item in self.track(items):
|
||||
pass
|
||||
self.stdout.write("Done")
|
||||
|
||||
|
||||
class MultiprocessCommand(PaperlessCommand):
|
||||
"""Command with multiprocessing support."""
|
||||
|
||||
help = "Multiprocess test command"
|
||||
supports_multiprocessing = True
|
||||
|
||||
def handle(self, *args, **options):
|
||||
items = list(range(5))
|
||||
results = []
|
||||
for result in self.process_parallel(
|
||||
_double_value,
|
||||
items,
|
||||
description="Processing...",
|
||||
):
|
||||
results.append(result)
|
||||
successes = sum(1 for r in results if r.success)
|
||||
self.stdout.write(f"Successes: {successes}")
|
||||
|
||||
|
||||
# --- Helper Functions for Multiprocessing ---
|
||||
# Must be at module level to be picklable
|
||||
|
||||
|
||||
def _double_value(x: int) -> int:
|
||||
"""Double the input value."""
|
||||
return x * 2
|
||||
|
||||
|
||||
def _divide_ten_by(x: int) -> float:
|
||||
"""Divide 10 by x. Raises ZeroDivisionError if x is 0."""
|
||||
return 10 / x
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def console() -> Console:
|
||||
"""Create a non-interactive console for testing."""
|
||||
return Console(force_terminal=False, force_interactive=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_command(console: Console) -> SimpleCommand:
|
||||
"""Create a SimpleCommand instance configured for testing."""
|
||||
command = SimpleCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
command.console = console
|
||||
command.no_progress_bar = True
|
||||
command.process_count = 1
|
||||
return command
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiprocess_command(console: Console) -> MultiprocessCommand:
|
||||
"""Create a MultiprocessCommand instance configured for testing."""
|
||||
command = MultiprocessCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
command.console = console
|
||||
command.no_progress_bar = True
|
||||
command.process_count = 1
|
||||
return command
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queryset():
|
||||
"""
|
||||
Create a mock Django QuerySet that tracks method calls.
|
||||
|
||||
This verifies we use .count() instead of len() for querysets.
|
||||
"""
|
||||
|
||||
class MockQuerySet:
|
||||
def __init__(self, items: list):
|
||||
self._items = items
|
||||
self.count_called = False
|
||||
|
||||
def count(self) -> int:
|
||||
self.count_called = True
|
||||
return len(self._items)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._items)
|
||||
|
||||
def __len__(self):
|
||||
raise AssertionError("len() should not be called on querysets")
|
||||
|
||||
return MockQuerySet
|
||||
|
||||
|
||||
# --- Test Classes ---
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestProcessResult:
|
||||
"""Tests for the ProcessResult dataclass."""
|
||||
|
||||
def test_success_result(self):
|
||||
result = ProcessResult(item=1, result=2, error=None)
|
||||
|
||||
assert result.item == 1
|
||||
assert result.result == 2
|
||||
assert result.error is None
|
||||
assert result.success is True
|
||||
|
||||
def test_error_result(self):
|
||||
error = ValueError("test error")
|
||||
result = ProcessResult(item=1, result=None, error=error)
|
||||
|
||||
assert result.item == 1
|
||||
assert result.result is None
|
||||
assert result.error is error
|
||||
assert result.success is False
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestPaperlessCommandArguments:
|
||||
"""Tests for argument parsing behavior."""
|
||||
|
||||
def test_progress_bar_argument_added_by_default(self):
|
||||
command = SimpleCommand()
|
||||
parser = command.create_parser("manage.py", "simple")
|
||||
|
||||
options = parser.parse_args(["--no-progress-bar"])
|
||||
assert options.no_progress_bar is True
|
||||
|
||||
options = parser.parse_args([])
|
||||
assert options.no_progress_bar is False
|
||||
|
||||
def test_progress_bar_argument_not_added_when_disabled(self):
|
||||
command = NoProgressBarCommand()
|
||||
parser = command.create_parser("manage.py", "noprogress")
|
||||
|
||||
options = parser.parse_args([])
|
||||
assert not hasattr(options, "no_progress_bar")
|
||||
|
||||
def test_processes_argument_added_when_multiprocessing_enabled(self):
|
||||
command = MultiprocessCommand()
|
||||
parser = command.create_parser("manage.py", "multiprocess")
|
||||
|
||||
options = parser.parse_args(["--processes", "4"])
|
||||
assert options.processes == 4
|
||||
|
||||
options = parser.parse_args([])
|
||||
assert options.processes >= 1
|
||||
|
||||
def test_processes_argument_not_added_when_multiprocessing_disabled(self):
|
||||
command = SimpleCommand()
|
||||
parser = command.create_parser("manage.py", "simple")
|
||||
|
||||
options = parser.parse_args([])
|
||||
assert not hasattr(options, "processes")
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestPaperlessCommandExecute:
|
||||
"""Tests for the execute() setup behavior."""
|
||||
|
||||
@pytest.fixture
|
||||
def base_options(self) -> dict:
|
||||
"""Base options required for execute()."""
|
||||
return {
|
||||
"verbosity": 1,
|
||||
"no_color": True,
|
||||
"force_color": False,
|
||||
"skip_checks": True,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("no_progress_bar_flag", "expected"),
|
||||
[
|
||||
pytest.param(False, False, id="progress-bar-enabled"),
|
||||
pytest.param(True, True, id="progress-bar-disabled"),
|
||||
],
|
||||
)
|
||||
def test_no_progress_bar_state_set(
|
||||
self,
|
||||
base_options: dict,
|
||||
*,
|
||||
no_progress_bar_flag: bool,
|
||||
expected: bool,
|
||||
):
|
||||
command = SimpleCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
|
||||
options = {**base_options, "no_progress_bar": no_progress_bar_flag}
|
||||
command.execute(**options)
|
||||
|
||||
assert command.no_progress_bar is expected
|
||||
|
||||
def test_no_progress_bar_always_true_when_not_supported(self, base_options: dict):
|
||||
command = NoProgressBarCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
|
||||
command.execute(**base_options)
|
||||
|
||||
assert command.no_progress_bar is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("processes", "expected"),
|
||||
[
|
||||
pytest.param(1, 1, id="single-process"),
|
||||
pytest.param(4, 4, id="four-processes"),
|
||||
pytest.param(8, 8, id="eight-processes"),
|
||||
],
|
||||
)
|
||||
def test_process_count_set(
|
||||
self,
|
||||
base_options: dict,
|
||||
processes: int,
|
||||
expected: int,
|
||||
):
|
||||
command = MultiprocessCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
|
||||
options = {**base_options, "processes": processes, "no_progress_bar": True}
|
||||
command.execute(**options)
|
||||
|
||||
assert command.process_count == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_count",
|
||||
[
|
||||
pytest.param(0, id="zero"),
|
||||
pytest.param(-1, id="negative"),
|
||||
],
|
||||
)
|
||||
def test_process_count_validation_rejects_invalid(
|
||||
self,
|
||||
base_options: dict,
|
||||
invalid_count: int,
|
||||
):
|
||||
command = MultiprocessCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
|
||||
options = {**base_options, "processes": invalid_count, "no_progress_bar": True}
|
||||
|
||||
with pytest.raises(CommandError, match="--processes must be at least 1"):
|
||||
command.execute(**options)
|
||||
|
||||
def test_process_count_defaults_to_one_when_not_supported(self, base_options: dict):
|
||||
command = SimpleCommand()
|
||||
command.stdout = io.StringIO()
|
||||
command.stderr = io.StringIO()
|
||||
|
||||
options = {**base_options, "no_progress_bar": True}
|
||||
command.execute(**options)
|
||||
|
||||
assert command.process_count == 1
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestGetIterableLength:
|
||||
"""Tests for the _get_iterable_length() method."""
|
||||
|
||||
def test_uses_count_method_for_querysets(
|
||||
self,
|
||||
simple_command: SimpleCommand,
|
||||
mock_queryset,
|
||||
):
|
||||
"""Should use .count() for Django querysets (SELECT COUNT(*))."""
|
||||
queryset = mock_queryset([1, 2, 3, 4, 5])
|
||||
|
||||
result = simple_command._get_iterable_length(queryset)
|
||||
|
||||
assert result == 5
|
||||
assert queryset.count_called is True
|
||||
|
||||
def test_falls_back_to_len_for_sequences(self, simple_command: SimpleCommand):
|
||||
"""Should use len() for regular sequences without .count()."""
|
||||
items = [1, 2, 3, 4]
|
||||
|
||||
result = simple_command._get_iterable_length(items)
|
||||
|
||||
assert result == 4
|
||||
|
||||
def test_returns_none_for_generators(self, simple_command: SimpleCommand):
|
||||
"""Should return None for iterables without len() or count()."""
|
||||
|
||||
def gen():
|
||||
yield from [1, 2, 3]
|
||||
|
||||
result = simple_command._get_iterable_length(gen())
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_handles_count_not_callable(self, simple_command: SimpleCommand):
|
||||
"""Should skip .count if it's not callable (edge case)."""
|
||||
|
||||
class WeirdObject:
|
||||
count = 42 # Attribute, not method
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
result = simple_command._get_iterable_length(WeirdObject())
|
||||
|
||||
assert result == 10
|
||||
|
||||
def test_handles_list_count_method(self, simple_command: SimpleCommand):
|
||||
"""list.count(value) requires an argument, should fall back to len()."""
|
||||
items = [1, 2, 3, 4, 5]
|
||||
|
||||
# list has a .count() method but it requires an argument
|
||||
# _get_iterable_length should catch TypeError and use len() instead
|
||||
result = simple_command._get_iterable_length(items)
|
||||
|
||||
assert result == 5
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestTrack:
|
||||
"""Tests for the track() method."""
|
||||
|
||||
def test_yields_all_items(self, simple_command: SimpleCommand):
|
||||
items = [1, 2, 3, 4, 5]
|
||||
|
||||
result = list(simple_command.track(items))
|
||||
|
||||
assert result == items
|
||||
|
||||
def test_with_progress_bar_disabled(self, simple_command: SimpleCommand):
|
||||
simple_command.no_progress_bar = True
|
||||
items = ["a", "b", "c"]
|
||||
|
||||
result = list(simple_command.track(items, description="Test..."))
|
||||
|
||||
assert result == items
|
||||
|
||||
def test_with_progress_bar_enabled(self, simple_command: SimpleCommand):
|
||||
simple_command.no_progress_bar = False
|
||||
items = [1, 2, 3]
|
||||
|
||||
result = list(simple_command.track(items, description="Processing..."))
|
||||
|
||||
assert result == items
|
||||
|
||||
def test_with_explicit_total(self, simple_command: SimpleCommand):
|
||||
simple_command.no_progress_bar = False
|
||||
|
||||
def gen():
|
||||
yield from [1, 2, 3]
|
||||
|
||||
result = list(simple_command.track(gen(), total=3))
|
||||
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_with_generator_no_total(self, simple_command: SimpleCommand):
|
||||
def gen():
|
||||
yield from [1, 2, 3]
|
||||
|
||||
result = list(simple_command.track(gen()))
|
||||
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_empty_iterable(self, simple_command: SimpleCommand):
|
||||
result = list(simple_command.track([]))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_uses_queryset_count(
|
||||
self,
|
||||
simple_command: SimpleCommand,
|
||||
mock_queryset,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Verify track() uses .count() for querysets."""
|
||||
simple_command.no_progress_bar = False
|
||||
queryset = mock_queryset([1, 2, 3])
|
||||
|
||||
spy = mocker.spy(simple_command, "_get_iterable_length")
|
||||
|
||||
result = list(simple_command.track(queryset))
|
||||
|
||||
assert result == [1, 2, 3]
|
||||
spy.assert_called_once_with(queryset)
|
||||
assert queryset.count_called is True
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestProcessParallel:
|
||||
"""Tests for the process_parallel() method."""
|
||||
|
||||
def test_sequential_processing_single_process(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
):
|
||||
multiprocess_command.process_count = 1
|
||||
items = [1, 2, 3, 4, 5]
|
||||
|
||||
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||
|
||||
assert len(results) == 5
|
||||
assert all(r.success for r in results)
|
||||
|
||||
result_map = {r.item: r.result for r in results}
|
||||
assert result_map == {1: 2, 2: 4, 3: 6, 4: 8, 5: 10}
|
||||
|
||||
def test_sequential_processing_handles_errors(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
):
|
||||
multiprocess_command.process_count = 1
|
||||
items = [1, 2, 0, 4] # 0 causes ZeroDivisionError
|
||||
|
||||
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
||||
|
||||
assert len(results) == 4
|
||||
|
||||
successes = [r for r in results if r.success]
|
||||
failures = [r for r in results if not r.success]
|
||||
|
||||
assert len(successes) == 3
|
||||
assert len(failures) == 1
|
||||
assert failures[0].item == 0
|
||||
assert isinstance(failures[0].error, ZeroDivisionError)
|
||||
|
||||
def test_parallel_closes_db_connections(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
multiprocess_command.process_count = 2
|
||||
items = [1, 2, 3]
|
||||
|
||||
mock_connections = mocker.patch(
|
||||
"documents.management.commands.base.db.connections",
|
||||
)
|
||||
|
||||
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||
|
||||
mock_connections.close_all.assert_called_once()
|
||||
assert len(results) == 3
|
||||
|
||||
def test_parallel_processing_handles_errors(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
multiprocess_command.process_count = 2
|
||||
items = [1, 2, 0, 4]
|
||||
|
||||
mocker.patch("documents.management.commands.base.db.connections")
|
||||
|
||||
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
||||
|
||||
failures = [r for r in results if not r.success]
|
||||
assert len(failures) == 1
|
||||
assert failures[0].item == 0
|
||||
|
||||
def test_empty_items(self, multiprocess_command: MultiprocessCommand):
|
||||
results = list(multiprocess_command.process_parallel(_double_value, []))
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_result_contains_original_item(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
):
|
||||
items = [10, 20, 30]
|
||||
|
||||
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||
|
||||
for result in results:
|
||||
assert result.item in items
|
||||
assert result.result == result.item * 2
|
||||
|
||||
def test_sequential_path_used_for_single_process(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Verify single process uses sequential path (important for testing)."""
|
||||
multiprocess_command.process_count = 1
|
||||
|
||||
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
||||
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
||||
|
||||
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
||||
|
||||
spy_sequential.assert_called_once()
|
||||
spy_parallel.assert_not_called()
|
||||
|
||||
def test_parallel_path_used_for_multiple_processes(
|
||||
self,
|
||||
multiprocess_command: MultiprocessCommand,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Verify multiple processes uses parallel path."""
|
||||
multiprocess_command.process_count = 2
|
||||
|
||||
mocker.patch("documents.management.commands.base.db.connections")
|
||||
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
||||
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
||||
|
||||
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
||||
|
||||
spy_parallel.assert_called_once()
|
||||
spy_sequential.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestClassVariableDefaults:
|
||||
"""Tests for class variable default behavior."""
|
||||
|
||||
def test_default_supports_progress_bar(self):
|
||||
assert PaperlessCommand.supports_progress_bar is True
|
||||
|
||||
def test_default_supports_multiprocessing(self):
|
||||
assert PaperlessCommand.supports_multiprocessing is False
|
||||
|
||||
def test_subclass_can_override_progress_bar(self):
|
||||
class NoProgressCommand(PaperlessCommand):
|
||||
supports_progress_bar: ClassVar[bool] = False
|
||||
|
||||
assert NoProgressCommand.supports_progress_bar is False
|
||||
assert PaperlessCommand.supports_progress_bar is True
|
||||
|
||||
def test_subclass_can_override_multiprocessing(self):
|
||||
class ParallelCommand(PaperlessCommand):
|
||||
supports_multiprocessing: ClassVar[bool] = True
|
||||
|
||||
assert ParallelCommand.supports_multiprocessing is True
|
||||
assert PaperlessCommand.supports_multiprocessing is False
|
||||
@@ -4,6 +4,7 @@ from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from auditlog.models import LogEntry
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.management import call_command
|
||||
@@ -19,6 +20,7 @@ from documents.tests.utils import FileSystemAssertsMixin
|
||||
sample_file: Path = Path(__file__).parent / "samples" / "simple.pdf"
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@override_settings(FILENAME_FORMAT="{correspondent}/{title}")
|
||||
class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
def make_models(self):
|
||||
@@ -94,6 +96,7 @@ class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
self.assertEqual(doc2.archive_filename, "document_01.pdf")
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestMakeIndex(TestCase):
|
||||
@mock.patch("documents.management.commands.document_index.index_reindex")
|
||||
def test_reindex(self, m) -> None:
|
||||
@@ -106,6 +109,7 @@ class TestMakeIndex(TestCase):
|
||||
m.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
@override_settings(FILENAME_FORMAT="")
|
||||
def test_rename(self) -> None:
|
||||
@@ -140,6 +144,7 @@ class TestCreateClassifier(TestCase):
|
||||
m.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestSanityChecker(DirectoriesMixin, TestCase):
|
||||
def test_no_issues(self) -> None:
|
||||
with self.assertLogs() as capture:
|
||||
@@ -165,6 +170,7 @@ class TestSanityChecker(DirectoriesMixin, TestCase):
|
||||
self.assertIn("Checksum mismatch. Stored: abc, actual:", capture.output[1])
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestConvertMariaDBUUID(TestCase):
|
||||
@mock.patch("django.db.connection.schema_editor")
|
||||
def test_convert(self, m) -> None:
|
||||
@@ -178,6 +184,7 @@ class TestConvertMariaDBUUID(TestCase):
|
||||
self.assertIn("Successfully converted", stdout.getvalue())
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestPruneAuditLogs(TestCase):
|
||||
def test_prune_audit_logs(self) -> None:
|
||||
LogEntry.objects.create(
|
||||
|
||||
@@ -577,6 +577,7 @@ class TestTagsFromPath:
|
||||
assert len(tag_ids) == 0
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestCommandValidation:
|
||||
"""Tests for command argument validation."""
|
||||
|
||||
@@ -605,6 +606,7 @@ class TestCommandValidation:
|
||||
cmd.handle(directory=str(sample_pdf), oneshot=True, testing=False)
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.usefixtures("mock_supported_extensions")
|
||||
class TestCommandOneshot:
|
||||
"""Tests for oneshot mode."""
|
||||
@@ -775,6 +777,7 @@ def start_consumer(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestCommandWatch:
|
||||
"""Integration tests for the watch loop."""
|
||||
@@ -896,6 +899,7 @@ class TestCommandWatch:
|
||||
assert not thread.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestCommandWatchPolling:
|
||||
"""Tests for polling mode."""
|
||||
@@ -928,6 +932,7 @@ class TestCommandWatchPolling:
|
||||
mock_consume_file_delay.delay.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestCommandWatchRecursive:
|
||||
"""Tests for recursive watching."""
|
||||
@@ -991,6 +996,7 @@ class TestCommandWatchRecursive:
|
||||
assert len(overrides.tag_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestCommandWatchEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from allauth.socialaccount.models import SocialToken
|
||||
@@ -45,6 +46,7 @@ from documents.tests.utils import paperless_environment
|
||||
from paperless_mail.models import MailAccount
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestExportImport(
|
||||
DirectoriesMixin,
|
||||
FileSystemAssertsMixin,
|
||||
@@ -846,6 +848,7 @@ class TestExportImport(
|
||||
self.assertEqual(Document.objects.all().count(), 4)
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestCryptExportImport(
|
||||
DirectoriesMixin,
|
||||
FileSystemAssertsMixin,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
from django.core.management import CommandError
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
@@ -7,6 +8,7 @@ from django.test import TestCase
|
||||
from documents.models import Document
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestFuzzyMatchCommand(TestCase):
|
||||
MSG_REGEX = r"Document \d fuzzy match to \d \(confidence \d\d\.\d\d\d\)"
|
||||
|
||||
@@ -49,19 +51,6 @@ class TestFuzzyMatchCommand(TestCase):
|
||||
self.call_command("--ratio", "101")
|
||||
self.assertIn("The ratio must be between 0 and 100", str(e.exception))
|
||||
|
||||
def test_invalid_process_count(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Invalid process count less than 0 above upper
|
||||
WHEN:
|
||||
- Command is called
|
||||
THEN:
|
||||
- Error is raised indicating issue
|
||||
"""
|
||||
with self.assertRaises(CommandError) as e:
|
||||
self.call_command("--processes", "0")
|
||||
self.assertIn("There must be at least 1 process", str(e.exception))
|
||||
|
||||
def test_no_matches(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
@@ -194,7 +183,7 @@ class TestFuzzyMatchCommand(TestCase):
|
||||
|
||||
self.assertEqual(Document.objects.count(), 3)
|
||||
|
||||
stdout, _ = self.call_command("--delete")
|
||||
stdout, _ = self.call_command("--delete", "--no-progress-bar")
|
||||
|
||||
self.assertIn(
|
||||
"The command is configured to delete documents. Use with caution",
|
||||
|
||||
@@ -4,6 +4,7 @@ from io import StringIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
@@ -18,6 +19,7 @@ from documents.tests.utils import FileSystemAssertsMixin
|
||||
from documents.tests.utils import SampleDirMixin
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestCommandImport(
|
||||
DirectoriesMixin,
|
||||
FileSystemAssertsMixin,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from django.test import TestCase
|
||||
@@ -10,6 +11,7 @@ from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestRetagger(DirectoriesMixin, TestCase):
|
||||
def make_models(self) -> None:
|
||||
self.sp1 = StoragePath.objects.create(
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
@@ -9,6 +10,7 @@ from django.test import TestCase
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestManageSuperUser(DirectoriesMixin, TestCase):
|
||||
def call_command(self, environ):
|
||||
out = StringIO()
|
||||
|
||||
@@ -2,6 +2,7 @@ import shutil
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
|
||||
@@ -12,6 +13,7 @@ from documents.tests.utils import DirectoriesMixin
|
||||
from documents.tests.utils import FileSystemAssertsMixin
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
class TestMakeThumbnails(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
def make_models(self) -> None:
|
||||
self.d1 = Document.objects.create(
|
||||
|
||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1137,6 +1137,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/10/23c0644cf67567bbe4e3a2eeeec0e9c79b701990c0e07c5ee4a4f8897f91/django_multiselectfield-1.0.1-py3-none-any.whl", hash = "sha256:18dc14801f7eca844a48e21cba6d8ec35b9b581f2373bbb2cb75e6994518259a", size = 20481, upload-time = "2025-06-12T14:41:20.107Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-rich"
|
||||
version = "2.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a6/67/e307a5fef657e7992468f567b521534c52e01bdda5a1ae5b12de679a670f/django_rich-2.2.0.tar.gz", hash = "sha256:ecec7842d040024ed8a225699388535e46b87277550c33f46193b52cece2f780", size = 62427, upload-time = "2025-09-18T11:42:17.182Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/27/ed/23fa669493d78cd67e7f6734fa380f8690f2b4d75b4f72fd645a52c3b32a/django_rich-2.2.0-py3-none-any.whl", hash = "sha256:a0d2c916bd9750b6e9beb57407aef5e836c8705d7dbe9e4fd4725f7bbe41c407", size = 9210, upload-time = "2025-09-18T11:42:15.779Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-soft-delete"
|
||||
version = "1.0.22"
|
||||
@@ -3041,6 +3054,7 @@ dependencies = [
|
||||
{ name = "django-filter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "django-guardian", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "django-multiselectfield", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "django-rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "django-soft-delete", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "django-treenode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -3081,7 +3095,6 @@ dependencies = [
|
||||
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "torch", version = "2.10.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "torch", version = "2.10.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'linux'" },
|
||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "watchfiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "whitenoise", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "whoosh-reloaded", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -3186,6 +3199,7 @@ requires-dist = [
|
||||
{ name = "django-filter", specifier = "~=25.1" },
|
||||
{ name = "django-guardian", specifier = "~=3.2.0" },
|
||||
{ name = "django-multiselectfield", specifier = "~=1.0.1" },
|
||||
{ name = "django-rich", specifier = "~=2.2.0" },
|
||||
{ name = "django-soft-delete", specifier = "~=1.0.18" },
|
||||
{ name = "django-treenode", specifier = ">=0.23.2" },
|
||||
{ name = "djangorestframework", specifier = "~=3.16" },
|
||||
@@ -3232,7 +3246,6 @@ requires-dist = [
|
||||
{ name = "setproctitle", specifier = "~=1.3.4" },
|
||||
{ name = "tika-client", specifier = "~=0.10.0" },
|
||||
{ name = "torch", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
{ name = "tqdm", specifier = "~=4.67.1" },
|
||||
{ name = "watchfiles", specifier = ">=1.1.1" },
|
||||
{ name = "whitenoise", specifier = "~=6.11" },
|
||||
{ name = "whoosh-reloaded", specifier = ">=2.7.5" },
|
||||
|
||||
Reference in New Issue
Block a user