2025-03-20 13:34:59 -07:00

69 lines
2.1 KiB
Python

import json
import logging
from asgiref.sync import async_to_sync
from channels.exceptions import AcceptConnection
from channels.exceptions import DenyConnection
from channels.generic.websocket import WebsocketConsumer
from django.db import close_old_connections
from django.db import connections
logger = logging.getLogger("paperless.websockets")
class StatusConsumer(WebsocketConsumer):
def _authenticated(self):
return "user" in self.scope and self.scope["user"].is_authenticated
def _is_owner_or_unowned(self, data):
return (
(
self.scope["user"].is_superuser
or self.scope["user"].id == data["owner_id"]
)
if "owner_id" in data and "user" in self.scope
else True
)
def _discard_database_connections(self):
logger.debug("Discarding %s database connections...", len(connections.all()))
for conn in connections.all():
conn.close()
def connect(self):
logger.debug("Connecting ws...")
close_old_connections()
if not self._authenticated():
raise DenyConnection
else:
async_to_sync(self.channel_layer.group_add)(
"status_updates",
self.channel_name,
)
raise AcceptConnection
def disconnect(self, close_code):
logger.debug("Disconnecting ws...")
self._discard_database_connections()
async_to_sync(self.channel_layer.group_discard)(
"status_updates",
self.channel_name,
)
def close(self, code=None, reason=None):
self._discard_database_connections()
return super().close(code, reason)
def status_update(self, event):
if not self._authenticated():
self.close()
else:
if self._is_owner_or_unowned(event["data"]):
self.send(json.dumps(event))
def documents_deleted(self, event):
if not self._authenticated():
self.close()
else:
self.send(json.dumps(event))