et tu, mypy?

This commit is contained in:
shamoon
2026-02-10 21:52:13 -08:00
parent 5a0a8a58b3
commit b1f2606022
8 changed files with 79 additions and 47 deletions

View File

@@ -811,17 +811,17 @@ def edit_pdf(
if delete_original and len(pdf_docs) == 1:
overrides.asn = root_doc.archive_serial_number
for idx, pdf in enumerate(pdf_docs, start=1):
filepath: Path = (
version_filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
/ f"{root_doc.id}_edit_{idx}.pdf"
)
pdf.remove_unreferenced_resources()
pdf.save(filepath)
pdf.save(version_filepath)
consume_tasks.append(
consume_file.s(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
original_file=version_filepath,
),
overrides,
),

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from django.conf import settings
from django.core.cache import cache
@@ -55,7 +56,12 @@ def _resolve_effective_doc(pk: int, request) -> Document | None:
# Default behavior: if pk is a root doc, prefer its newest child version
if request_doc.root_document_id is None:
latest = root_doc.versions.only("id").order_by("id").last()
latest = (
Document.objects.filter(root_document=root_doc)
.only("id")
.order_by("id")
.last()
)
return latest or root_doc
# pk is already a version
@@ -167,7 +173,7 @@ def preview_last_modified(request, pk: int) -> datetime | None:
return None
def thumbnail_last_modified(request, pk: int) -> datetime | None:
def thumbnail_last_modified(request: Any, pk: int) -> datetime | None:
"""
Returns the filesystem last modified either from cache or from filesystem.
Cache should be (slightly?) faster than filesystem

View File

@@ -103,9 +103,10 @@ class ConsumerStatusShortMessage(str, Enum):
class ConsumerPluginMixin:
if TYPE_CHECKING:
from logging import Logger
from logging import LoggerAdapter
log: "LoggerAdapter"
log: "LoggerAdapter" # type: ignore[type-arg]
def __init__(
self,
@@ -126,7 +127,7 @@ class ConsumerPluginMixin:
f"Document root document id: {input_doc.root_document_id}",
)
root_document = Document.objects.get(pk=input_doc.root_document_id)
version_index = root_document.versions.count()
version_index = Document.objects.filter(root_document=root_document).count()
self.filename += f"_v{version_index}"
def _send_progress(
@@ -530,11 +531,15 @@ class ConsumerPlugin(
settings.AUDIT_LOG_ENABLED
and self.metadata.actor_id is not None
):
from auditlog.models import LogEntry
from auditlog.models import (
LogEntry, # type: ignore[import-untyped]
)
actor = User.objects.filter(pk=self.metadata.actor_id).first()
if actor is not None:
from auditlog.context import set_actor
from auditlog.context import ( # type: ignore[import-untyped]
set_actor,
)
with set_actor(actor):
original_document.save()

View File

@@ -1,6 +1,5 @@
import datetime
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Final
import pathvalidate
@@ -156,12 +155,7 @@ class StoragePath(MatchingModel):
verbose_name_plural = _("storage paths")
class Document(SoftDeleteModel, ModelWithOwner):
if TYPE_CHECKING:
from django.db.models.query import QuerySet
versions: "QuerySet[Document]"
class Document(SoftDeleteModel, ModelWithOwner): # type: ignore[django-manager-missing]
correspondent = models.ForeignKey(
Correspondent,
blank=True,
@@ -1748,5 +1742,5 @@ class WorkflowRun(SoftDeleteModel):
verbose_name = _("workflow run")
verbose_name_plural = _("workflow runs")
def __str__(self):
def __str__(self) -> str:
return f"WorkflowRun of {self.workflow} at {self.run_at} on {self.document}"

View File

@@ -7,7 +7,9 @@ from datetime import datetime
from datetime import timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import TypedDict
import magic
from celery import states
@@ -89,6 +91,8 @@ if TYPE_CHECKING:
from collections.abc import Iterable
from django.db.models.query import QuerySet
from rest_framework.relations import ManyRelatedField
from rest_framework.relations import RelatedField
logger = logging.getLogger("paperless.serializers")
@@ -1071,6 +1075,14 @@ class DocumentVersionInfoSerializer(serializers.Serializer):
is_root = serializers.BooleanField()
class _DocumentVersionInfo(TypedDict):
id: int
added: datetime
version_label: str | None
checksum: str | None
is_root: bool
@extend_schema_serializer(
deprecate_fields=["created_date"],
)
@@ -1091,7 +1103,9 @@ class DocumentSerializer(
duplicate_documents = SerializerMethodField()
notes = NotesSerializer(many=True, required=False, read_only=True)
root_document = serializers.PrimaryKeyRelatedField(read_only=True)
root_document: RelatedField[Document, Document, Any] | ManyRelatedField = (
serializers.PrimaryKeyRelatedField(read_only=True)
)
versions = SerializerMethodField()
custom_fields = CustomFieldInstanceSerializer(
@@ -1129,6 +1143,8 @@ class DocumentSerializer(
@extend_schema_field(DocumentVersionInfoSerializer(many=True))
def get_versions(self, obj):
root_doc = obj if obj.root_document_id is None else obj.root_document
if root_doc is None:
return []
versions_qs = Document.objects.filter(root_document=root_doc).only(
"id",
"added",
@@ -1137,7 +1153,7 @@ class DocumentSerializer(
)
versions = [*versions_qs, root_doc]
def build_info(doc: Document) -> dict[str, object]:
def build_info(doc: Document) -> _DocumentVersionInfo:
return {
"id": doc.id,
"added": doc.added,
@@ -2249,7 +2265,7 @@ class TasksViewSerializer(OwnedObjectSerializer):
return list(duplicates.values("id", "title", "deleted_at"))
class RunTaskViewSerializer(serializers.Serializer):
class RunTaskViewSerializer(serializers.Serializer[dict[str, Any]]):
task_name = serializers.ChoiceField(
choices=PaperlessTask.TaskName.choices,
label="Task Name",
@@ -2257,7 +2273,7 @@ class RunTaskViewSerializer(serializers.Serializer):
)
class AcknowledgeTasksViewSerializer(serializers.Serializer):
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
tasks = serializers.ListField(
required=True,
label="Tasks",
@@ -3004,7 +3020,7 @@ class TrashSerializer(SerializerWithPerms):
write_only=True,
)
def validate_documents(self, documents):
def validate_documents(self, documents: list[int]) -> list[int]:
count = Document.deleted_objects.filter(id__in=documents).count()
if not count == len(documents):
raise serializers.ValidationError(

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from unittest import mock
from auditlog.models import LogEntry
from auditlog.models import LogEntry # type: ignore[import-untyped]
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from rest_framework import status

View File

@@ -1291,12 +1291,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
@mock.patch("pikepdf.open")
def test_remove_password_creates_consumable_document(
self,
mock_open,
mock_mkdtemp,
mock_consume_file,
mock_group,
mock_chord,
):
mock_open: mock.Mock,
mock_mkdtemp: mock.Mock,
mock_consume_file: mock.Mock,
mock_group: mock.Mock,
mock_chord: mock.Mock,
) -> None:
doc = self.doc2
temp_dir = self.dirs.scratch_dir / "remove-password"
temp_dir.mkdir(parents=True, exist_ok=True)
@@ -1305,8 +1305,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
def save_side_effect(target_path):
Path(target_path).write_bytes(b"password removed")
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
@@ -1348,13 +1348,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
@mock.patch("pikepdf.open")
def test_remove_password_deletes_original(
self,
mock_open,
mock_mkdtemp,
mock_consume_file,
mock_group,
mock_chord,
mock_delete,
):
mock_open: mock.Mock,
mock_mkdtemp: mock.Mock,
mock_consume_file: mock.Mock,
mock_group: mock.Mock,
mock_chord: mock.Mock,
mock_delete: mock.Mock,
) -> None:
doc = self.doc2
temp_dir = self.dirs.scratch_dir / "remove-password-delete"
temp_dir.mkdir(parents=True, exist_ok=True)
@@ -1363,8 +1363,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
def save_side_effect(target_path):
Path(target_path).write_bytes(b"password removed")
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
@@ -1387,7 +1387,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_delete.si.assert_called_once_with([doc.id])
@mock.patch("pikepdf.open")
def test_remove_password_open_failure(self, mock_open):
def test_remove_password_open_failure(self, mock_open: mock.Mock) -> None:
mock_open.side_effect = RuntimeError("wrong password")
with self.assertLogs("paperless.bulk_edit", level="ERROR") as cm:

View File

@@ -10,6 +10,7 @@ from collections import deque
from datetime import datetime
from pathlib import Path
from time import mktime
from typing import Any
from typing import Literal
from unicodedata import normalize
from urllib.parse import quote
@@ -37,8 +38,10 @@ from django.db.models import Sum
from django.db.models import When
from django.db.models.functions import Lower
from django.db.models.manager import Manager
from django.db.models.query import QuerySet
from django.http import FileResponse
from django.http import Http404
from django.http import HttpRequest
from django.http import HttpResponse
from django.http import HttpResponseBadRequest
from django.http import HttpResponseForbidden
@@ -83,6 +86,7 @@ from rest_framework.mixins import ListModelMixin
from rest_framework.mixins import RetrieveModelMixin
from rest_framework.mixins import UpdateModelMixin
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from rest_framework.viewsets import ModelViewSet
@@ -825,6 +829,8 @@ class DocumentViewSet(
raise Http404
root_doc = doc if doc.root_document_id is None else doc.root_document
if root_doc is None:
raise Http404
if request.user is not None and not has_perms_owner_aware(
request.user,
"view_document",
@@ -889,7 +895,7 @@ class DocumentViewSet(
):
raise Http404
return candidate
latest = root_doc.versions.order_by("id").last()
latest = Document.objects.filter(root_document=root_doc).order_by("id").last()
return latest or root_doc
def file_response(self, pk, request, disposition):
@@ -1922,7 +1928,7 @@ class SavedViewViewSet(ModelViewSet, PassUserMixin):
.prefetch_related("filter_rules")
)
def perform_create(self, serializer) -> None:
def perform_create(self, serializer: serializers.BaseSerializer[Any]) -> None:
serializer.save(owner=self.request.user)
@@ -3466,7 +3472,7 @@ class CustomFieldViewSet(ModelViewSet):
queryset = CustomField.objects.all().order_by("-created")
def get_queryset(self):
def get_queryset(self) -> QuerySet[CustomField]:
filter = (
Q(fields__document__deleted_at__isnull=True)
if self.request.user is None or self.request.user.is_superuser
@@ -3779,11 +3785,16 @@ class TrashView(ListModelMixin, PassUserMixin):
queryset = Document.deleted_objects.all()
def get(self, request, format=None):
def get(self, request: Request, format: str | None = None) -> Response:
self.serializer_class = DocumentSerializer
return self.list(request, format)
def post(self, request, *args, **kwargs):
def post(
self,
request: Request,
*args: Any,
**kwargs: Any,
) -> Response | HttpResponse:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@@ -3807,7 +3818,7 @@ class TrashView(ListModelMixin, PassUserMixin):
return Response({"result": "OK", "doc_ids": doc_ids})
def serve_logo(request, filename=None):
def serve_logo(request: HttpRequest, filename: str | None = None) -> FileResponse:
"""
Serves the configured logo file with Content-Disposition: attachment.
Prevents inline execution of SVGs. See GHSA-6p53-hqqw-8j62