mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-02-26 01:09:34 -06:00
Initial conversion to new base class
This commit is contained in:
@@ -37,6 +37,7 @@ dependencies = [
|
|||||||
"django-filter~=25.1",
|
"django-filter~=25.1",
|
||||||
"django-guardian~=3.2.0",
|
"django-guardian~=3.2.0",
|
||||||
"django-multiselectfield~=1.0.1",
|
"django-multiselectfield~=1.0.1",
|
||||||
|
"django-rich~=2.2.0",
|
||||||
"django-soft-delete~=1.0.18",
|
"django-soft-delete~=1.0.18",
|
||||||
"django-treenode>=0.23.2",
|
"django-treenode>=0.23.2",
|
||||||
"djangorestframework~=3.16",
|
"djangorestframework~=3.16",
|
||||||
@@ -76,7 +77,6 @@ dependencies = [
|
|||||||
"setproctitle~=1.3.4",
|
"setproctitle~=1.3.4",
|
||||||
"tika-client~=0.10.0",
|
"tika-client~=0.10.0",
|
||||||
"torch~=2.10.0",
|
"torch~=2.10.0",
|
||||||
"tqdm~=4.67.1",
|
|
||||||
"watchfiles>=1.1.1",
|
"watchfiles>=1.1.1",
|
||||||
"whitenoise~=6.11",
|
"whitenoise~=6.11",
|
||||||
"whoosh-reloaded>=2.7.5",
|
"whoosh-reloaded>=2.7.5",
|
||||||
@@ -304,6 +304,7 @@ markers = [
|
|||||||
"tika: Tests requiring Tika service",
|
"tika: Tests requiring Tika service",
|
||||||
"greenmail: Tests requiring Greenmail service",
|
"greenmail: Tests requiring Greenmail service",
|
||||||
"date_parsing: Tests which cover date parsing from content or filename",
|
"date_parsing: Tests which cover date parsing from content or filename",
|
||||||
|
"management: Tests which cover management commands/functionality",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
|
|||||||
316
src/documents/management/commands/base.py
Normal file
316
src/documents/management/commands/base.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
Base command class for Paperless-ngx management commands.
|
||||||
|
|
||||||
|
Provides automatic progress bar and multiprocessing support with minimal boilerplate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from concurrent.futures import as_completed
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from typing import ClassVar
|
||||||
|
from typing import Generic
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from django import db
|
||||||
|
from django.core.management import CommandError
|
||||||
|
from django_rich.management import RichCommand
|
||||||
|
from rich.progress import BarColumn
|
||||||
|
from rich.progress import MofNCompleteColumn
|
||||||
|
from rich.progress import Progress
|
||||||
|
from rich.progress import SpinnerColumn
|
||||||
|
from rich.progress import TextColumn
|
||||||
|
from rich.progress import TimeElapsedColumn
|
||||||
|
from rich.progress import TimeRemainingColumn
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ProcessResult(Generic[T, R]):
|
||||||
|
"""
|
||||||
|
Result of processing a single item in parallel.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
item: The input item that was processed.
|
||||||
|
result: The return value from the processing function, or None if an error occurred.
|
||||||
|
error: The exception if processing failed, or None on success.
|
||||||
|
"""
|
||||||
|
|
||||||
|
item: T
|
||||||
|
result: R | None
|
||||||
|
error: BaseException | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success(self) -> bool:
|
||||||
|
"""Return True if the item was processed successfully."""
|
||||||
|
return self.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class PaperlessCommand(RichCommand):
|
||||||
|
"""
|
||||||
|
Base command class with automatic progress bar and multiprocessing support.
|
||||||
|
|
||||||
|
Features are opt-in via class attributes:
|
||||||
|
supports_progress_bar: Adds --no-progress-bar argument (default: True)
|
||||||
|
supports_multiprocessing: Adds --processes argument (default: False)
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
class Command(PaperlessCommand):
|
||||||
|
help = "Process all documents"
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
documents = Document.objects.all()
|
||||||
|
for doc in self.track(documents, description="Processing..."):
|
||||||
|
process_document(doc)
|
||||||
|
|
||||||
|
class Command(PaperlessCommand):
|
||||||
|
help = "Regenerate thumbnails"
|
||||||
|
supports_multiprocessing = True
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
ids = list(Document.objects.values_list("id", flat=True))
|
||||||
|
for result in self.process_parallel(process_doc, ids):
|
||||||
|
if result.error:
|
||||||
|
self.console.print(f"[red]Failed: {result.error}[/red]")
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_progress_bar: ClassVar[bool] = True
|
||||||
|
supports_multiprocessing: ClassVar[bool] = False
|
||||||
|
|
||||||
|
# Instance attributes set by execute() before handle() runs
|
||||||
|
no_progress_bar: bool
|
||||||
|
process_count: int
|
||||||
|
|
||||||
|
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||||
|
"""Add arguments based on supported features."""
|
||||||
|
super().add_arguments(parser)
|
||||||
|
|
||||||
|
if self.supports_progress_bar:
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-progress-bar",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Disable the progress bar",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.supports_multiprocessing:
|
||||||
|
default_processes = max(1, (os.cpu_count() or 1) // 4)
|
||||||
|
parser.add_argument(
|
||||||
|
"--processes",
|
||||||
|
default=default_processes,
|
||||||
|
type=int,
|
||||||
|
help=f"Number of processes to use (default: {default_processes})",
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, *args, **options) -> int:
|
||||||
|
"""
|
||||||
|
Set up instance state before handle() is called.
|
||||||
|
|
||||||
|
This is called by Django's command infrastructure after argument parsing
|
||||||
|
but before handle(). We use it to set instance attributes from options.
|
||||||
|
"""
|
||||||
|
# Set progress bar state
|
||||||
|
if self.supports_progress_bar:
|
||||||
|
self.no_progress_bar = options.get("no_progress_bar", False)
|
||||||
|
else:
|
||||||
|
self.no_progress_bar = True
|
||||||
|
|
||||||
|
# Set multiprocessing state
|
||||||
|
if self.supports_multiprocessing:
|
||||||
|
self.process_count = options.get("processes", 1)
|
||||||
|
if self.process_count < 1:
|
||||||
|
raise CommandError("--processes must be at least 1")
|
||||||
|
else:
|
||||||
|
self.process_count = 1
|
||||||
|
|
||||||
|
return super().execute(*args, **options)
|
||||||
|
|
||||||
|
def _create_progress(self, description: str) -> Progress:
|
||||||
|
"""
|
||||||
|
Create a configured Progress instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Text to display alongside the progress bar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Progress instance configured with appropriate columns.
|
||||||
|
"""
|
||||||
|
return Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=self.console,
|
||||||
|
transient=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_iterable_length(self, iterable: Iterable[T]) -> int | None:
|
||||||
|
"""
|
||||||
|
Attempt to determine the length of an iterable without consuming it.
|
||||||
|
|
||||||
|
Tries .count() first (for Django querysets - executes SELECT COUNT(*)),
|
||||||
|
then falls back to len() for sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iterable: The iterable to measure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The length if determinable, None otherwise.
|
||||||
|
"""
|
||||||
|
# Django querysets have .count() which is a SELECT COUNT(*)
|
||||||
|
# This is much more efficient than len() which evaluates the queryset
|
||||||
|
# Note: list.count(value) requires an argument, so we catch TypeError
|
||||||
|
if hasattr(iterable, "count") and callable(iterable.count):
|
||||||
|
try:
|
||||||
|
return iterable.count()
|
||||||
|
except TypeError:
|
||||||
|
# list.count() requires an argument, fall through to len()
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to len() for sequences
|
||||||
|
try:
|
||||||
|
return len(iterable) # type: ignore[arg-type]
|
||||||
|
except TypeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def track(
|
||||||
|
self,
|
||||||
|
iterable: Iterable[T],
|
||||||
|
*,
|
||||||
|
description: str = "Processing...",
|
||||||
|
total: int | None = None,
|
||||||
|
) -> Generator[T, None, None]:
|
||||||
|
"""
|
||||||
|
Iterate over items with an optional progress bar.
|
||||||
|
|
||||||
|
Respects --no-progress-bar flag. When disabled, simply yields items
|
||||||
|
without any progress display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iterable: The items to iterate over.
|
||||||
|
description: Text to display alongside the progress bar.
|
||||||
|
total: Total number of items. If None, attempts to determine
|
||||||
|
automatically via .count() (for querysets) or len().
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Items from the iterable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
for doc in self.track(documents, description="Renaming..."):
|
||||||
|
process(doc)
|
||||||
|
"""
|
||||||
|
if self.no_progress_bar:
|
||||||
|
yield from iterable
|
||||||
|
return
|
||||||
|
|
||||||
|
# Attempt to determine total if not provided
|
||||||
|
if total is None:
|
||||||
|
total = self._get_iterable_length(iterable)
|
||||||
|
|
||||||
|
with self._create_progress(description) as progress:
|
||||||
|
task_id = progress.add_task(description, total=total)
|
||||||
|
for item in iterable:
|
||||||
|
yield item
|
||||||
|
progress.advance(task_id)
|
||||||
|
|
||||||
|
def process_parallel(
|
||||||
|
self,
|
||||||
|
fn: Callable[[T], R],
|
||||||
|
items: Sequence[T],
|
||||||
|
*,
|
||||||
|
description: str = "Processing...",
|
||||||
|
) -> Generator[ProcessResult[T, R], None, None]:
|
||||||
|
"""
|
||||||
|
Process items in parallel with progress tracking.
|
||||||
|
|
||||||
|
When --processes=1, runs sequentially in the main process without
|
||||||
|
spawning subprocesses. This is critical for testing, as multiprocessing
|
||||||
|
breaks fixtures, mocks, and database transactions.
|
||||||
|
|
||||||
|
When --processes > 1, uses ProcessPoolExecutor and automatically closes
|
||||||
|
database connections before spawning workers (required for PostgreSQL).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: Function to apply to each item. Must be picklable for parallel
|
||||||
|
execution (i.e., defined at module level, not a lambda or closure).
|
||||||
|
items: Sequence of items to process.
|
||||||
|
description: Text to display alongside the progress bar.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
ProcessResult for each item, containing the item, result, and any error.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def regenerate_thumbnail(doc_id: int) -> Path:
|
||||||
|
...
|
||||||
|
|
||||||
|
for result in self.process_parallel(regenerate_thumbnail, doc_ids):
|
||||||
|
if result.error:
|
||||||
|
self.console.print(f"[red]Failed {result.item}[/red]")
|
||||||
|
"""
|
||||||
|
total = len(items)
|
||||||
|
|
||||||
|
if self.process_count == 1:
|
||||||
|
# Sequential execution in main process - critical for testing
|
||||||
|
yield from self._process_sequential(fn, items, description, total)
|
||||||
|
else:
|
||||||
|
# Parallel execution with ProcessPoolExecutor
|
||||||
|
yield from self._process_parallel(fn, items, description, total)
|
||||||
|
|
||||||
|
def _process_sequential(
|
||||||
|
self,
|
||||||
|
fn: Callable[[T], R],
|
||||||
|
items: Sequence[T],
|
||||||
|
description: str,
|
||||||
|
total: int,
|
||||||
|
) -> Generator[ProcessResult[T, R], None, None]:
|
||||||
|
"""Process items sequentially in the main process."""
|
||||||
|
for item in self.track(items, description=description, total=total):
|
||||||
|
try:
|
||||||
|
result = fn(item)
|
||||||
|
yield ProcessResult(item=item, result=result, error=None)
|
||||||
|
except Exception as e:
|
||||||
|
yield ProcessResult(item=item, result=None, error=e)
|
||||||
|
|
||||||
|
def _process_parallel(
|
||||||
|
self,
|
||||||
|
fn: Callable[[T], R],
|
||||||
|
items: Sequence[T],
|
||||||
|
description: str,
|
||||||
|
total: int,
|
||||||
|
) -> Generator[ProcessResult[T, R], None, None]:
|
||||||
|
"""Process items in parallel using ProcessPoolExecutor."""
|
||||||
|
# Close database connections before forking - required for PostgreSQL
|
||||||
|
db.connections.close_all()
|
||||||
|
|
||||||
|
with self._create_progress(description) as progress:
|
||||||
|
task_id = progress.add_task(description, total=total)
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=self.process_count) as executor:
|
||||||
|
# Submit all tasks and map futures back to items
|
||||||
|
future_to_item = {executor.submit(fn, item): item for item in items}
|
||||||
|
|
||||||
|
# Yield results as they complete
|
||||||
|
for future in as_completed(future_to_item):
|
||||||
|
item = future_to_item[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
yield ProcessResult(item=item, result=result, error=None)
|
||||||
|
except Exception as e:
|
||||||
|
yield ProcessResult(item=item, result=None, error=e)
|
||||||
|
finally:
|
||||||
|
progress.advance(task_id)
|
||||||
@@ -1,20 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from django import db
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
|
|
||||||
from documents.management.commands.mixins import MultiProcessMixin
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.tasks import update_document_content_maybe_archive_file
|
from documents.tasks import update_document_content_maybe_archive_file
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.management.archiver")
|
logger = logging.getLogger("paperless.management.archiver")
|
||||||
|
|
||||||
|
|
||||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = (
|
help = (
|
||||||
"Using the current classification model, assigns correspondents, tags "
|
"Using the current classification model, assigns correspondents, tags "
|
||||||
"and document types to all documents, effectively allowing you to "
|
"and document types to all documents, effectively allowing you to "
|
||||||
@@ -22,7 +17,10 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
|||||||
"modified) after their initial import."
|
"modified) after their initial import."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
supports_multiprocessing = True
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
super().add_arguments(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-f",
|
"-f",
|
||||||
"--overwrite",
|
"--overwrite",
|
||||||
@@ -44,13 +42,8 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
|||||||
"run on this specific document."
|
"run on this specific document."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
self.add_argument_processes_mixin(parser)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
self.handle_processes_mixin(**options)
|
|
||||||
self.handle_progress_bar_mixin(**options)
|
|
||||||
|
|
||||||
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
overwrite = options["overwrite"]
|
overwrite = options["overwrite"]
|
||||||
@@ -60,35 +53,21 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
|||||||
else:
|
else:
|
||||||
documents = Document.objects.all()
|
documents = Document.objects.all()
|
||||||
|
|
||||||
document_ids = list(
|
document_ids = [
|
||||||
map(
|
doc.id for doc in documents if overwrite or not doc.has_archive_version
|
||||||
lambda doc: doc.id,
|
]
|
||||||
filter(lambda d: overwrite or not d.has_archive_version, documents),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note to future self: this prevents django from reusing database
|
|
||||||
# connections between processes, which is bad and does not work
|
|
||||||
# with postgres.
|
|
||||||
db.connections.close_all()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.getLogger().handlers[0].level = logging.ERROR
|
logging.getLogger().handlers[0].level = logging.ERROR
|
||||||
|
|
||||||
if self.process_count == 1:
|
for result in self.process_parallel(
|
||||||
for doc_id in document_ids:
|
update_document_content_maybe_archive_file,
|
||||||
update_document_content_maybe_archive_file(doc_id)
|
document_ids,
|
||||||
else: # pragma: no cover
|
description="Archiving...",
|
||||||
with multiprocessing.Pool(self.process_count) as pool:
|
):
|
||||||
list(
|
if result.error:
|
||||||
tqdm.tqdm(
|
self.console.print(
|
||||||
pool.imap_unordered(
|
f"[red]Failed document {result.item}: {result.error}[/red]",
|
||||||
update_document_content_maybe_archive_file,
|
|
||||||
document_ids,
|
|
||||||
),
|
|
||||||
total=len(document_ids),
|
|
||||||
disable=self.no_progress_bar,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.stdout.write(self.style.NOTICE("Aborting..."))
|
self.console.print("[yellow]Aborting...[/yellow]")
|
||||||
|
|||||||
@@ -1,24 +1,20 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import multiprocessing
|
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
import rapidfuzz
|
import rapidfuzz
|
||||||
import tqdm
|
|
||||||
from django.core.management import BaseCommand
|
|
||||||
from django.core.management import CommandError
|
from django.core.management import CommandError
|
||||||
|
|
||||||
from documents.management.commands.mixins import MultiProcessMixin
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True, slots=True)
|
||||||
class _WorkPackage:
|
class _WorkPackage:
|
||||||
first_doc: Document
|
first_doc: Document
|
||||||
second_doc: Document
|
second_doc: Document
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True, slots=True)
|
||||||
class _WorkResult:
|
class _WorkResult:
|
||||||
doc_one_pk: int
|
doc_one_pk: int
|
||||||
doc_two_pk: int
|
doc_two_pk: int
|
||||||
@@ -31,22 +27,23 @@ class _WorkResult:
|
|||||||
def _process_and_match(work: _WorkPackage) -> _WorkResult:
|
def _process_and_match(work: _WorkPackage) -> _WorkResult:
|
||||||
"""
|
"""
|
||||||
Does basic processing of document content, gets the basic ratio
|
Does basic processing of document content, gets the basic ratio
|
||||||
and returns the result package
|
and returns the result package.
|
||||||
"""
|
"""
|
||||||
# Normalize the string some, lower case, whitespace, etc
|
|
||||||
first_string = rapidfuzz.utils.default_process(work.first_doc.content)
|
first_string = rapidfuzz.utils.default_process(work.first_doc.content)
|
||||||
second_string = rapidfuzz.utils.default_process(work.second_doc.content)
|
second_string = rapidfuzz.utils.default_process(work.second_doc.content)
|
||||||
|
|
||||||
# Basic matching ratio
|
|
||||||
match = rapidfuzz.fuzz.ratio(first_string, second_string)
|
match = rapidfuzz.fuzz.ratio(first_string, second_string)
|
||||||
|
|
||||||
return _WorkResult(work.first_doc.pk, work.second_doc.pk, match)
|
return _WorkResult(work.first_doc.pk, work.second_doc.pk, match)
|
||||||
|
|
||||||
|
|
||||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = "Searches for documents where the content almost matches"
|
help = "Searches for documents where the content almost matches"
|
||||||
|
|
||||||
|
supports_multiprocessing = True
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
super().add_arguments(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ratio",
|
"--ratio",
|
||||||
default=85.0,
|
default=85.0,
|
||||||
@@ -59,91 +56,72 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="If set, one document of matches above the ratio WILL BE DELETED",
|
help="If set, one document of matches above the ratio WILL BE DELETED",
|
||||||
)
|
)
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
self.add_argument_processes_mixin(parser)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
RATIO_MIN: Final[float] = 0.0
|
RATIO_MIN: Final[float] = 0.0
|
||||||
RATIO_MAX: Final[float] = 100.0
|
RATIO_MAX: Final[float] = 100.0
|
||||||
|
|
||||||
self.handle_processes_mixin(**options)
|
|
||||||
self.handle_progress_bar_mixin(**options)
|
|
||||||
|
|
||||||
if options["delete"]:
|
if options["delete"]:
|
||||||
self.stdout.write(
|
self.console.print(
|
||||||
self.style.WARNING(
|
"[yellow]The command is configured to delete documents. "
|
||||||
"The command is configured to delete documents. Use with caution",
|
"Use with caution.[/yellow]",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
opt_ratio = options["ratio"]
|
opt_ratio = options["ratio"]
|
||||||
checked_pairs: set[tuple[int, int]] = set()
|
checked_pairs: set[tuple[int, int]] = set()
|
||||||
work_pkgs: list[_WorkPackage] = []
|
work_pkgs: list[_WorkPackage] = []
|
||||||
|
|
||||||
# Ratio is a float from 0.0 to 100.0
|
|
||||||
if opt_ratio < RATIO_MIN or opt_ratio > RATIO_MAX:
|
if opt_ratio < RATIO_MIN or opt_ratio > RATIO_MAX:
|
||||||
raise CommandError("The ratio must be between 0 and 100")
|
raise CommandError("The ratio must be between 0 and 100")
|
||||||
|
|
||||||
all_docs = Document.objects.all().order_by("id")
|
all_docs = Document.objects.all().order_by("id")
|
||||||
|
|
||||||
# Build work packages for processing
|
|
||||||
for first_doc in all_docs:
|
for first_doc in all_docs:
|
||||||
for second_doc in all_docs:
|
for second_doc in all_docs:
|
||||||
# doc to doc is obviously not useful
|
|
||||||
if first_doc.pk == second_doc.pk:
|
if first_doc.pk == second_doc.pk:
|
||||||
continue
|
continue
|
||||||
# Skip empty documents (e.g. password-protected)
|
|
||||||
if first_doc.content.strip() == "" or second_doc.content.strip() == "":
|
if first_doc.content.strip() == "" or second_doc.content.strip() == "":
|
||||||
continue
|
continue
|
||||||
# Skip matching which have already been matched together
|
|
||||||
# doc 1 to doc 2 is the same as doc 2 to doc 1
|
|
||||||
doc_1_to_doc_2 = (first_doc.pk, second_doc.pk)
|
doc_1_to_doc_2 = (first_doc.pk, second_doc.pk)
|
||||||
doc_2_to_doc_1 = doc_1_to_doc_2[::-1]
|
doc_2_to_doc_1 = doc_1_to_doc_2[::-1]
|
||||||
if doc_1_to_doc_2 in checked_pairs or doc_2_to_doc_1 in checked_pairs:
|
if doc_1_to_doc_2 in checked_pairs or doc_2_to_doc_1 in checked_pairs:
|
||||||
continue
|
continue
|
||||||
checked_pairs.update([doc_1_to_doc_2, doc_2_to_doc_1])
|
checked_pairs.update([doc_1_to_doc_2, doc_2_to_doc_1])
|
||||||
# Actually something useful to work on now
|
|
||||||
work_pkgs.append(_WorkPackage(first_doc, second_doc))
|
work_pkgs.append(_WorkPackage(first_doc, second_doc))
|
||||||
|
|
||||||
# Don't spin up a pool of 1 process
|
results: list[_WorkResult] = []
|
||||||
if self.process_count == 1:
|
for result in self.process_parallel(
|
||||||
results = []
|
_process_and_match,
|
||||||
for work in tqdm.tqdm(work_pkgs, disable=self.no_progress_bar):
|
work_pkgs,
|
||||||
results.append(_process_and_match(work))
|
description="Matching...",
|
||||||
else: # pragma: no cover
|
):
|
||||||
with multiprocessing.Pool(processes=self.process_count) as pool:
|
if result.error:
|
||||||
results = list(
|
self.console.print(
|
||||||
tqdm.tqdm(
|
f"[red]Failed: {result.error}[/red]",
|
||||||
pool.imap_unordered(_process_and_match, work_pkgs),
|
|
||||||
total=len(work_pkgs),
|
|
||||||
disable=self.no_progress_bar,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
elif result.result is not None:
|
||||||
|
results.append(result.result)
|
||||||
|
|
||||||
# Check results
|
messages: list[str] = []
|
||||||
messages = []
|
maybe_delete_ids: list[int] = []
|
||||||
maybe_delete_ids = []
|
for match_result in sorted(results):
|
||||||
for result in sorted(results):
|
if match_result.ratio >= opt_ratio:
|
||||||
if result.ratio >= opt_ratio:
|
|
||||||
messages.append(
|
messages.append(
|
||||||
self.style.NOTICE(
|
self.style.NOTICE(
|
||||||
f"Document {result.doc_one_pk} fuzzy match"
|
f"Document {match_result.doc_one_pk} fuzzy match"
|
||||||
f" to {result.doc_two_pk} (confidence {result.ratio:.3f})\n",
|
f" to {match_result.doc_two_pk}"
|
||||||
|
f" (confidence {match_result.ratio:.3f})\n",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
maybe_delete_ids.append(result.doc_two_pk)
|
maybe_delete_ids.append(match_result.doc_two_pk)
|
||||||
|
|
||||||
if len(messages) == 0:
|
if len(messages) == 0:
|
||||||
messages.append(
|
messages.append(self.style.SUCCESS("No matches found\n"))
|
||||||
self.style.SUCCESS("No matches found\n"),
|
self.stdout.writelines(messages)
|
||||||
)
|
|
||||||
self.stdout.writelines(
|
|
||||||
messages,
|
|
||||||
)
|
|
||||||
if options["delete"]:
|
if options["delete"]:
|
||||||
self.stdout.write(
|
self.console.print(
|
||||||
self.style.NOTICE(
|
f"[yellow]Deleting {len(maybe_delete_ids)} documents "
|
||||||
f"Deleting {len(maybe_delete_ids)} documents based on ratio matches",
|
f"based on ratio matches[/yellow]",
|
||||||
),
|
|
||||||
)
|
)
|
||||||
Document.objects.filter(pk__in=maybe_delete_ids).delete()
|
Document.objects.filter(pk__in=maybe_delete_ids).delete()
|
||||||
|
|||||||
@@ -1,25 +1,12 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
from django.db.models.signals import post_save
|
from django.db.models.signals import post_save
|
||||||
|
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
|
|
||||||
|
|
||||||
class Command(ProgressBarMixin, BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = "This will rename all documents to match the latest filename format."
|
help = "Rename all documents"
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
self.handle_progress_bar_mixin(**options)
|
for document in self.track(Document.objects.all(), description="Renaming..."):
|
||||||
logging.getLogger().handlers[0].level = logging.ERROR
|
|
||||||
|
|
||||||
for document in tqdm.tqdm(
|
|
||||||
Document.objects.all(),
|
|
||||||
disable=self.no_progress_bar,
|
|
||||||
):
|
|
||||||
post_save.send(Document, instance=document, created=False)
|
post_save.send(Document, instance=document, created=False)
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
|
|
||||||
from documents.classifier import load_classifier
|
from documents.classifier import load_classifier
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.signals.handlers import set_correspondent
|
from documents.signals.handlers import set_correspondent
|
||||||
from documents.signals.handlers import set_document_type
|
from documents.signals.handlers import set_document_type
|
||||||
@@ -14,7 +11,7 @@ from documents.signals.handlers import set_tags
|
|||||||
logger = logging.getLogger("paperless.management.retagger")
|
logger = logging.getLogger("paperless.management.retagger")
|
||||||
|
|
||||||
|
|
||||||
class Command(ProgressBarMixin, BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = (
|
help = (
|
||||||
"Using the current classification model, assigns correspondents, tags "
|
"Using the current classification model, assigns correspondents, tags "
|
||||||
"and document types to all documents, effectively allowing you to "
|
"and document types to all documents, effectively allowing you to "
|
||||||
@@ -23,6 +20,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
super().add_arguments(parser)
|
||||||
parser.add_argument("-c", "--correspondent", default=False, action="store_true")
|
parser.add_argument("-c", "--correspondent", default=False, action="store_true")
|
||||||
parser.add_argument("-T", "--tags", default=False, action="store_true")
|
parser.add_argument("-T", "--tags", default=False, action="store_true")
|
||||||
parser.add_argument("-t", "--document_type", default=False, action="store_true")
|
parser.add_argument("-t", "--document_type", default=False, action="store_true")
|
||||||
@@ -34,7 +32,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=(
|
||||||
"By default this command won't try to assign a correspondent "
|
"By default this command won't try to assign a correspondent "
|
||||||
"if more than one matches the document. Use this flag if "
|
"if more than one matches the document. Use this flag if "
|
||||||
"you'd rather it just pick the first one it finds."
|
"you'd rather it just pick the first one it finds."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -49,7 +47,6 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
"and tags that do not match anymore due to changed rules."
|
"and tags that do not match anymore due to changed rules."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--suggest",
|
"--suggest",
|
||||||
default=False,
|
default=False,
|
||||||
@@ -68,8 +65,6 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
self.handle_progress_bar_mixin(**options)
|
|
||||||
|
|
||||||
if options["inbox_only"]:
|
if options["inbox_only"]:
|
||||||
queryset = Document.objects.filter(tags__is_inbox_tag=True)
|
queryset = Document.objects.filter(tags__is_inbox_tag=True)
|
||||||
else:
|
else:
|
||||||
@@ -84,7 +79,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
|
|
||||||
classifier = load_classifier()
|
classifier = load_classifier()
|
||||||
|
|
||||||
for document in tqdm.tqdm(documents, disable=self.no_progress_bar):
|
for document in self.track(documents, description="Retagging..."):
|
||||||
if options["correspondent"]:
|
if options["correspondent"]:
|
||||||
set_correspondent(
|
set_correspondent(
|
||||||
sender=None,
|
sender=None,
|
||||||
@@ -122,6 +117,7 @@ class Command(ProgressBarMixin, BaseCommand):
|
|||||||
stdout=self.stdout,
|
stdout=self.stdout,
|
||||||
style_func=self.style,
|
style_func=self.style,
|
||||||
)
|
)
|
||||||
|
|
||||||
if options["storage_path"]:
|
if options["storage_path"]:
|
||||||
set_storage_path(
|
set_storage_path(
|
||||||
sender=None,
|
sender=None,
|
||||||
|
|||||||
@@ -1,43 +1,39 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import tqdm
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from django import db
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
|
|
||||||
from documents.management.commands.mixins import MultiProcessMixin
|
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.parsers import get_parser_class_for_mime_type
|
from documents.parsers import get_parser_class_for_mime_type
|
||||||
|
|
||||||
|
|
||||||
def _process_document(doc_id) -> None:
|
def _process_document(doc_id: int) -> None:
|
||||||
document: Document = Document.objects.get(id=doc_id)
|
document: Document = Document.objects.get(id=doc_id)
|
||||||
parser_class = get_parser_class_for_mime_type(document.mime_type)
|
parser_class = get_parser_class_for_mime_type(document.mime_type)
|
||||||
|
|
||||||
if parser_class:
|
if parser_class is None:
|
||||||
parser = parser_class(logging_group=None)
|
|
||||||
else:
|
|
||||||
print(f"{document} No parser for mime type {document.mime_type}") # noqa: T201
|
print(f"{document} No parser for mime type {document.mime_type}") # noqa: T201
|
||||||
return
|
return
|
||||||
|
|
||||||
|
parser = parser_class(logging_group=None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
thumb = parser.get_thumbnail(
|
thumb = parser.get_thumbnail(
|
||||||
document.source_path,
|
document.source_path,
|
||||||
document.mime_type,
|
document.mime_type,
|
||||||
document.get_public_filename(),
|
document.get_public_filename(),
|
||||||
)
|
)
|
||||||
|
|
||||||
shutil.move(thumb, document.thumbnail_path)
|
shutil.move(thumb, document.thumbnail_path)
|
||||||
finally:
|
finally:
|
||||||
parser.cleanup()
|
parser.cleanup()
|
||||||
|
|
||||||
|
|
||||||
class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = "This will regenerate the thumbnails for all documents."
|
help = "This will regenerate the thumbnails for all documents."
|
||||||
|
|
||||||
|
supports_multiprocessing = True
|
||||||
|
|
||||||
def add_arguments(self, parser) -> None:
|
def add_arguments(self, parser) -> None:
|
||||||
|
super().add_arguments(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-d",
|
"-d",
|
||||||
"--document",
|
"--document",
|
||||||
@@ -49,36 +45,23 @@ class Command(MultiProcessMixin, ProgressBarMixin, BaseCommand):
|
|||||||
"run on this specific document."
|
"run on this specific document."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
self.add_argument_processes_mixin(parser)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
logging.getLogger().handlers[0].level = logging.ERROR
|
logging.getLogger().handlers[0].level = logging.ERROR
|
||||||
|
|
||||||
self.handle_processes_mixin(**options)
|
|
||||||
self.handle_progress_bar_mixin(**options)
|
|
||||||
|
|
||||||
if options["document"]:
|
if options["document"]:
|
||||||
documents = Document.objects.filter(pk=options["document"])
|
documents = Document.objects.filter(pk=options["document"])
|
||||||
else:
|
else:
|
||||||
documents = Document.objects.all()
|
documents = Document.objects.all()
|
||||||
|
|
||||||
ids = [doc.id for doc in documents]
|
ids = list(documents.values_list("id", flat=True))
|
||||||
|
|
||||||
# Note to future self: this prevents django from reusing database
|
for result in self.process_parallel(
|
||||||
# connections between processes, which is bad and does not work
|
_process_document,
|
||||||
# with postgres.
|
ids,
|
||||||
db.connections.close_all()
|
description="Regenerating thumbnails...",
|
||||||
|
):
|
||||||
if self.process_count == 1:
|
if result.error:
|
||||||
for doc_id in ids:
|
self.console.print(
|
||||||
_process_document(doc_id)
|
f"[red]Failed document {result.item}: {result.error}[/red]",
|
||||||
else: # pragma: no cover
|
|
||||||
with multiprocessing.Pool(processes=self.process_count) as pool:
|
|
||||||
list(
|
|
||||||
tqdm.tqdm(
|
|
||||||
pool.imap_unordered(_process_document, ids),
|
|
||||||
total=len(ids),
|
|
||||||
disable=self.no_progress_bar,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,27 +1,21 @@
|
|||||||
from auditlog.models import LogEntry
|
from auditlog.models import LogEntry
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from documents.management.commands.mixins import ProgressBarMixin
|
from documents.management.commands.base import PaperlessCommand
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand, ProgressBarMixin):
|
class Command(PaperlessCommand):
|
||||||
"""
|
"""Prune the audit logs of objects that no longer exist."""
|
||||||
Prune the audit logs of objects that no longer exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
help = "Prunes the audit logs of objects that no longer exist."
|
help = "Prunes the audit logs of objects that no longer exist."
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def handle(self, *args, **options):
|
||||||
self.add_argument_progress_bar_mixin(parser)
|
|
||||||
|
|
||||||
def handle(self, **options):
|
|
||||||
self.handle_progress_bar_mixin(**options)
|
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
for log_entry in tqdm(LogEntry.objects.all(), disable=self.no_progress_bar):
|
for log_entry in self.track(
|
||||||
|
LogEntry.objects.all(),
|
||||||
|
description="Pruning audit logs...",
|
||||||
|
):
|
||||||
model_class = log_entry.content_type.model_class()
|
model_class = log_entry.content_type.model_class()
|
||||||
# use global_objects for SoftDeleteModel
|
|
||||||
objects = (
|
objects = (
|
||||||
model_class.global_objects
|
model_class.global_objects
|
||||||
if hasattr(model_class, "global_objects")
|
if hasattr(model_class, "global_objects")
|
||||||
@@ -32,8 +26,8 @@ class Command(BaseCommand, ProgressBarMixin):
|
|||||||
and not objects.filter(pk=log_entry.object_id).exists()
|
and not objects.filter(pk=log_entry.object_id).exists()
|
||||||
):
|
):
|
||||||
log_entry.delete()
|
log_entry.delete()
|
||||||
tqdm.write(
|
self.console.print(
|
||||||
self.style.NOTICE(
|
f"Deleted audit log entry for "
|
||||||
f"Deleted audit log entry for {model_class.__name__} #{log_entry.object_id}",
|
f"{model_class.__name__} #{log_entry.object_id}",
|
||||||
),
|
style="yellow",
|
||||||
)
|
)
|
||||||
|
|||||||
0
src/documents/tests/management/__init__.py
Normal file
0
src/documents/tests/management/__init__.py
Normal file
579
src/documents/tests/management/test_management_base_cmd.py
Normal file
579
src/documents/tests/management/test_management_base_cmd.py
Normal file
@@ -0,0 +1,579 @@
|
|||||||
|
"""Tests for PaperlessCommand base class."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from django.core.management import CommandError
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from documents.management.commands.base import PaperlessCommand
|
||||||
|
from documents.management.commands.base import ProcessResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Commands ---
|
||||||
|
# These simulate real command implementations for testing
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleCommand(PaperlessCommand):
|
||||||
|
"""Command with default settings (progress bar, no multiprocessing)."""
|
||||||
|
|
||||||
|
help = "Simple test command"
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
items = list(range(5))
|
||||||
|
results = []
|
||||||
|
for item in self.track(items, description="Processing..."):
|
||||||
|
results.append(item * 2)
|
||||||
|
self.stdout.write(f"Results: {results}")
|
||||||
|
|
||||||
|
|
||||||
|
class NoProgressBarCommand(PaperlessCommand):
|
||||||
|
"""Command with progress bar disabled."""
|
||||||
|
|
||||||
|
help = "No progress bar command"
|
||||||
|
supports_progress_bar = False
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
items = list(range(3))
|
||||||
|
for item in self.track(items):
|
||||||
|
pass
|
||||||
|
self.stdout.write("Done")
|
||||||
|
|
||||||
|
|
||||||
|
class MultiprocessCommand(PaperlessCommand):
|
||||||
|
"""Command with multiprocessing support."""
|
||||||
|
|
||||||
|
help = "Multiprocess test command"
|
||||||
|
supports_multiprocessing = True
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
items = list(range(5))
|
||||||
|
results = []
|
||||||
|
for result in self.process_parallel(
|
||||||
|
_double_value,
|
||||||
|
items,
|
||||||
|
description="Processing...",
|
||||||
|
):
|
||||||
|
results.append(result)
|
||||||
|
successes = sum(1 for r in results if r.success)
|
||||||
|
self.stdout.write(f"Successes: {successes}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Functions for Multiprocessing ---
|
||||||
|
# Must be at module level to be picklable
|
||||||
|
|
||||||
|
|
||||||
|
def _double_value(x: int) -> int:
|
||||||
|
"""Double the input value."""
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
|
||||||
|
def _divide_ten_by(x: int) -> float:
|
||||||
|
"""Divide 10 by x. Raises ZeroDivisionError if x is 0."""
|
||||||
|
return 10 / x
|
||||||
|
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def console() -> Console:
|
||||||
|
"""Create a non-interactive console for testing."""
|
||||||
|
return Console(force_terminal=False, force_interactive=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_command(console: Console) -> SimpleCommand:
|
||||||
|
"""Create a SimpleCommand instance configured for testing."""
|
||||||
|
command = SimpleCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
command.console = console
|
||||||
|
command.no_progress_bar = True
|
||||||
|
command.process_count = 1
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def multiprocess_command(console: Console) -> MultiprocessCommand:
|
||||||
|
"""Create a MultiprocessCommand instance configured for testing."""
|
||||||
|
command = MultiprocessCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
command.console = console
|
||||||
|
command.no_progress_bar = True
|
||||||
|
command.process_count = 1
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_queryset():
|
||||||
|
"""
|
||||||
|
Create a mock Django QuerySet that tracks method calls.
|
||||||
|
|
||||||
|
This verifies we use .count() instead of len() for querysets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MockQuerySet:
|
||||||
|
def __init__(self, items: list):
|
||||||
|
self._items = items
|
||||||
|
self.count_called = False
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
self.count_called = True
|
||||||
|
return len(self._items)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._items)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise AssertionError("len() should not be called on querysets")
|
||||||
|
|
||||||
|
return MockQuerySet
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Classes ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestProcessResult:
|
||||||
|
"""Tests for the ProcessResult dataclass."""
|
||||||
|
|
||||||
|
def test_success_result(self):
|
||||||
|
result = ProcessResult(item=1, result=2, error=None)
|
||||||
|
|
||||||
|
assert result.item == 1
|
||||||
|
assert result.result == 2
|
||||||
|
assert result.error is None
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
def test_error_result(self):
|
||||||
|
error = ValueError("test error")
|
||||||
|
result = ProcessResult(item=1, result=None, error=error)
|
||||||
|
|
||||||
|
assert result.item == 1
|
||||||
|
assert result.result is None
|
||||||
|
assert result.error is error
|
||||||
|
assert result.success is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestPaperlessCommandArguments:
|
||||||
|
"""Tests for argument parsing behavior."""
|
||||||
|
|
||||||
|
def test_progress_bar_argument_added_by_default(self):
|
||||||
|
command = SimpleCommand()
|
||||||
|
parser = command.create_parser("manage.py", "simple")
|
||||||
|
|
||||||
|
options = parser.parse_args(["--no-progress-bar"])
|
||||||
|
assert options.no_progress_bar is True
|
||||||
|
|
||||||
|
options = parser.parse_args([])
|
||||||
|
assert options.no_progress_bar is False
|
||||||
|
|
||||||
|
def test_progress_bar_argument_not_added_when_disabled(self):
|
||||||
|
command = NoProgressBarCommand()
|
||||||
|
parser = command.create_parser("manage.py", "noprogress")
|
||||||
|
|
||||||
|
options = parser.parse_args([])
|
||||||
|
assert not hasattr(options, "no_progress_bar")
|
||||||
|
|
||||||
|
def test_processes_argument_added_when_multiprocessing_enabled(self):
|
||||||
|
command = MultiprocessCommand()
|
||||||
|
parser = command.create_parser("manage.py", "multiprocess")
|
||||||
|
|
||||||
|
options = parser.parse_args(["--processes", "4"])
|
||||||
|
assert options.processes == 4
|
||||||
|
|
||||||
|
options = parser.parse_args([])
|
||||||
|
assert options.processes >= 1
|
||||||
|
|
||||||
|
def test_processes_argument_not_added_when_multiprocessing_disabled(self):
|
||||||
|
command = SimpleCommand()
|
||||||
|
parser = command.create_parser("manage.py", "simple")
|
||||||
|
|
||||||
|
options = parser.parse_args([])
|
||||||
|
assert not hasattr(options, "processes")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestPaperlessCommandExecute:
|
||||||
|
"""Tests for the execute() setup behavior."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_options(self) -> dict:
|
||||||
|
"""Base options required for execute()."""
|
||||||
|
return {
|
||||||
|
"verbosity": 1,
|
||||||
|
"no_color": True,
|
||||||
|
"force_color": False,
|
||||||
|
"skip_checks": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("no_progress_bar_flag", "expected"),
|
||||||
|
[
|
||||||
|
pytest.param(False, False, id="progress-bar-enabled"),
|
||||||
|
pytest.param(True, True, id="progress-bar-disabled"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_no_progress_bar_state_set(
|
||||||
|
self,
|
||||||
|
base_options: dict,
|
||||||
|
*,
|
||||||
|
no_progress_bar_flag: bool,
|
||||||
|
expected: bool,
|
||||||
|
):
|
||||||
|
command = SimpleCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
|
||||||
|
options = {**base_options, "no_progress_bar": no_progress_bar_flag}
|
||||||
|
command.execute(**options)
|
||||||
|
|
||||||
|
assert command.no_progress_bar is expected
|
||||||
|
|
||||||
|
def test_no_progress_bar_always_true_when_not_supported(self, base_options: dict):
|
||||||
|
command = NoProgressBarCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
|
||||||
|
command.execute(**base_options)
|
||||||
|
|
||||||
|
assert command.no_progress_bar is True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("processes", "expected"),
|
||||||
|
[
|
||||||
|
pytest.param(1, 1, id="single-process"),
|
||||||
|
pytest.param(4, 4, id="four-processes"),
|
||||||
|
pytest.param(8, 8, id="eight-processes"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_process_count_set(
|
||||||
|
self,
|
||||||
|
base_options: dict,
|
||||||
|
processes: int,
|
||||||
|
expected: int,
|
||||||
|
):
|
||||||
|
command = MultiprocessCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
|
||||||
|
options = {**base_options, "processes": processes, "no_progress_bar": True}
|
||||||
|
command.execute(**options)
|
||||||
|
|
||||||
|
assert command.process_count == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_count",
|
||||||
|
[
|
||||||
|
pytest.param(0, id="zero"),
|
||||||
|
pytest.param(-1, id="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_process_count_validation_rejects_invalid(
|
||||||
|
self,
|
||||||
|
base_options: dict,
|
||||||
|
invalid_count: int,
|
||||||
|
):
|
||||||
|
command = MultiprocessCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
|
||||||
|
options = {**base_options, "processes": invalid_count, "no_progress_bar": True}
|
||||||
|
|
||||||
|
with pytest.raises(CommandError, match="--processes must be at least 1"):
|
||||||
|
command.execute(**options)
|
||||||
|
|
||||||
|
def test_process_count_defaults_to_one_when_not_supported(self, base_options: dict):
|
||||||
|
command = SimpleCommand()
|
||||||
|
command.stdout = io.StringIO()
|
||||||
|
command.stderr = io.StringIO()
|
||||||
|
|
||||||
|
options = {**base_options, "no_progress_bar": True}
|
||||||
|
command.execute(**options)
|
||||||
|
|
||||||
|
assert command.process_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestGetIterableLength:
|
||||||
|
"""Tests for the _get_iterable_length() method."""
|
||||||
|
|
||||||
|
def test_uses_count_method_for_querysets(
|
||||||
|
self,
|
||||||
|
simple_command: SimpleCommand,
|
||||||
|
mock_queryset,
|
||||||
|
):
|
||||||
|
"""Should use .count() for Django querysets (SELECT COUNT(*))."""
|
||||||
|
queryset = mock_queryset([1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
result = simple_command._get_iterable_length(queryset)
|
||||||
|
|
||||||
|
assert result == 5
|
||||||
|
assert queryset.count_called is True
|
||||||
|
|
||||||
|
def test_falls_back_to_len_for_sequences(self, simple_command: SimpleCommand):
|
||||||
|
"""Should use len() for regular sequences without .count()."""
|
||||||
|
items = [1, 2, 3, 4]
|
||||||
|
|
||||||
|
result = simple_command._get_iterable_length(items)
|
||||||
|
|
||||||
|
assert result == 4
|
||||||
|
|
||||||
|
def test_returns_none_for_generators(self, simple_command: SimpleCommand):
|
||||||
|
"""Should return None for iterables without len() or count()."""
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
yield from [1, 2, 3]
|
||||||
|
|
||||||
|
result = simple_command._get_iterable_length(gen())
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_handles_count_not_callable(self, simple_command: SimpleCommand):
|
||||||
|
"""Should skip .count if it's not callable (edge case)."""
|
||||||
|
|
||||||
|
class WeirdObject:
|
||||||
|
count = 42 # Attribute, not method
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 10
|
||||||
|
|
||||||
|
result = simple_command._get_iterable_length(WeirdObject())
|
||||||
|
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
def test_handles_list_count_method(self, simple_command: SimpleCommand):
|
||||||
|
"""list.count(value) requires an argument, should fall back to len()."""
|
||||||
|
items = [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
# list has a .count() method but it requires an argument
|
||||||
|
# _get_iterable_length should catch TypeError and use len() instead
|
||||||
|
result = simple_command._get_iterable_length(items)
|
||||||
|
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestTrack:
|
||||||
|
"""Tests for the track() method."""
|
||||||
|
|
||||||
|
def test_yields_all_items(self, simple_command: SimpleCommand):
|
||||||
|
items = [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
result = list(simple_command.track(items))
|
||||||
|
|
||||||
|
assert result == items
|
||||||
|
|
||||||
|
def test_with_progress_bar_disabled(self, simple_command: SimpleCommand):
|
||||||
|
simple_command.no_progress_bar = True
|
||||||
|
items = ["a", "b", "c"]
|
||||||
|
|
||||||
|
result = list(simple_command.track(items, description="Test..."))
|
||||||
|
|
||||||
|
assert result == items
|
||||||
|
|
||||||
|
def test_with_progress_bar_enabled(self, simple_command: SimpleCommand):
|
||||||
|
simple_command.no_progress_bar = False
|
||||||
|
items = [1, 2, 3]
|
||||||
|
|
||||||
|
result = list(simple_command.track(items, description="Processing..."))
|
||||||
|
|
||||||
|
assert result == items
|
||||||
|
|
||||||
|
def test_with_explicit_total(self, simple_command: SimpleCommand):
|
||||||
|
simple_command.no_progress_bar = False
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
yield from [1, 2, 3]
|
||||||
|
|
||||||
|
result = list(simple_command.track(gen(), total=3))
|
||||||
|
|
||||||
|
assert result == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_with_generator_no_total(self, simple_command: SimpleCommand):
|
||||||
|
def gen():
|
||||||
|
yield from [1, 2, 3]
|
||||||
|
|
||||||
|
result = list(simple_command.track(gen()))
|
||||||
|
|
||||||
|
assert result == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_empty_iterable(self, simple_command: SimpleCommand):
|
||||||
|
result = list(simple_command.track([]))
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_uses_queryset_count(
|
||||||
|
self,
|
||||||
|
simple_command: SimpleCommand,
|
||||||
|
mock_queryset,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Verify track() uses .count() for querysets."""
|
||||||
|
simple_command.no_progress_bar = False
|
||||||
|
queryset = mock_queryset([1, 2, 3])
|
||||||
|
|
||||||
|
spy = mocker.spy(simple_command, "_get_iterable_length")
|
||||||
|
|
||||||
|
result = list(simple_command.track(queryset))
|
||||||
|
|
||||||
|
assert result == [1, 2, 3]
|
||||||
|
spy.assert_called_once_with(queryset)
|
||||||
|
assert queryset.count_called is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestProcessParallel:
|
||||||
|
"""Tests for the process_parallel() method."""
|
||||||
|
|
||||||
|
def test_sequential_processing_single_process(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
):
|
||||||
|
multiprocess_command.process_count = 1
|
||||||
|
items = [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||||
|
|
||||||
|
assert len(results) == 5
|
||||||
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
|
result_map = {r.item: r.result for r in results}
|
||||||
|
assert result_map == {1: 2, 2: 4, 3: 6, 4: 8, 5: 10}
|
||||||
|
|
||||||
|
def test_sequential_processing_handles_errors(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
):
|
||||||
|
multiprocess_command.process_count = 1
|
||||||
|
items = [1, 2, 0, 4] # 0 causes ZeroDivisionError
|
||||||
|
|
||||||
|
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
||||||
|
|
||||||
|
assert len(results) == 4
|
||||||
|
|
||||||
|
successes = [r for r in results if r.success]
|
||||||
|
failures = [r for r in results if not r.success]
|
||||||
|
|
||||||
|
assert len(successes) == 3
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].item == 0
|
||||||
|
assert isinstance(failures[0].error, ZeroDivisionError)
|
||||||
|
|
||||||
|
def test_parallel_closes_db_connections(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
multiprocess_command.process_count = 2
|
||||||
|
items = [1, 2, 3]
|
||||||
|
|
||||||
|
mock_connections = mocker.patch(
|
||||||
|
"documents.management.commands.base.db.connections",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||||
|
|
||||||
|
mock_connections.close_all.assert_called_once()
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
def test_parallel_processing_handles_errors(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
multiprocess_command.process_count = 2
|
||||||
|
items = [1, 2, 0, 4]
|
||||||
|
|
||||||
|
mocker.patch("documents.management.commands.base.db.connections")
|
||||||
|
|
||||||
|
results = list(multiprocess_command.process_parallel(_divide_ten_by, items))
|
||||||
|
|
||||||
|
failures = [r for r in results if not r.success]
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].item == 0
|
||||||
|
|
||||||
|
def test_empty_items(self, multiprocess_command: MultiprocessCommand):
|
||||||
|
results = list(multiprocess_command.process_parallel(_double_value, []))
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_result_contains_original_item(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
):
|
||||||
|
items = [10, 20, 30]
|
||||||
|
|
||||||
|
results = list(multiprocess_command.process_parallel(_double_value, items))
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
assert result.item in items
|
||||||
|
assert result.result == result.item * 2
|
||||||
|
|
||||||
|
def test_sequential_path_used_for_single_process(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Verify single process uses sequential path (important for testing)."""
|
||||||
|
multiprocess_command.process_count = 1
|
||||||
|
|
||||||
|
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
||||||
|
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
||||||
|
|
||||||
|
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
||||||
|
|
||||||
|
spy_sequential.assert_called_once()
|
||||||
|
spy_parallel.assert_not_called()
|
||||||
|
|
||||||
|
def test_parallel_path_used_for_multiple_processes(
|
||||||
|
self,
|
||||||
|
multiprocess_command: MultiprocessCommand,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Verify multiple processes uses parallel path."""
|
||||||
|
multiprocess_command.process_count = 2
|
||||||
|
|
||||||
|
mocker.patch("documents.management.commands.base.db.connections")
|
||||||
|
spy_sequential = mocker.spy(multiprocess_command, "_process_sequential")
|
||||||
|
spy_parallel = mocker.spy(multiprocess_command, "_process_parallel")
|
||||||
|
|
||||||
|
list(multiprocess_command.process_parallel(_double_value, [1, 2, 3]))
|
||||||
|
|
||||||
|
spy_parallel.assert_called_once()
|
||||||
|
spy_sequential.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
|
class TestClassVariableDefaults:
|
||||||
|
"""Tests for class variable default behavior."""
|
||||||
|
|
||||||
|
def test_default_supports_progress_bar(self):
|
||||||
|
assert PaperlessCommand.supports_progress_bar is True
|
||||||
|
|
||||||
|
def test_default_supports_multiprocessing(self):
|
||||||
|
assert PaperlessCommand.supports_multiprocessing is False
|
||||||
|
|
||||||
|
def test_subclass_can_override_progress_bar(self):
|
||||||
|
class NoProgressCommand(PaperlessCommand):
|
||||||
|
supports_progress_bar: ClassVar[bool] = False
|
||||||
|
|
||||||
|
assert NoProgressCommand.supports_progress_bar is False
|
||||||
|
assert PaperlessCommand.supports_progress_bar is True
|
||||||
|
|
||||||
|
def test_subclass_can_override_multiprocessing(self):
|
||||||
|
class ParallelCommand(PaperlessCommand):
|
||||||
|
supports_multiprocessing: ClassVar[bool] = True
|
||||||
|
|
||||||
|
assert ParallelCommand.supports_multiprocessing is True
|
||||||
|
assert PaperlessCommand.supports_multiprocessing is False
|
||||||
@@ -4,6 +4,7 @@ from io import StringIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from auditlog.models import LogEntry
|
from auditlog.models import LogEntry
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
@@ -19,6 +20,7 @@ from documents.tests.utils import FileSystemAssertsMixin
|
|||||||
sample_file: Path = Path(__file__).parent / "samples" / "simple.pdf"
|
sample_file: Path = Path(__file__).parent / "samples" / "simple.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@override_settings(FILENAME_FORMAT="{correspondent}/{title}")
|
@override_settings(FILENAME_FORMAT="{correspondent}/{title}")
|
||||||
class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||||
def make_models(self):
|
def make_models(self):
|
||||||
@@ -94,6 +96,7 @@ class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
self.assertEqual(doc2.archive_filename, "document_01.pdf")
|
self.assertEqual(doc2.archive_filename, "document_01.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestMakeIndex(TestCase):
|
class TestMakeIndex(TestCase):
|
||||||
@mock.patch("documents.management.commands.document_index.index_reindex")
|
@mock.patch("documents.management.commands.document_index.index_reindex")
|
||||||
def test_reindex(self, m) -> None:
|
def test_reindex(self, m) -> None:
|
||||||
@@ -106,6 +109,7 @@ class TestMakeIndex(TestCase):
|
|||||||
m.assert_called_once()
|
m.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||||
@override_settings(FILENAME_FORMAT="")
|
@override_settings(FILENAME_FORMAT="")
|
||||||
def test_rename(self) -> None:
|
def test_rename(self) -> None:
|
||||||
@@ -140,6 +144,7 @@ class TestCreateClassifier(TestCase):
|
|||||||
m.assert_called_once()
|
m.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestSanityChecker(DirectoriesMixin, TestCase):
|
class TestSanityChecker(DirectoriesMixin, TestCase):
|
||||||
def test_no_issues(self) -> None:
|
def test_no_issues(self) -> None:
|
||||||
with self.assertLogs() as capture:
|
with self.assertLogs() as capture:
|
||||||
@@ -165,6 +170,7 @@ class TestSanityChecker(DirectoriesMixin, TestCase):
|
|||||||
self.assertIn("Checksum mismatch. Stored: abc, actual:", capture.output[1])
|
self.assertIn("Checksum mismatch. Stored: abc, actual:", capture.output[1])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestConvertMariaDBUUID(TestCase):
|
class TestConvertMariaDBUUID(TestCase):
|
||||||
@mock.patch("django.db.connection.schema_editor")
|
@mock.patch("django.db.connection.schema_editor")
|
||||||
def test_convert(self, m) -> None:
|
def test_convert(self, m) -> None:
|
||||||
@@ -178,6 +184,7 @@ class TestConvertMariaDBUUID(TestCase):
|
|||||||
self.assertIn("Successfully converted", stdout.getvalue())
|
self.assertIn("Successfully converted", stdout.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestPruneAuditLogs(TestCase):
|
class TestPruneAuditLogs(TestCase):
|
||||||
def test_prune_audit_logs(self) -> None:
|
def test_prune_audit_logs(self) -> None:
|
||||||
LogEntry.objects.create(
|
LogEntry.objects.create(
|
||||||
|
|||||||
@@ -577,6 +577,7 @@ class TestTagsFromPath:
|
|||||||
assert len(tag_ids) == 0
|
assert len(tag_ids) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestCommandValidation:
|
class TestCommandValidation:
|
||||||
"""Tests for command argument validation."""
|
"""Tests for command argument validation."""
|
||||||
|
|
||||||
@@ -605,6 +606,7 @@ class TestCommandValidation:
|
|||||||
cmd.handle(directory=str(sample_pdf), oneshot=True, testing=False)
|
cmd.handle(directory=str(sample_pdf), oneshot=True, testing=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@pytest.mark.usefixtures("mock_supported_extensions")
|
@pytest.mark.usefixtures("mock_supported_extensions")
|
||||||
class TestCommandOneshot:
|
class TestCommandOneshot:
|
||||||
"""Tests for oneshot mode."""
|
"""Tests for oneshot mode."""
|
||||||
@@ -775,6 +777,7 @@ def start_consumer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
class TestCommandWatch:
|
class TestCommandWatch:
|
||||||
"""Integration tests for the watch loop."""
|
"""Integration tests for the watch loop."""
|
||||||
@@ -896,6 +899,7 @@ class TestCommandWatch:
|
|||||||
assert not thread.is_alive()
|
assert not thread.is_alive()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
class TestCommandWatchPolling:
|
class TestCommandWatchPolling:
|
||||||
"""Tests for polling mode."""
|
"""Tests for polling mode."""
|
||||||
@@ -928,6 +932,7 @@ class TestCommandWatchPolling:
|
|||||||
mock_consume_file_delay.delay.assert_called()
|
mock_consume_file_delay.delay.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
class TestCommandWatchRecursive:
|
class TestCommandWatchRecursive:
|
||||||
"""Tests for recursive watching."""
|
"""Tests for recursive watching."""
|
||||||
@@ -991,6 +996,7 @@ class TestCommandWatchRecursive:
|
|||||||
assert len(overrides.tag_ids) == 2
|
assert len(overrides.tag_ids) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
class TestCommandWatchEdgeCases:
|
class TestCommandWatchEdgeCases:
|
||||||
"""Tests for edge cases and error handling."""
|
"""Tests for edge cases and error handling."""
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import pytest
|
||||||
from allauth.socialaccount.models import SocialAccount
|
from allauth.socialaccount.models import SocialAccount
|
||||||
from allauth.socialaccount.models import SocialApp
|
from allauth.socialaccount.models import SocialApp
|
||||||
from allauth.socialaccount.models import SocialToken
|
from allauth.socialaccount.models import SocialToken
|
||||||
@@ -45,6 +46,7 @@ from documents.tests.utils import paperless_environment
|
|||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestExportImport(
|
class TestExportImport(
|
||||||
DirectoriesMixin,
|
DirectoriesMixin,
|
||||||
FileSystemAssertsMixin,
|
FileSystemAssertsMixin,
|
||||||
@@ -846,6 +848,7 @@ class TestExportImport(
|
|||||||
self.assertEqual(Document.objects.all().count(), 4)
|
self.assertEqual(Document.objects.all().count(), 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestCryptExportImport(
|
class TestCryptExportImport(
|
||||||
DirectoriesMixin,
|
DirectoriesMixin,
|
||||||
FileSystemAssertsMixin,
|
FileSystemAssertsMixin,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.core.management import CommandError
|
from django.core.management import CommandError
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
@@ -7,6 +8,7 @@ from django.test import TestCase
|
|||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestFuzzyMatchCommand(TestCase):
|
class TestFuzzyMatchCommand(TestCase):
|
||||||
MSG_REGEX = r"Document \d fuzzy match to \d \(confidence \d\d\.\d\d\d\)"
|
MSG_REGEX = r"Document \d fuzzy match to \d \(confidence \d\d\.\d\d\d\)"
|
||||||
|
|
||||||
@@ -49,19 +51,6 @@ class TestFuzzyMatchCommand(TestCase):
|
|||||||
self.call_command("--ratio", "101")
|
self.call_command("--ratio", "101")
|
||||||
self.assertIn("The ratio must be between 0 and 100", str(e.exception))
|
self.assertIn("The ratio must be between 0 and 100", str(e.exception))
|
||||||
|
|
||||||
def test_invalid_process_count(self) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Invalid process count less than 0 above upper
|
|
||||||
WHEN:
|
|
||||||
- Command is called
|
|
||||||
THEN:
|
|
||||||
- Error is raised indicating issue
|
|
||||||
"""
|
|
||||||
with self.assertRaises(CommandError) as e:
|
|
||||||
self.call_command("--processes", "0")
|
|
||||||
self.assertIn("There must be at least 1 process", str(e.exception))
|
|
||||||
|
|
||||||
def test_no_matches(self) -> None:
|
def test_no_matches(self) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -194,7 +183,7 @@ class TestFuzzyMatchCommand(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(Document.objects.count(), 3)
|
self.assertEqual(Document.objects.count(), 3)
|
||||||
|
|
||||||
stdout, _ = self.call_command("--delete")
|
stdout, _ = self.call_command("--delete", "--no-progress-bar")
|
||||||
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"The command is configured to delete documents. Use with caution",
|
"The command is configured to delete documents. Use with caution",
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from io import StringIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.core.management.base import CommandError
|
from django.core.management.base import CommandError
|
||||||
@@ -18,6 +19,7 @@ from documents.tests.utils import FileSystemAssertsMixin
|
|||||||
from documents.tests.utils import SampleDirMixin
|
from documents.tests.utils import SampleDirMixin
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestCommandImport(
|
class TestCommandImport(
|
||||||
DirectoriesMixin,
|
DirectoriesMixin,
|
||||||
FileSystemAssertsMixin,
|
FileSystemAssertsMixin,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.core.management.base import CommandError
|
from django.core.management.base import CommandError
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
@@ -10,6 +11,7 @@ from documents.models import Tag
|
|||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestRetagger(DirectoriesMixin, TestCase):
|
class TestRetagger(DirectoriesMixin, TestCase):
|
||||||
def make_models(self) -> None:
|
def make_models(self) -> None:
|
||||||
self.sp1 = StoragePath.objects.create(
|
self.sp1 = StoragePath.objects.create(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
@@ -9,6 +10,7 @@ from django.test import TestCase
|
|||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestManageSuperUser(DirectoriesMixin, TestCase):
|
class TestManageSuperUser(DirectoriesMixin, TestCase):
|
||||||
def call_command(self, environ):
|
def call_command(self, environ):
|
||||||
out = StringIO()
|
out = StringIO()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import shutil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ from documents.tests.utils import DirectoriesMixin
|
|||||||
from documents.tests.utils import FileSystemAssertsMixin
|
from documents.tests.utils import FileSystemAssertsMixin
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.management
|
||||||
class TestMakeThumbnails(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
class TestMakeThumbnails(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||||
def make_models(self) -> None:
|
def make_models(self) -> None:
|
||||||
self.d1 = Document.objects.create(
|
self.d1 = Document.objects.create(
|
||||||
|
|||||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1137,6 +1137,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/6d/10/23c0644cf67567bbe4e3a2eeeec0e9c79b701990c0e07c5ee4a4f8897f91/django_multiselectfield-1.0.1-py3-none-any.whl", hash = "sha256:18dc14801f7eca844a48e21cba6d8ec35b9b581f2373bbb2cb75e6994518259a", size = 20481, upload-time = "2025-06-12T14:41:20.107Z" },
|
{ url = "https://files.pythonhosted.org/packages/6d/10/23c0644cf67567bbe4e3a2eeeec0e9c79b701990c0e07c5ee4a4f8897f91/django_multiselectfield-1.0.1-py3-none-any.whl", hash = "sha256:18dc14801f7eca844a48e21cba6d8ec35b9b581f2373bbb2cb75e6994518259a", size = 20481, upload-time = "2025-06-12T14:41:20.107Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "django-rich"
|
||||||
|
version = "2.2.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "django", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a6/67/e307a5fef657e7992468f567b521534c52e01bdda5a1ae5b12de679a670f/django_rich-2.2.0.tar.gz", hash = "sha256:ecec7842d040024ed8a225699388535e46b87277550c33f46193b52cece2f780", size = 62427, upload-time = "2025-09-18T11:42:17.182Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/ed/23fa669493d78cd67e7f6734fa380f8690f2b4d75b4f72fd645a52c3b32a/django_rich-2.2.0-py3-none-any.whl", hash = "sha256:a0d2c916bd9750b6e9beb57407aef5e836c8705d7dbe9e4fd4725f7bbe41c407", size = 9210, upload-time = "2025-09-18T11:42:15.779Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "django-soft-delete"
|
name = "django-soft-delete"
|
||||||
version = "1.0.22"
|
version = "1.0.22"
|
||||||
@@ -3041,6 +3054,7 @@ dependencies = [
|
|||||||
{ name = "django-filter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "django-filter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "django-guardian", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "django-guardian", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "django-multiselectfield", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "django-multiselectfield", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "django-rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "django-soft-delete", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "django-soft-delete", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "django-treenode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "django-treenode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "djangorestframework", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
@@ -3081,7 +3095,6 @@ dependencies = [
|
|||||||
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "torch", version = "2.10.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
{ name = "torch", version = "2.10.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
||||||
{ name = "torch", version = "2.10.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'linux'" },
|
{ name = "torch", version = "2.10.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'linux'" },
|
||||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
|
||||||
{ name = "watchfiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "watchfiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "whitenoise", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "whitenoise", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "whoosh-reloaded", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "whoosh-reloaded", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
@@ -3186,6 +3199,7 @@ requires-dist = [
|
|||||||
{ name = "django-filter", specifier = "~=25.1" },
|
{ name = "django-filter", specifier = "~=25.1" },
|
||||||
{ name = "django-guardian", specifier = "~=3.2.0" },
|
{ name = "django-guardian", specifier = "~=3.2.0" },
|
||||||
{ name = "django-multiselectfield", specifier = "~=1.0.1" },
|
{ name = "django-multiselectfield", specifier = "~=1.0.1" },
|
||||||
|
{ name = "django-rich", specifier = "~=2.2.0" },
|
||||||
{ name = "django-soft-delete", specifier = "~=1.0.18" },
|
{ name = "django-soft-delete", specifier = "~=1.0.18" },
|
||||||
{ name = "django-treenode", specifier = ">=0.23.2" },
|
{ name = "django-treenode", specifier = ">=0.23.2" },
|
||||||
{ name = "djangorestframework", specifier = "~=3.16" },
|
{ name = "djangorestframework", specifier = "~=3.16" },
|
||||||
@@ -3232,7 +3246,6 @@ requires-dist = [
|
|||||||
{ name = "setproctitle", specifier = "~=1.3.4" },
|
{ name = "setproctitle", specifier = "~=1.3.4" },
|
||||||
{ name = "tika-client", specifier = "~=0.10.0" },
|
{ name = "tika-client", specifier = "~=0.10.0" },
|
||||||
{ name = "torch", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cpu" },
|
{ name = "torch", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||||
{ name = "tqdm", specifier = "~=4.67.1" },
|
|
||||||
{ name = "watchfiles", specifier = ">=1.1.1" },
|
{ name = "watchfiles", specifier = ">=1.1.1" },
|
||||||
{ name = "whitenoise", specifier = "~=6.11" },
|
{ name = "whitenoise", specifier = "~=6.11" },
|
||||||
{ name = "whoosh-reloaded", specifier = ">=2.7.5" },
|
{ name = "whoosh-reloaded", specifier = ">=2.7.5" },
|
||||||
|
|||||||
Reference in New Issue
Block a user