From 7080322c419ec0ec3877cedb198ce3d0f06f01de Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 25 Feb 2026 14:30:50 -0800 Subject: [PATCH] Initial conversion to new base class --- pyproject.toml | 3 +- src/documents/management/commands/base.py | 316 ++++++++++ .../management/commands/document_archiver.py | 55 +- .../commands/document_fuzzy_match.py | 94 ++- .../management/commands/document_renamer.py | 21 +- .../management/commands/document_retagger.py | 16 +- .../commands/document_thumbnails.py | 53 +- .../management/commands/prune_audit_logs.py | 30 +- src/documents/tests/management/__init__.py | 0 .../management/test_management_base_cmd.py | 579 ++++++++++++++++++ src/documents/tests/test_management.py | 7 + .../tests/test_management_consumer.py | 6 + .../tests/test_management_exporter.py | 3 + src/documents/tests/test_management_fuzzy.py | 17 +- .../tests/test_management_importer.py | 2 + .../tests/test_management_retagger.py | 2 + .../tests/test_management_superuser.py | 2 + .../tests/test_management_thumbnails.py | 2 + uv.lock | 17 +- 19 files changed, 1032 insertions(+), 193 deletions(-) create mode 100644 src/documents/management/commands/base.py create mode 100644 src/documents/tests/management/__init__.py create mode 100644 src/documents/tests/management/test_management_base_cmd.py diff --git a/pyproject.toml b/pyproject.toml index 7edde8dcf..f36f9f1be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "django-filter~=25.1", "django-guardian~=3.2.0", "django-multiselectfield~=1.0.1", + "django-rich~=2.2.0", "django-soft-delete~=1.0.18", "django-treenode>=0.23.2", "djangorestframework~=3.16", @@ -76,7 +77,6 @@ dependencies = [ "setproctitle~=1.3.4", "tika-client~=0.10.0", "torch~=2.10.0", - "tqdm~=4.67.1", "watchfiles>=1.1.1", "whitenoise~=6.11", "whoosh-reloaded>=2.7.5", @@ -304,6 +304,7 @@ markers = [ "tika: Tests requiring Tika service", "greenmail: Tests requiring Greenmail service", "date_parsing: Tests which cover date parsing from content or filename", + "management: Tests which cover management commands/functionality", ] [tool.pytest_env] diff --git a/src/documents/management/commands/base.py b/src/documents/management/commands/base.py new file mode 100644 index 000000000..dc9f7e98d --- /dev/null +++ b/src/documents/management/commands/base.py @@ -0,0 +1,316 @@ +""" +Base command class for Paperless-ngx management commands. + +Provides automatic progress bar and multiprocessing support with minimal boilerplate. +""" + +from __future__ import annotations + +import os +from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import as_completed +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import ClassVar +from typing import Generic +from typing import TypeVar + +from django import db +from django.core.management import CommandError +from django_rich.management import RichCommand +from rich.progress import BarColumn +from rich.progress import MofNCompleteColumn +from rich.progress import Progress +from rich.progress import SpinnerColumn +from rich.progress import TextColumn +from rich.progress import TimeElapsedColumn +from rich.progress import TimeRemainingColumn + +if TYPE_CHECKING: + from argparse import ArgumentParser + from collections.abc import Callable + from collections.abc import Generator + from collections.abc import Iterable + from collections.abc import Sequence + +T = TypeVar("T") +R = TypeVar("R") + + +@dataclass(frozen=True, slots=True) +class ProcessResult(Generic[T, R]): + """ + Result of processing a single item in parallel. + + Attributes: + item: The input item that was processed. + result: The return value from the processing function, or None if an error occurred. + error: The exception if processing failed, or None on success. + """ + + item: T + result: R | None + error: BaseException | None + + @property + def success(self) -> bool: + """Return True if the item was processed successfully.""" + return self.error is None + + +class PaperlessCommand(RichCommand): + """ + Base command class with automatic progress bar and multiprocessing support. + + Features are opt-in via class attributes: + supports_progress_bar: Adds --no-progress-bar argument (default: True) + supports_multiprocessing: Adds --processes argument (default: False) + + Example usage: + + class Command(PaperlessCommand): + help = "Process all documents" + + def handle(self, *args, **options): + documents = Document.objects.all() + for doc in self.track(documents, description="Processing..."): + process_document(doc) + + class Command(PaperlessCommand): + help = "Regenerate thumbnails" + supports_multiprocessing = True + + def handle(self, *args, **options): + ids = list(Document.objects.values_list("id", flat=True)) + for result in self.process_parallel(process_doc, ids): + if result.error: + self.console.print(f"[red]Failed: {result.error}[/red]") + """ + + supports_progress_bar: ClassVar[bool] = True + supports_multiprocessing: ClassVar[bool] = False + + # Instance attributes set by execute() before handle() runs + no_progress_bar: bool + process_count: int + + def add_arguments(self, parser: ArgumentParser) -> None: + """Add arguments based on supported features.""" + super().add_arguments(parser) + + if self.supports_progress_bar: + parser.add_argument( + "--no-progress-bar", + default=False, + action="store_true", + help="Disable the progress bar", + ) + + if self.supports_multiprocessing: + default_processes = max(1, (os.cpu_count() or 1) // 4) + parser.add_argument( + "--processes", + default=default_processes, + type=int, + help=f"Number of processes to use (default: {default_processes})", + ) + + def execute(self, *args, **options) -> int: + """ + Set up instance state before handle() is called. + + This is called by Django's command infrastructure after argument parsing + but before handle(). We use it to set instance attributes from options. + """ + # Set progress bar state + if self.supports_progress_bar: + self.no_progress_bar = options.get("no_progress_bar", False) + else: + self.no_progress_bar = True + + # Set multiprocessing state + if self.supports_multiprocessing: + self.process_count = options.get("processes", 1) + if self.process_count < 1: + raise CommandError("--processes must be at least 1") + else: + self.process_count = 1 + + return super().execute(*args, **options) + + def _create_progress(self, description: str) -> Progress: + """ + Create a configured Progress instance. + + Args: + description: Text to display alongside the progress bar. + + Returns: + A Progress instance configured with appropriate columns. + """ + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=self.console, + transient=False, + ) + + def _get_iterable_length(self, iterable: Iterable[T]) -> int | None: + """ + Attempt to determine the length of an iterable without consuming it. + + Tries .count() first (for Django querysets - executes SELECT COUNT(*)), + then falls back to len() for sequences. + + Args: + iterable: The iterable to measure. + + Returns: + The length if determinable, None otherwise. + """ + # Django querysets have .count() which is a SELECT COUNT(*) + # This is much more efficient than len() which evaluates the queryset + # Note: list.count(value) requires an argument, so we catch TypeError + if hasattr(iterable, "count") and callable(iterable.count): + try: + return iterable.count() + except TypeError: + # list.count() requires an argument, fall through to len() + pass + + # Fall back to len() for sequences + try: + return len(iterable) # type: ignore[arg-type] + except TypeError: + return None + + def track( + self, + iterable: Iterable[T], + *, + description: str = "Processing...", + total: int | None = None, + ) -> Generator[T, None, None]: + """ + Iterate over items with an optional progress bar. + + Respects --no-progress-bar flag. When disabled, simply yields items + without any progress display. + + Args: + iterable: The items to iterate over. + description: Text to display alongside the progress bar. + total: Total number of items. If None, attempts to determine + automatically via .count() (for querysets) or len(). + + Yields: + Items from the iterable. + + Example: + for doc in self.track(documents, description="Renaming..."): + process(doc) + """ + if self.no_progress_bar: + yield from iterable + return + + # Attempt to determine total if not provided + if total is None: + total = self._get_iterable_length(iterable) + + with self._create_progress(description) as progress: + task_id = progress.add_task(description, total=total) + for item in iterable: + yield item + progress.advance(task_id) + + def process_parallel( + self, + fn: Callable[[T], R], + items: Sequence[T], + *, + description: str = "Processing...", + ) -> Generator[ProcessResult[T, R], None, None]: + """ + Process items in parallel with progress tracking. + + When --processes=1, runs sequentially in the main process without + spawning subprocesses. This is critical for testing, as multiprocessing + breaks fixtures, mocks, and database transactions. + + When --processes > 1, uses ProcessPoolExecutor and automatically closes + database connections before spawning workers (required for PostgreSQL). + + Args: + fn: Function to apply to each item. Must be picklable for parallel + execution (i.e., defined at module level, not a lambda or closure). + items: Sequence of items to process. + description: Text to display alongside the progress bar. + + Yields: + ProcessResult for each item, containing the item, result, and any error. + + Example: + def regenerate_thumbnail(doc_id: int) -> Path: + ... + + for result in self.process_parallel(regenerate_thumbnail, doc_ids): + if result.error: + self.console.print(f"[red]Failed {result.item}[/red]") + """ + total = len(items) + + if self.process_count == 1: + # Sequential execution in main process - critical for testing + yield from self._process_sequential(fn, items, description, total) + else: + # Parallel execution with ProcessPoolExecutor + yield from self._process_parallel(fn, items, description, total) + + def _process_sequential( + self, + fn: Callable[[T], R], + items: Sequence[T], + description: str, + total: int, + ) -> Generator[ProcessResult[T, R], None, None]: + """Process items sequentially in the main process.""" + for item in self.track(items, description=description, total=total): + try: + result = fn(item) + yield ProcessResult(item=item, result=result, error=None) + except Exception as e: + yield ProcessResult(item=item, result=None, error=e) + + def _process_parallel( + self, + fn: Callable[[T], R], + items: Sequence[T], + description: str, + total: int, + ) -> Generator[ProcessResult[T, R], None, None]: + """Process items in parallel using ProcessPoolExecutor.""" + # Close database connections before forking - required for PostgreSQL + db.connections.close_all() + + with self._create_progress(description) as progress: + task_id = progress.add_task(description, total=total) + + with ProcessPoolExecutor(max_workers=self.process_count) as executor: + # Submit all tasks and map futures back to items + future_to_item = {executor.submit(fn, item): item for item in items} + + # Yield results as they complete + for future in as_completed(future_to_item): + item = future_to_item[future] + try: + result = future.result() + yield ProcessResult(item=item, result=result, error=None) + except Exception as e: + yield ProcessResult(item=item, result=None, error=e) + finally: + progress.advance(task_id) diff --git a/src/documents/management/commands/document_archiver.py b/src/documents/management/commands/document_archiver.py index 1aa52117a..f0f122d85 100644 --- a/src/documents/management/commands/document_archiver.py +++ b/src/documents/management/commands/document_archiver.py @@ -1,20 +1,15 @@ import logging -import multiprocessing -import tqdm -from django import db from django.conf import settings -from django.core.management.base import BaseCommand -from documents.management.commands.mixins import MultiProcessMixin -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.models import Document from documents.tasks import update_document_content_maybe_archive_file logger = logging.getLogger("paperless.management.archiver") -class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = ( "Using the current classification model, assigns correspondents, tags " "and document types to all documents, effectively allowing you to " @@ -22,7 +17,10 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): "modified) after their initial import." ) + supports_multiprocessing = True + def add_arguments(self, parser): + super().add_arguments(parser) parser.add_argument( "-f", "--overwrite", @@ -44,13 +42,8 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): "run on this specific document." ), ) - self.add_argument_progress_bar_mixin(parser) - self.add_argument_processes_mixin(parser) def handle(self, *args, **options): - self.handle_processes_mixin(**options) - self.handle_progress_bar_mixin(**options) - settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True) overwrite = options["overwrite"] @@ -60,35 +53,21 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): else: documents = Document.objects.all() - document_ids = list( - map( - lambda doc: doc.id, - filter(lambda d: overwrite or not d.has_archive_version, documents), - ), - ) - - # Note to future self: this prevents django from reusing database - # connections between processes, which is bad and does not work - # with postgres. - db.connections.close_all() + document_ids = [ + doc.id for doc in documents if overwrite or not doc.has_archive_version + ] try: logging.getLogger().handlers[0].level = logging.ERROR - if self.process_count == 1: - for doc_id in document_ids: - update_document_content_maybe_archive_file(doc_id) - else: # pragma: no cover - with multiprocessing.Pool(self.process_count) as pool: - list( - tqdm.tqdm( - pool.imap_unordered( - update_document_content_maybe_archive_file, - document_ids, - ), - total=len(document_ids), - disable=self.no_progress_bar, - ), + for result in self.process_parallel( + update_document_content_maybe_archive_file, + document_ids, + description="Archiving...", + ): + if result.error: + self.console.print( + f"[red]Failed document {result.item}: {result.error}[/red]", ) except KeyboardInterrupt: - self.stdout.write(self.style.NOTICE("Aborting...")) + self.console.print("[yellow]Aborting...[/yellow]") diff --git a/src/documents/management/commands/document_fuzzy_match.py b/src/documents/management/commands/document_fuzzy_match.py index 4ecdf6d01..98709b540 100644 --- a/src/documents/management/commands/document_fuzzy_match.py +++ b/src/documents/management/commands/document_fuzzy_match.py @@ -1,24 +1,20 @@ import dataclasses -import multiprocessing from typing import Final import rapidfuzz -import tqdm -from django.core.management import BaseCommand from django.core.management import CommandError -from documents.management.commands.mixins import MultiProcessMixin -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.models import Document -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class _WorkPackage: first_doc: Document second_doc: Document -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class _WorkResult: doc_one_pk: int doc_two_pk: int @@ -31,22 +27,23 @@ class _WorkResult: def _process_and_match(work: _WorkPackage) -> _WorkResult: """ Does basic processing of document content, gets the basic ratio - and returns the result package + and returns the result package. """ - # Normalize the string some, lower case, whitespace, etc first_string = rapidfuzz.utils.default_process(work.first_doc.content) second_string = rapidfuzz.utils.default_process(work.second_doc.content) - # Basic matching ratio match = rapidfuzz.fuzz.ratio(first_string, second_string) return _WorkResult(work.first_doc.pk, work.second_doc.pk, match) -class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = "Searches for documents where the content almost matches" + supports_multiprocessing = True + def add_arguments(self, parser): + super().add_arguments(parser) parser.add_argument( "--ratio", default=85.0, @@ -59,91 +56,72 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): action="store_true", help="If set, one document of matches above the ratio WILL BE DELETED", ) - self.add_argument_progress_bar_mixin(parser) - self.add_argument_processes_mixin(parser) def handle(self, *args, **options): RATIO_MIN: Final[float] = 0.0 RATIO_MAX: Final[float] = 100.0 - self.handle_processes_mixin(**options) - self.handle_progress_bar_mixin(**options) - if options["delete"]: - self.stdout.write( - self.style.WARNING( - "The command is configured to delete documents. Use with caution", - ), + self.console.print( + "[yellow]The command is configured to delete documents. " + "Use with caution.[/yellow]", ) opt_ratio = options["ratio"] checked_pairs: set[tuple[int, int]] = set() work_pkgs: list[_WorkPackage] = [] - # Ratio is a float from 0.0 to 100.0 if opt_ratio < RATIO_MIN or opt_ratio > RATIO_MAX: raise CommandError("The ratio must be between 0 and 100") all_docs = Document.objects.all().order_by("id") - # Build work packages for processing for first_doc in all_docs: for second_doc in all_docs: - # doc to doc is obviously not useful if first_doc.pk == second_doc.pk: continue - # Skip empty documents (e.g. password-protected) if first_doc.content.strip() == "" or second_doc.content.strip() == "": continue - # Skip matching which have already been matched together - # doc 1 to doc 2 is the same as doc 2 to doc 1 doc_1_to_doc_2 = (first_doc.pk, second_doc.pk) doc_2_to_doc_1 = doc_1_to_doc_2[::-1] if doc_1_to_doc_2 in checked_pairs or doc_2_to_doc_1 in checked_pairs: continue checked_pairs.update([doc_1_to_doc_2, doc_2_to_doc_1]) - # Actually something useful to work on now work_pkgs.append(_WorkPackage(first_doc, second_doc)) - # Don't spin up a pool of 1 process - if self.process_count == 1: - results = [] - for work in tqdm.tqdm(work_pkgs, disable=self.no_progress_bar): - results.append(_process_and_match(work)) - else: # pragma: no cover - with multiprocessing.Pool(processes=self.process_count) as pool: - results = list( - tqdm.tqdm( - pool.imap_unordered(_process_and_match, work_pkgs), - total=len(work_pkgs), - disable=self.no_progress_bar, - ), + results: list[_WorkResult] = [] + for result in self.process_parallel( + _process_and_match, + work_pkgs, + description="Matching...", + ): + if result.error: + self.console.print( + f"[red]Failed: {result.error}[/red]", ) + elif result.result is not None: + results.append(result.result) - # Check results - messages = [] - maybe_delete_ids = [] - for result in sorted(results): - if result.ratio >= opt_ratio: + messages: list[str] = [] + maybe_delete_ids: list[int] = [] + for match_result in sorted(results): + if match_result.ratio >= opt_ratio: messages.append( self.style.NOTICE( - f"Document {result.doc_one_pk} fuzzy match" - f" to {result.doc_two_pk} (confidence {result.ratio:.3f})\n", + f"Document {match_result.doc_one_pk} fuzzy match" + f" to {match_result.doc_two_pk}" + f" (confidence {match_result.ratio:.3f})\n", ), ) - maybe_delete_ids.append(result.doc_two_pk) + maybe_delete_ids.append(match_result.doc_two_pk) if len(messages) == 0: - messages.append( - self.style.SUCCESS("No matches found\n"), - ) - self.stdout.writelines( - messages, - ) + messages.append(self.style.SUCCESS("No matches found\n")) + self.stdout.writelines(messages) + if options["delete"]: - self.stdout.write( - self.style.NOTICE( - f"Deleting {len(maybe_delete_ids)} documents based on ratio matches", - ), + self.console.print( + f"[yellow]Deleting {len(maybe_delete_ids)} documents " + f"based on ratio matches[/yellow]", ) Document.objects.filter(pk__in=maybe_delete_ids).delete() diff --git a/src/documents/management/commands/document_renamer.py b/src/documents/management/commands/document_renamer.py index 2dfca217e..05f0224bb 100644 --- a/src/documents/management/commands/document_renamer.py +++ b/src/documents/management/commands/document_renamer.py @@ -1,25 +1,12 @@ -import logging - -import tqdm -from django.core.management.base import BaseCommand from django.db.models.signals import post_save -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.models import Document -class Command(ProgressBarMixin, BaseCommand): - help = "This will rename all documents to match the latest filename format." - - def add_arguments(self, parser): - self.add_argument_progress_bar_mixin(parser) +class Command(PaperlessCommand): + help = "Rename all documents" def handle(self, *args, **options): - self.handle_progress_bar_mixin(**options) - logging.getLogger().handlers[0].level = logging.ERROR - - for document in tqdm.tqdm( - Document.objects.all(), - disable=self.no_progress_bar, - ): + for document in self.track(Document.objects.all(), description="Renaming..."): post_save.send(Document, instance=document, created=False) diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 10bb54b71..32f895d4e 100644 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -1,10 +1,7 @@ import logging -import tqdm -from django.core.management.base import BaseCommand - from documents.classifier import load_classifier -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.models import Document from documents.signals.handlers import set_correspondent from documents.signals.handlers import set_document_type @@ -14,7 +11,7 @@ from documents.signals.handlers import set_tags logger = logging.getLogger("paperless.management.retagger") -class Command(ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = ( "Using the current classification model, assigns correspondents, tags " "and document types to all documents, effectively allowing you to " @@ -23,6 +20,7 @@ class Command(ProgressBarMixin, BaseCommand): ) def add_arguments(self, parser): + super().add_arguments(parser) parser.add_argument("-c", "--correspondent", default=False, action="store_true") parser.add_argument("-T", "--tags", default=False, action="store_true") parser.add_argument("-t", "--document_type", default=False, action="store_true") @@ -34,7 +32,7 @@ class Command(ProgressBarMixin, BaseCommand): action="store_true", help=( "By default this command won't try to assign a correspondent " - "if more than one matches the document. Use this flag if " + "if more than one matches the document. Use this flag if " "you'd rather it just pick the first one it finds." ), ) @@ -49,7 +47,6 @@ class Command(ProgressBarMixin, BaseCommand): "and tags that do not match anymore due to changed rules." ), ) - self.add_argument_progress_bar_mixin(parser) parser.add_argument( "--suggest", default=False, @@ -68,8 +65,6 @@ class Command(ProgressBarMixin, BaseCommand): ) def handle(self, *args, **options): - self.handle_progress_bar_mixin(**options) - if options["inbox_only"]: queryset = Document.objects.filter(tags__is_inbox_tag=True) else: @@ -84,7 +79,7 @@ class Command(ProgressBarMixin, BaseCommand): classifier = load_classifier() - for document in tqdm.tqdm(documents, disable=self.no_progress_bar): + for document in self.track(documents, description="Retagging..."): if options["correspondent"]: set_correspondent( sender=None, @@ -122,6 +117,7 @@ class Command(ProgressBarMixin, BaseCommand): stdout=self.stdout, style_func=self.style, ) + if options["storage_path"]: set_storage_path( sender=None, diff --git a/src/documents/management/commands/document_thumbnails.py b/src/documents/management/commands/document_thumbnails.py index e50c837d3..03824a63e 100644 --- a/src/documents/management/commands/document_thumbnails.py +++ b/src/documents/management/commands/document_thumbnails.py @@ -1,43 +1,39 @@ import logging -import multiprocessing import shutil -import tqdm -from django import db -from django.core.management.base import BaseCommand - -from documents.management.commands.mixins import MultiProcessMixin -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.models import Document from documents.parsers import get_parser_class_for_mime_type -def _process_document(doc_id) -> None: +def _process_document(doc_id: int) -> None: document: Document = Document.objects.get(id=doc_id) parser_class = get_parser_class_for_mime_type(document.mime_type) - if parser_class: - parser = parser_class(logging_group=None) - else: + if parser_class is None: print(f"{document} No parser for mime type {document.mime_type}") # noqa: T201 return + parser = parser_class(logging_group=None) + try: thumb = parser.get_thumbnail( document.source_path, document.mime_type, document.get_public_filename(), ) - shutil.move(thumb, document.thumbnail_path) finally: parser.cleanup() -class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = "This will regenerate the thumbnails for all documents." + supports_multiprocessing = True + def add_arguments(self, parser) -> None: + super().add_arguments(parser) parser.add_argument( "-d", "--document", @@ -49,36 +45,23 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand): "run on this specific document." ), ) - self.add_argument_progress_bar_mixin(parser) - self.add_argument_processes_mixin(parser) def handle(self, *args, **options): logging.getLogger().handlers[0].level = logging.ERROR - self.handle_processes_mixin(**options) - self.handle_progress_bar_mixin(**options) - if options["document"]: documents = Document.objects.filter(pk=options["document"]) else: documents = Document.objects.all() - ids = [doc.id for doc in documents] + ids = list(documents.values_list("id", flat=True)) - # Note to future self: this prevents django from reusing database - # connections between processes, which is bad and does not work - # with postgres. - db.connections.close_all() - - if self.process_count == 1: - for doc_id in ids: - _process_document(doc_id) - else: # pragma: no cover - with multiprocessing.Pool(processes=self.process_count) as pool: - list( - tqdm.tqdm( - pool.imap_unordered(_process_document, ids), - total=len(ids), - disable=self.no_progress_bar, - ), + for result in self.process_parallel( + _process_document, + ids, + description="Regenerating thumbnails...", + ): + if result.error: + self.console.print( + f"[red]Failed document {result.item}: {result.error}[/red]", ) diff --git a/src/documents/management/commands/prune_audit_logs.py b/src/documents/management/commands/prune_audit_logs.py index b49f4afc2..eac690757 100644 --- a/src/documents/management/commands/prune_audit_logs.py +++ b/src/documents/management/commands/prune_audit_logs.py @@ -1,27 +1,21 @@ from auditlog.models import LogEntry -from django.core.management.base import BaseCommand from django.db import transaction -from tqdm import tqdm -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand -class Command(BaseCommand, ProgressBarMixin): - """ - Prune the audit logs of objects that no longer exist. - """ +class Command(PaperlessCommand): + """Prune the audit logs of objects that no longer exist.""" help = "Prunes the audit logs of objects that no longer exist." - def add_arguments(self, parser): - self.add_argument_progress_bar_mixin(parser) - - def handle(self, **options): - self.handle_progress_bar_mixin(**options) + def handle(self, *args, **options): with transaction.atomic(): - for log_entry in tqdm(LogEntry.objects.all(), disable=self.no_progress_bar): + for log_entry in self.track( + LogEntry.objects.all(), + description="Pruning audit logs...", + ): model_class = log_entry.content_type.model_class() - # use global_objects for SoftDeleteModel objects = ( model_class.global_objects if hasattr(model_class, "global_objects") @@ -32,8 +26,8 @@ class Command(BaseCommand, ProgressBarMixin): and not objects.filter(pk=log_entry.object_id).exists() ): log_entry.delete() - tqdm.write( - self.style.NOTICE( - f"Deleted audit log entry for {model_class.__name__} #{log_entry.object_id}", - ), + self.console.print( + f"Deleted audit log entry for " + f"{model_class.__name__} #{log_entry.object_id}", + style="yellow", ) diff --git a/src/documents/tests/management/__init__.py b/src/documents/tests/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/documents/tests/management/test_management_base_cmd.py b/src/documents/tests/management/test_management_base_cmd.py new file mode 100644 index 000000000..14a1ebb44 --- /dev/null +++ b/src/documents/tests/management/test_management_base_cmd.py @@ -0,0 +1,579 @@ +"""Tests for PaperlessCommand base class.""" + +from __future__ import annotations + +import io +from typing import TYPE_CHECKING +from typing import ClassVar + +import pytest +from django.core.management import CommandError +from rich.console import Console + +from documents.management.commands.base import PaperlessCommand +from documents.management.commands.base import ProcessResult + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + + +# --- Test Commands --- +# These simulate real command implementations for testing + + +class SimpleCommand(PaperlessCommand): + """Command with default settings (progress bar, no multiprocessing).""" + + help = "Simple test command" + + def handle(self, *args, **options): + items = list(range(5)) + results = [] + for item in self.track(items, description="Processing..."): + results.append(item * 2) + self.stdout.write(f"Results: {results}") + + +class NoProgressBarCommand(PaperlessCommand): + """Command with progress bar disabled.""" + + help = "No progress bar command" + supports_progress_bar = False + + def handle(self, *args, **options): + items = list(range(3)) + for item in self.track(items): + pass + self.stdout.write("Done") + + +class MultiprocessCommand(PaperlessCommand): + """Command with multiprocessing support.""" + + help = "Multiprocess test command" + supports_multiprocessing = True + + def handle(self, *args, **options): + items = list(range(5)) + results = [] + for result in self.process_parallel( + _double_value, + items, + description="Processing...", + ): + results.append(result) + successes = sum(1 for r in results if r.success) + self.stdout.write(f"Successes: {successes}") + + +# --- Helper Functions for Multiprocessing --- +# Must be at module level to be picklable + + +def _double_value(x: int) -> int: + """Double the input value.""" + return x * 2 + + +def _divide_ten_by(x: int) -> float: + """Divide 10 by x. Raises ZeroDivisionError if x is 0.""" + return 10 / x + + +# --- Fixtures --- + + +@pytest.fixture +def console() -> Console: + """Create a non-interactive console for testing.""" + return Console(force_terminal=False, force_interactive=False) + + +@pytest.fixture +def simple_command(console: Console) -> SimpleCommand: + """Create a SimpleCommand instance configured for testing.""" + command = SimpleCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + command.console = console + command.no_progress_bar = True + command.process_count = 1 + return command + + +@pytest.fixture +def multiprocess_command(console: Console) -> MultiprocessCommand: + """Create a MultiprocessCommand instance configured for testing.""" + command = MultiprocessCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + command.console = console + command.no_progress_bar = True + command.process_count = 1 + return command + + +@pytest.fixture +def mock_queryset(): + """ + Create a mock Django QuerySet that tracks method calls. + + This verifies we use .count() instead of len() for querysets. + """ + + class MockQuerySet: + def __init__(self, items: list): + self._items = items + self.count_called = False + + def count(self) -> int: + self.count_called = True + return len(self._items) + + def __iter__(self): + return iter(self._items) + + def __len__(self): + raise AssertionError("len() should not be called on querysets") + + return MockQuerySet + + +# --- Test Classes --- + + +@pytest.mark.management +class TestProcessResult: + """Tests for the ProcessResult dataclass.""" + + def test_success_result(self): + result = ProcessResult(item=1, result=2, error=None) + + assert result.item == 1 + assert result.result == 2 + assert result.error is None + assert result.success is True + + def test_error_result(self): + error = ValueError("test error") + result = ProcessResult(item=1, result=None, error=error) + + assert result.item == 1 + assert result.result is None + assert result.error is error + assert result.success is False + + +@pytest.mark.management +class TestPaperlessCommandArguments: + """Tests for argument parsing behavior.""" + + def test_progress_bar_argument_added_by_default(self): + command = SimpleCommand() + parser = command.create_parser("manage.py", "simple") + + options = parser.parse_args(["--no-progress-bar"]) + assert options.no_progress_bar is True + + options = parser.parse_args([]) + assert options.no_progress_bar is False + + def test_progress_bar_argument_not_added_when_disabled(self): + command = NoProgressBarCommand() + parser = command.create_parser("manage.py", "noprogress") + + options = parser.parse_args([]) + assert not hasattr(options, "no_progress_bar") + + def test_processes_argument_added_when_multiprocessing_enabled(self): + command = MultiprocessCommand() + parser = command.create_parser("manage.py", "multiprocess") + + options = parser.parse_args(["--processes", "4"]) + assert options.processes == 4 + + options = parser.parse_args([]) + assert options.processes >= 1 + + def test_processes_argument_not_added_when_multiprocessing_disabled(self): + command = SimpleCommand() + parser = command.create_parser("manage.py", "simple") + + options = parser.parse_args([]) + assert not hasattr(options, "processes") + + +@pytest.mark.management +class TestPaperlessCommandExecute: + """Tests for the execute() setup behavior.""" + + @pytest.fixture + def base_options(self) -> dict: + """Base options required for execute().""" + return { + "verbosity": 1, + "no_color": True, + "force_color": False, + "skip_checks": True, + } + + @pytest.mark.parametrize( + ("no_progress_bar_flag", "expected"), + [ + pytest.param(False, False, id="progress-bar-enabled"), + pytest.param(True, True, id="progress-bar-disabled"), + ], + ) + def test_no_progress_bar_state_set( + self, + base_options: dict, + *, + no_progress_bar_flag: bool, + expected: bool, + ): + command = SimpleCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + + options = {**base_options, "no_progress_bar": no_progress_bar_flag} + command.execute(**options) + + assert command.no_progress_bar is expected + + def test_no_progress_bar_always_true_when_not_supported(self, base_options: dict): + command = NoProgressBarCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + + command.execute(**base_options) + + assert command.no_progress_bar is True + + @pytest.mark.parametrize( + ("processes", "expected"), + [ + pytest.param(1, 1, id="single-process"), + pytest.param(4, 4, id="four-processes"), + pytest.param(8, 8, id="eight-processes"), + ], + ) + def test_process_count_set( + self, + base_options: dict, + processes: int, + expected: int, + ): + command = MultiprocessCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + + options = {**base_options, "processes": processes, "no_progress_bar": True} + command.execute(**options) + + assert command.process_count == expected + + @pytest.mark.parametrize( + "invalid_count", + [ + pytest.param(0, id="zero"), + pytest.param(-1, id="negative"), + ], + ) + def test_process_count_validation_rejects_invalid( + self, + base_options: dict, + invalid_count: int, + ): + command = MultiprocessCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + + options = {**base_options, "processes": invalid_count, "no_progress_bar": True} + + with pytest.raises(CommandError, match="--processes must be at least 1"): + command.execute(**options) + + def test_process_count_defaults_to_one_when_not_supported(self, base_options: dict): + command = SimpleCommand() + command.stdout = io.StringIO() + command.stderr = io.StringIO() + + options = {**base_options, "no_progress_bar": True} + command.execute(**options) + + assert command.process_count == 1 + + +@pytest.mark.management +class TestGetIterableLength: + """Tests for the _get_iterable_length() method.""" + + def test_uses_count_method_for_querysets( + self, + simple_command: SimpleCommand, + mock_queryset, + ): + """Should use .count() for Django querysets (SELECT COUNT(*)).""" + queryset = mock_queryset([1, 2, 3, 4, 5]) + + result = simple_command._get_iterable_length(queryset) + + assert result == 5 + assert queryset.count_called is True + + def test_falls_back_to_len_for_sequences(self, simple_command: SimpleCommand): + """Should use len() for regular sequences without .count().""" + items = [1, 2, 3, 4] + + result = simple_command._get_iterable_length(items) + + assert result == 4 + + def test_returns_none_for_generators(self, simple_command: SimpleCommand): + """Should return None for iterables without len() or count().""" + + def gen(): + yield from [1, 2, 3] + + result = simple_command._get_iterable_length(gen()) + + assert result is None + + def test_handles_count_not_callable(self, simple_command: SimpleCommand): + """Should skip .count if it's not callable (edge case).""" + + class WeirdObject: + count = 42 # Attribute, not method + + def __len__(self): + return 10 + + result = simple_command._get_iterable_length(WeirdObject()) + + assert result == 10 + + def test_handles_list_count_method(self, simple_command: SimpleCommand): + """list.count(value) requires an argument, should fall back to len().""" + items = [1, 2, 3, 4, 5] + + # list has a .count() method but it requires an argument + # _get_iterable_length should catch TypeError and use len() instead + result = simple_command._get_iterable_length(items) + + assert result == 5 + + +@pytest.mark.management +class TestTrack: + """Tests for the track() method.""" + + def test_yields_all_items(self, simple_command: SimpleCommand): + items = [1, 2, 3, 4, 5] + + result = list(simple_command.track(items)) + + assert result == items + + def test_with_progress_bar_disabled(self, simple_command: SimpleCommand): + simple_command.no_progress_bar = True + items = ["a", "b", "c"] + + result = list(simple_command.track(items, description="Test...")) + + assert result == items + + def test_with_progress_bar_enabled(self, simple_command: SimpleCommand): + simple_command.no_progress_bar = False + items = [1, 2, 3] + + result = list(simple_command.track(items, description="Processing...")) + + assert result == items + + def test_with_explicit_total(self, simple_command: SimpleCommand): + simple_command.no_progress_bar = False + + def gen(): + yield from [1, 2, 3] + + result = list(simple_command.track(gen(), total=3)) + + assert result == [1, 2, 3] + + def test_with_generator_no_total(self, simple_command: SimpleCommand): + def gen(): + yield from [1, 2, 3] + + result = list(simple_command.track(gen())) + + assert result == [1, 2, 3] + + def test_empty_iterable(self, simple_command: SimpleCommand): + result = list(simple_command.track([])) + + assert result == [] + + def test_uses_queryset_count( + self, + simple_command: SimpleCommand, + mock_queryset, + mocker: MockerFixture, + ): + """Verify track() uses .count() for querysets.""" + simple_command.no_progress_bar = False + queryset = mock_queryset([1, 2, 3]) + + spy = mocker.spy(simple_command, "_get_iterable_length") + + result = list(simple_command.track(queryset)) + + assert result == [1, 2, 3] + spy.assert_called_once_with(queryset) + assert queryset.count_called is True + + +@pytest.mark.management +class TestProcessParallel: + """Tests for the process_parallel() method.""" + + def test_sequential_processing_single_process( + self, + multiprocess_command: MultiprocessCommand, + ): + multiprocess_command.process_count = 1 + items = [1, 2, 3, 4, 5] + + results = list(multiprocess_command.process_parallel(_double_value, items)) + + assert len(results) == 5 + assert all(r.success for r in results) + + result_map = {r.item: r.result for r in results} + assert result_map == {1: 2, 2: 4, 3: 6, 4: 8, 5: 10} + + def test_sequential_processing_handles_errors( + self, + multiprocess_command: MultiprocessCommand, + ): + multiprocess_command.process_count = 1 + items = [1, 2, 0, 4] # 0 causes ZeroDivisionError + + results = list(multiprocess_command.process_parallel(_divide_ten_by, items)) + + assert len(results) == 4 + + successes = [r for r in results if r.success] + failures = [r for r in results if not r.success] + + assert len(successes) == 3 + assert len(failures) == 1 + assert failures[0].item == 0 + assert isinstance(failures[0].error, ZeroDivisionError) + + def test_parallel_closes_db_connections( + self, + multiprocess_command: MultiprocessCommand, + mocker: MockerFixture, + ): + multiprocess_command.process_count = 2 + items = [1, 2, 3] + + mock_connections = mocker.patch( + "documents.management.commands.base.db.connections", + ) + + results = list(multiprocess_command.process_parallel(_double_value, items)) + + mock_connections.close_all.assert_called_once() + assert len(results) == 3 + + def test_parallel_processing_handles_errors( + self, + multiprocess_command: MultiprocessCommand, + mocker: MockerFixture, + ): + multiprocess_command.process_count = 2 + items = [1, 2, 0, 4] + + mocker.patch("documents.management.commands.base.db.connections") + + results = list(multiprocess_command.process_parallel(_divide_ten_by, items)) + + failures = [r for r in results if not r.success] + assert len(failures) == 1 + assert failures[0].item == 0 + + def test_empty_items(self, multiprocess_command: MultiprocessCommand): + results = list(multiprocess_command.process_parallel(_double_value, [])) + + assert results == [] + + def test_result_contains_original_item( + self, + multiprocess_command: MultiprocessCommand, + ): + items = [10, 20, 30] + + results = list(multiprocess_command.process_parallel(_double_value, items)) + + for result in results: + assert result.item in items + assert result.result == result.item * 2 + + def test_sequential_path_used_for_single_process( + self, + multiprocess_command: MultiprocessCommand, + mocker: MockerFixture, + ): + """Verify single process uses sequential path (important for testing).""" + multiprocess_command.process_count = 1 + + spy_sequential = mocker.spy(multiprocess_command, "_process_sequential") + spy_parallel = mocker.spy(multiprocess_command, "_process_parallel") + + list(multiprocess_command.process_parallel(_double_value, [1, 2, 3])) + + spy_sequential.assert_called_once() + spy_parallel.assert_not_called() + + def test_parallel_path_used_for_multiple_processes( + self, + multiprocess_command: MultiprocessCommand, + mocker: MockerFixture, + ): + """Verify multiple processes uses parallel path.""" + multiprocess_command.process_count = 2 + + mocker.patch("documents.management.commands.base.db.connections") + spy_sequential = mocker.spy(multiprocess_command, "_process_sequential") + spy_parallel = mocker.spy(multiprocess_command, "_process_parallel") + + list(multiprocess_command.process_parallel(_double_value, [1, 2, 3])) + + spy_parallel.assert_called_once() + spy_sequential.assert_not_called() + + +@pytest.mark.management +class TestClassVariableDefaults: + """Tests for class variable default behavior.""" + + def test_default_supports_progress_bar(self): + assert PaperlessCommand.supports_progress_bar is True + + def test_default_supports_multiprocessing(self): + assert PaperlessCommand.supports_multiprocessing is False + + def test_subclass_can_override_progress_bar(self): + class NoProgressCommand(PaperlessCommand): + supports_progress_bar: ClassVar[bool] = False + + assert NoProgressCommand.supports_progress_bar is False + assert PaperlessCommand.supports_progress_bar is True + + def test_subclass_can_override_multiprocessing(self): + class ParallelCommand(PaperlessCommand): + supports_multiprocessing: ClassVar[bool] = True + + assert ParallelCommand.supports_multiprocessing is True + assert PaperlessCommand.supports_multiprocessing is False diff --git a/src/documents/tests/test_management.py b/src/documents/tests/test_management.py index 63c870e23..074e8039a 100644 --- a/src/documents/tests/test_management.py +++ b/src/documents/tests/test_management.py @@ -4,6 +4,7 @@ from io import StringIO from pathlib import Path from unittest import mock +import pytest from auditlog.models import LogEntry from django.contrib.contenttypes.models import ContentType from django.core.management import call_command @@ -19,6 +20,7 @@ from documents.tests.utils import FileSystemAssertsMixin sample_file: Path = Path(__file__).parent / "samples" / "simple.pdf" +@pytest.mark.management @override_settings(FILENAME_FORMAT="{correspondent}/{title}") class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def make_models(self): @@ -94,6 +96,7 @@ class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase): self.assertEqual(doc2.archive_filename, "document_01.pdf") +@pytest.mark.management class TestMakeIndex(TestCase): @mock.patch("documents.management.commands.document_index.index_reindex") def test_reindex(self, m) -> None: @@ -106,6 +109,7 @@ class TestMakeIndex(TestCase): m.assert_called_once() +@pytest.mark.management class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase): @override_settings(FILENAME_FORMAT="") def test_rename(self) -> None: @@ -140,6 +144,7 @@ class TestCreateClassifier(TestCase): m.assert_called_once() +@pytest.mark.management class TestSanityChecker(DirectoriesMixin, TestCase): def test_no_issues(self) -> None: with self.assertLogs() as capture: @@ -165,6 +170,7 @@ class TestSanityChecker(DirectoriesMixin, TestCase): self.assertIn("Checksum mismatch. Stored: abc, actual:", capture.output[1]) +@pytest.mark.management class TestConvertMariaDBUUID(TestCase): @mock.patch("django.db.connection.schema_editor") def test_convert(self, m) -> None: @@ -178,6 +184,7 @@ class TestConvertMariaDBUUID(TestCase): self.assertIn("Successfully converted", stdout.getvalue()) +@pytest.mark.management class TestPruneAuditLogs(TestCase): def test_prune_audit_logs(self) -> None: LogEntry.objects.create( diff --git a/src/documents/tests/test_management_consumer.py b/src/documents/tests/test_management_consumer.py index 810ae63e2..f4451f545 100644 --- a/src/documents/tests/test_management_consumer.py +++ b/src/documents/tests/test_management_consumer.py @@ -577,6 +577,7 @@ class TestTagsFromPath: assert len(tag_ids) == 0 +@pytest.mark.management class TestCommandValidation: """Tests for command argument validation.""" @@ -605,6 +606,7 @@ class TestCommandValidation: cmd.handle(directory=str(sample_pdf), oneshot=True, testing=False) +@pytest.mark.management @pytest.mark.usefixtures("mock_supported_extensions") class TestCommandOneshot: """Tests for oneshot mode.""" @@ -775,6 +777,7 @@ def start_consumer( ) +@pytest.mark.management @pytest.mark.django_db class TestCommandWatch: """Integration tests for the watch loop.""" @@ -896,6 +899,7 @@ class TestCommandWatch: assert not thread.is_alive() +@pytest.mark.management @pytest.mark.django_db class TestCommandWatchPolling: """Tests for polling mode.""" @@ -928,6 +932,7 @@ class TestCommandWatchPolling: mock_consume_file_delay.delay.assert_called() +@pytest.mark.management @pytest.mark.django_db class TestCommandWatchRecursive: """Tests for recursive watching.""" @@ -991,6 +996,7 @@ class TestCommandWatchRecursive: assert len(overrides.tag_ids) == 2 +@pytest.mark.management @pytest.mark.django_db class TestCommandWatchEdgeCases: """Tests for edge cases and error handling.""" diff --git a/src/documents/tests/test_management_exporter.py b/src/documents/tests/test_management_exporter.py index 391d87f41..e6daf5991 100644 --- a/src/documents/tests/test_management_exporter.py +++ b/src/documents/tests/test_management_exporter.py @@ -7,6 +7,7 @@ from pathlib import Path from unittest import mock from zipfile import ZipFile +import pytest from allauth.socialaccount.models import SocialAccount from allauth.socialaccount.models import SocialApp from allauth.socialaccount.models import SocialToken @@ -45,6 +46,7 @@ from documents.tests.utils import paperless_environment from paperless_mail.models import MailAccount +@pytest.mark.management class TestExportImport( DirectoriesMixin, FileSystemAssertsMixin, @@ -846,6 +848,7 @@ class TestExportImport( self.assertEqual(Document.objects.all().count(), 4) +@pytest.mark.management class TestCryptExportImport( DirectoriesMixin, FileSystemAssertsMixin, diff --git a/src/documents/tests/test_management_fuzzy.py b/src/documents/tests/test_management_fuzzy.py index 5ba57b15b..e097a015f 100644 --- a/src/documents/tests/test_management_fuzzy.py +++ b/src/documents/tests/test_management_fuzzy.py @@ -1,5 +1,6 @@ from io import StringIO +import pytest from django.core.management import CommandError from django.core.management import call_command from django.test import TestCase @@ -7,6 +8,7 @@ from django.test import TestCase from documents.models import Document +@pytest.mark.management class TestFuzzyMatchCommand(TestCase): MSG_REGEX = r"Document \d fuzzy match to \d \(confidence \d\d\.\d\d\d\)" @@ -49,19 +51,6 @@ class TestFuzzyMatchCommand(TestCase): self.call_command("--ratio", "101") self.assertIn("The ratio must be between 0 and 100", str(e.exception)) - def test_invalid_process_count(self) -> None: - """ - GIVEN: - - Invalid process count less than 0 above upper - WHEN: - - Command is called - THEN: - - Error is raised indicating issue - """ - with self.assertRaises(CommandError) as e: - self.call_command("--processes", "0") - self.assertIn("There must be at least 1 process", str(e.exception)) - def test_no_matches(self) -> None: """ GIVEN: @@ -194,7 +183,7 @@ class TestFuzzyMatchCommand(TestCase): self.assertEqual(Document.objects.count(), 3) - stdout, _ = self.call_command("--delete") + stdout, _ = self.call_command("--delete", "--no-progress-bar") self.assertIn( "The command is configured to delete documents. Use with caution", diff --git a/src/documents/tests/test_management_importer.py b/src/documents/tests/test_management_importer.py index 8537716ee..04045c805 100644 --- a/src/documents/tests/test_management_importer.py +++ b/src/documents/tests/test_management_importer.py @@ -4,6 +4,7 @@ from io import StringIO from pathlib import Path from zipfile import ZipFile +import pytest from django.contrib.auth.models import User from django.core.management import call_command from django.core.management.base import CommandError @@ -18,6 +19,7 @@ from documents.tests.utils import FileSystemAssertsMixin from documents.tests.utils import SampleDirMixin +@pytest.mark.management class TestCommandImport( DirectoriesMixin, FileSystemAssertsMixin, diff --git a/src/documents/tests/test_management_retagger.py b/src/documents/tests/test_management_retagger.py index 87912211b..29b322c28 100644 --- a/src/documents/tests/test_management_retagger.py +++ b/src/documents/tests/test_management_retagger.py @@ -1,3 +1,4 @@ +import pytest from django.core.management import call_command from django.core.management.base import CommandError from django.test import TestCase @@ -10,6 +11,7 @@ from documents.models import Tag from documents.tests.utils import DirectoriesMixin +@pytest.mark.management class TestRetagger(DirectoriesMixin, TestCase): def make_models(self) -> None: self.sp1 = StoragePath.objects.create( diff --git a/src/documents/tests/test_management_superuser.py b/src/documents/tests/test_management_superuser.py index 55484eb05..0a6bcb8cd 100644 --- a/src/documents/tests/test_management_superuser.py +++ b/src/documents/tests/test_management_superuser.py @@ -2,6 +2,7 @@ import os from io import StringIO from unittest import mock +import pytest from django.contrib.auth.models import User from django.core.management import call_command from django.test import TestCase @@ -9,6 +10,7 @@ from django.test import TestCase from documents.tests.utils import DirectoriesMixin +@pytest.mark.management class TestManageSuperUser(DirectoriesMixin, TestCase): def call_command(self, environ): out = StringIO() diff --git a/src/documents/tests/test_management_thumbnails.py b/src/documents/tests/test_management_thumbnails.py index 0cb65e4d4..160cb4419 100644 --- a/src/documents/tests/test_management_thumbnails.py +++ b/src/documents/tests/test_management_thumbnails.py @@ -2,6 +2,7 @@ import shutil from pathlib import Path from unittest import mock +import pytest from django.core.management import call_command from django.test import TestCase @@ -12,6 +13,7 @@ from documents.tests.utils import DirectoriesMixin from documents.tests.utils import FileSystemAssertsMixin +@pytest.mark.management class TestMakeThumbnails(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def make_models(self) -> None: self.d1 = Document.objects.create( diff --git a/uv.lock b/uv.lock index 2b1ee98b1..35b259921 100644 --- a/uv.lock +++ b/uv.lock @@ -1137,6 +1137,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/10/23c0644cf67567bbe4e3a2eeeec0e9c79b701990c0e07c5ee4a4f8897f91/django_multiselectfield-1.0.1-py3-none-any.whl", hash = "sha256:18dc14801f7eca844a48e21cba6d8ec35b9b581f2373bbb2cb75e6994518259a", size = 20481, upload-time = "2025-06-12T14:41:20.107Z" }, ] +[[package]] +name = "django-rich" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/67/e307a5fef657e7992468f567b521534c52e01bdda5a1ae5b12de679a670f/django_rich-2.2.0.tar.gz", hash = "sha256:ecec7842d040024ed8a225699388535e46b87277550c33f46193b52cece2f780", size = 62427, upload-time = "2025-09-18T11:42:17.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/ed/23fa669493d78cd67e7f6734fa380f8690f2b4d75b4f72fd645a52c3b32a/django_rich-2.2.0-py3-none-any.whl", hash = "sha256:a0d2c916bd9750b6e9beb57407aef5e836c8705d7dbe9e4fd4725f7bbe41c407", size = 9210, upload-time = "2025-09-18T11:42:15.779Z" }, +] + [[package]] name = "django-soft-delete" version = "1.0.22" @@ -3041,6 +3054,7 @@ dependencies = [ { name = "django-filter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "django-guardian", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "django-multiselectfield", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "django-rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "django-soft-delete", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "django-treenode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -3081,7 +3095,6 @@ dependencies = [ { name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "torch", version = "2.10.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, { name = "torch", version = "2.10.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'linux'" }, - { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "watchfiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "whitenoise", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "whoosh-reloaded", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -3186,6 +3199,7 @@ requires-dist = [ { name = "django-filter", specifier = "~=25.1" }, { name = "django-guardian", specifier = "~=3.2.0" }, { name = "django-multiselectfield", specifier = "~=1.0.1" }, + { name = "django-rich", specifier = "~=2.2.0" }, { name = "django-soft-delete", specifier = "~=1.0.18" }, { name = "django-treenode", specifier = ">=0.23.2" }, { name = "djangorestframework", specifier = "~=3.16" }, @@ -3232,7 +3246,6 @@ requires-dist = [ { name = "setproctitle", specifier = "~=1.3.4" }, { name = "tika-client", specifier = "~=0.10.0" }, { name = "torch", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cpu" }, - { name = "tqdm", specifier = "~=4.67.1" }, { name = "watchfiles", specifier = ">=1.1.1" }, { name = "whitenoise", specifier = "~=6.11" }, { name = "whoosh-reloaded", specifier = ">=2.7.5" },