diff --git a/src/documents/plugins/helpers.py b/src/documents/plugins/helpers.py index f012a0dcb..1e2139aba 100644 --- a/src/documents/plugins/helpers.py +++ b/src/documents/plugins/helpers.py @@ -1,4 +1,5 @@ import enum +from collections.abc import Mapping from typing import TYPE_CHECKING from asgiref.sync import async_to_sync @@ -47,7 +48,7 @@ class BaseStatusManager: async_to_sync(self._channel.flush) self._channel = None - def send(self, payload: dict[str, str | int | None]) -> None: + def send(self, payload: Mapping[str, object]) -> None: # Ensure the layer is open self.open() @@ -73,26 +74,28 @@ class ProgressManager(BaseStatusManager): max_progress: int, extra_args: dict[str, str | int | None] | None = None, ) -> None: - payload = { - "type": "status_update", - "data": { - "filename": self.filename, - "task_id": self.task_id, - "current_progress": current_progress, - "max_progress": max_progress, - "status": status, - "message": message, - }, + data: dict[str, object] = { + "filename": self.filename, + "task_id": self.task_id, + "current_progress": current_progress, + "max_progress": max_progress, + "status": status, + "message": message, } if extra_args is not None: - payload["data"].update(extra_args) + data.update(extra_args) + + payload: dict[str, object] = { + "type": "status_update", + "data": data, + } self.send(payload) class DocumentsStatusManager(BaseStatusManager): def send_documents_deleted(self, documents: list[int]) -> None: - payload = { + payload: dict[str, object] = { "type": "documents_deleted", "data": { "documents": documents, @@ -110,7 +113,7 @@ class DocumentsStatusManager(BaseStatusManager): users_can_view: list[int] | None = None, groups_can_view: list[int] | None = None, ) -> None: - payload: dict[str, str | int | None] = { + payload: dict[str, object] = { "type": "document_updated", "data": { "document_id": document_id, diff --git a/src/documents/tests/test_workflows.py b/src/documents/tests/test_workflows.py index 56dc3157e..59745d8ec 100644 --- a/src/documents/tests/test_workflows.py +++ b/src/documents/tests/test_workflows.py @@ -3,6 +3,7 @@ import json import shutil import socket import tempfile +from collections.abc import Callable from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -4135,7 +4136,7 @@ class TestWebhookSecurity: def test_strips_user_supplied_host_header( self, httpx_mock: HTTPXMock, - resolve_to, + resolve_to: Callable[[str], None], ) -> None: """ GIVEN: