From b1f2606022e9fdb9d9143bd78ddfcb74bff06f66 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:52:13 -0800 Subject: [PATCH] et tu, mypy? --- src/documents/bulk_edit.py | 6 ++-- src/documents/conditionals.py | 10 ++++-- src/documents/consumer.py | 13 ++++--- src/documents/models.py | 10 ++---- src/documents/serialisers.py | 26 +++++++++++--- .../tests/test_api_document_versions.py | 2 +- src/documents/tests/test_bulk_edit.py | 36 +++++++++---------- src/documents/views.py | 23 ++++++++---- 8 files changed, 79 insertions(+), 47 deletions(-) diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 5d1f3f639..c72cda599 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -811,17 +811,17 @@ def edit_pdf( if delete_original and len(pdf_docs) == 1: overrides.asn = root_doc.archive_serial_number for idx, pdf in enumerate(pdf_docs, start=1): - filepath: Path = ( + version_filepath: Path = ( Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR)) / f"{root_doc.id}_edit_{idx}.pdf" ) pdf.remove_unreferenced_resources() - pdf.save(filepath) + pdf.save(version_filepath) consume_tasks.append( consume_file.s( ConsumableDocument( source=DocumentSource.ConsumeFolder, - original_file=filepath, + original_file=version_filepath, ), overrides, ), diff --git a/src/documents/conditionals.py b/src/documents/conditionals.py index b809d828c..c1160d563 100644 --- a/src/documents/conditionals.py +++ b/src/documents/conditionals.py @@ -1,5 +1,6 @@ from datetime import datetime from datetime import timezone +from typing import Any from django.conf import settings from django.core.cache import cache @@ -55,7 +56,12 @@ def _resolve_effective_doc(pk: int, request) -> Document | None: # Default behavior: if pk is a root doc, prefer its newest child version if request_doc.root_document_id is None: - latest = root_doc.versions.only("id").order_by("id").last() + latest = ( + Document.objects.filter(root_document=root_doc) + .only("id") + .order_by("id") + .last() + ) return latest or root_doc # pk is already a version @@ -167,7 +173,7 @@ def preview_last_modified(request, pk: int) -> datetime | None: return None -def thumbnail_last_modified(request, pk: int) -> datetime | None: +def thumbnail_last_modified(request: Any, pk: int) -> datetime | None: """ Returns the filesystem last modified either from cache or from filesystem. Cache should be (slightly?) faster than filesystem diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 98540e679..9deb2aea8 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -103,9 +103,10 @@ class ConsumerStatusShortMessage(str, Enum): class ConsumerPluginMixin: if TYPE_CHECKING: + from logging import Logger from logging import LoggerAdapter - log: "LoggerAdapter" + log: "LoggerAdapter" # type: ignore[type-arg] def __init__( self, @@ -126,7 +127,7 @@ class ConsumerPluginMixin: f"Document root document id: {input_doc.root_document_id}", ) root_document = Document.objects.get(pk=input_doc.root_document_id) - version_index = root_document.versions.count() + version_index = Document.objects.filter(root_document=root_document).count() self.filename += f"_v{version_index}" def _send_progress( @@ -530,11 +531,15 @@ class ConsumerPlugin( settings.AUDIT_LOG_ENABLED and self.metadata.actor_id is not None ): - from auditlog.models import LogEntry + from auditlog.models import ( + LogEntry, # type: ignore[import-untyped] + ) actor = User.objects.filter(pk=self.metadata.actor_id).first() if actor is not None: - from auditlog.context import set_actor + from auditlog.context import ( # type: ignore[import-untyped] + set_actor, + ) with set_actor(actor): original_document.save() diff --git a/src/documents/models.py b/src/documents/models.py index 249de4b4a..45cd3c4e1 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -1,6 +1,5 @@ import datetime from pathlib import Path -from typing import TYPE_CHECKING from typing import Final import pathvalidate @@ -156,12 +155,7 @@ class StoragePath(MatchingModel): verbose_name_plural = _("storage paths") -class Document(SoftDeleteModel, ModelWithOwner): - if TYPE_CHECKING: - from django.db.models.query import QuerySet - - versions: "QuerySet[Document]" - +class Document(SoftDeleteModel, ModelWithOwner): # type: ignore[django-manager-missing] correspondent = models.ForeignKey( Correspondent, blank=True, @@ -1748,5 +1742,5 @@ class WorkflowRun(SoftDeleteModel): verbose_name = _("workflow run") verbose_name_plural = _("workflow runs") - def __str__(self): + def __str__(self) -> str: return f"WorkflowRun of {self.workflow} at {self.run_at} on {self.document}" diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index a168aea7b..71953bd90 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -7,7 +7,9 @@ from datetime import datetime from datetime import timedelta from decimal import Decimal from typing import TYPE_CHECKING +from typing import Any from typing import Literal +from typing import TypedDict import magic from celery import states @@ -89,6 +91,8 @@ if TYPE_CHECKING: from collections.abc import Iterable from django.db.models.query import QuerySet + from rest_framework.relations import ManyRelatedField + from rest_framework.relations import RelatedField logger = logging.getLogger("paperless.serializers") @@ -1071,6 +1075,14 @@ class DocumentVersionInfoSerializer(serializers.Serializer): is_root = serializers.BooleanField() +class _DocumentVersionInfo(TypedDict): + id: int + added: datetime + version_label: str | None + checksum: str | None + is_root: bool + + @extend_schema_serializer( deprecate_fields=["created_date"], ) @@ -1091,7 +1103,9 @@ class DocumentSerializer( duplicate_documents = SerializerMethodField() notes = NotesSerializer(many=True, required=False, read_only=True) - root_document = serializers.PrimaryKeyRelatedField(read_only=True) + root_document: RelatedField[Document, Document, Any] | ManyRelatedField = ( + serializers.PrimaryKeyRelatedField(read_only=True) + ) versions = SerializerMethodField() custom_fields = CustomFieldInstanceSerializer( @@ -1129,6 +1143,8 @@ class DocumentSerializer( @extend_schema_field(DocumentVersionInfoSerializer(many=True)) def get_versions(self, obj): root_doc = obj if obj.root_document_id is None else obj.root_document + if root_doc is None: + return [] versions_qs = Document.objects.filter(root_document=root_doc).only( "id", "added", @@ -1137,7 +1153,7 @@ class DocumentSerializer( ) versions = [*versions_qs, root_doc] - def build_info(doc: Document) -> dict[str, object]: + def build_info(doc: Document) -> _DocumentVersionInfo: return { "id": doc.id, "added": doc.added, @@ -2249,7 +2265,7 @@ class TasksViewSerializer(OwnedObjectSerializer): return list(duplicates.values("id", "title", "deleted_at")) -class RunTaskViewSerializer(serializers.Serializer): +class RunTaskViewSerializer(serializers.Serializer[dict[str, Any]]): task_name = serializers.ChoiceField( choices=PaperlessTask.TaskName.choices, label="Task Name", @@ -2257,7 +2273,7 @@ class RunTaskViewSerializer(serializers.Serializer): ) -class AcknowledgeTasksViewSerializer(serializers.Serializer): +class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]): tasks = serializers.ListField( required=True, label="Tasks", @@ -3004,7 +3020,7 @@ class TrashSerializer(SerializerWithPerms): write_only=True, ) - def validate_documents(self, documents): + def validate_documents(self, documents: list[int]) -> list[int]: count = Document.deleted_objects.filter(id__in=documents).count() if not count == len(documents): raise serializers.ValidationError( diff --git a/src/documents/tests/test_api_document_versions.py b/src/documents/tests/test_api_document_versions.py index 851578d20..29665bcfe 100644 --- a/src/documents/tests/test_api_document_versions.py +++ b/src/documents/tests/test_api_document_versions.py @@ -2,7 +2,7 @@ from __future__ import annotations from unittest import mock -from auditlog.models import LogEntry +from auditlog.models import LogEntry # type: ignore[import-untyped] from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from rest_framework import status diff --git a/src/documents/tests/test_bulk_edit.py b/src/documents/tests/test_bulk_edit.py index b3c5f9f28..4a509d0fe 100644 --- a/src/documents/tests/test_bulk_edit.py +++ b/src/documents/tests/test_bulk_edit.py @@ -1291,12 +1291,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): @mock.patch("pikepdf.open") def test_remove_password_creates_consumable_document( self, - mock_open, - mock_mkdtemp, - mock_consume_file, - mock_group, - mock_chord, - ): + mock_open: mock.Mock, + mock_mkdtemp: mock.Mock, + mock_consume_file: mock.Mock, + mock_group: mock.Mock, + mock_chord: mock.Mock, + ) -> None: doc = self.doc2 temp_dir = self.dirs.scratch_dir / "remove-password" temp_dir.mkdir(parents=True, exist_ok=True) @@ -1305,8 +1305,8 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf = mock.MagicMock() fake_pdf.pages = [mock.Mock(), mock.Mock()] - def save_side_effect(target_path): - Path(target_path).write_bytes(b"password removed") + def save_side_effect(target_path: Path) -> None: + target_path.write_bytes(b"password removed") fake_pdf.save.side_effect = save_side_effect mock_open.return_value.__enter__.return_value = fake_pdf @@ -1348,13 +1348,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): @mock.patch("pikepdf.open") def test_remove_password_deletes_original( self, - mock_open, - mock_mkdtemp, - mock_consume_file, - mock_group, - mock_chord, - mock_delete, - ): + mock_open: mock.Mock, + mock_mkdtemp: mock.Mock, + mock_consume_file: mock.Mock, + mock_group: mock.Mock, + mock_chord: mock.Mock, + mock_delete: mock.Mock, + ) -> None: doc = self.doc2 temp_dir = self.dirs.scratch_dir / "remove-password-delete" temp_dir.mkdir(parents=True, exist_ok=True) @@ -1363,8 +1363,8 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf = mock.MagicMock() fake_pdf.pages = [mock.Mock(), mock.Mock()] - def save_side_effect(target_path): - Path(target_path).write_bytes(b"password removed") + def save_side_effect(target_path: Path) -> None: + target_path.write_bytes(b"password removed") fake_pdf.save.side_effect = save_side_effect mock_open.return_value.__enter__.return_value = fake_pdf @@ -1387,7 +1387,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_delete.si.assert_called_once_with([doc.id]) @mock.patch("pikepdf.open") - def test_remove_password_open_failure(self, mock_open): + def test_remove_password_open_failure(self, mock_open: mock.Mock) -> None: mock_open.side_effect = RuntimeError("wrong password") with self.assertLogs("paperless.bulk_edit", level="ERROR") as cm: diff --git a/src/documents/views.py b/src/documents/views.py index 46732ffb9..a8237c669 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -10,6 +10,7 @@ from collections import deque from datetime import datetime from pathlib import Path from time import mktime +from typing import Any from typing import Literal from unicodedata import normalize from urllib.parse import quote @@ -37,8 +38,10 @@ from django.db.models import Sum from django.db.models import When from django.db.models.functions import Lower from django.db.models.manager import Manager +from django.db.models.query import QuerySet from django.http import FileResponse from django.http import Http404 +from django.http import HttpRequest from django.http import HttpResponse from django.http import HttpResponseBadRequest from django.http import HttpResponseForbidden @@ -83,6 +86,7 @@ from rest_framework.mixins import ListModelMixin from rest_framework.mixins import RetrieveModelMixin from rest_framework.mixins import UpdateModelMixin from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import ModelViewSet @@ -825,6 +829,8 @@ class DocumentViewSet( raise Http404 root_doc = doc if doc.root_document_id is None else doc.root_document + if root_doc is None: + raise Http404 if request.user is not None and not has_perms_owner_aware( request.user, "view_document", @@ -889,7 +895,7 @@ class DocumentViewSet( ): raise Http404 return candidate - latest = root_doc.versions.order_by("id").last() + latest = Document.objects.filter(root_document=root_doc).order_by("id").last() return latest or root_doc def file_response(self, pk, request, disposition): @@ -1922,7 +1928,7 @@ class SavedViewViewSet(ModelViewSet, PassUserMixin): .prefetch_related("filter_rules") ) - def perform_create(self, serializer) -> None: + def perform_create(self, serializer: serializers.BaseSerializer[Any]) -> None: serializer.save(owner=self.request.user) @@ -3466,7 +3472,7 @@ class CustomFieldViewSet(ModelViewSet): queryset = CustomField.objects.all().order_by("-created") - def get_queryset(self): + def get_queryset(self) -> QuerySet[CustomField]: filter = ( Q(fields__document__deleted_at__isnull=True) if self.request.user is None or self.request.user.is_superuser @@ -3779,11 +3785,16 @@ class TrashView(ListModelMixin, PassUserMixin): queryset = Document.deleted_objects.all() - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: self.serializer_class = DocumentSerializer return self.list(request, format) - def post(self, request, *args, **kwargs): + def post( + self, + request: Request, + *args: Any, + **kwargs: Any, + ) -> Response | HttpResponse: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -3807,7 +3818,7 @@ class TrashView(ListModelMixin, PassUserMixin): return Response({"result": "OK", "doc_ids": doc_ids}) -def serve_logo(request, filename=None): +def serve_logo(request: HttpRequest, filename: str | None = None) -> FileResponse: """ Serves the configured logo file with Content-Disposition: attachment. Prevents inline execution of SVGs. See GHSA-6p53-hqqw-8j62