diff --git a/src/paperless/consumers.py b/src/paperless/consumers.py index c72b58aa7..8a58fce99 100644 --- a/src/paperless/consumers.py +++ b/src/paperless/consumers.py @@ -1,9 +1,14 @@ import json +import logging from asgiref.sync import async_to_sync from channels.exceptions import AcceptConnection from channels.exceptions import DenyConnection from channels.generic.websocket import WebsocketConsumer +from django.db import close_old_connections +from django.db import connections + +logger = logging.getLogger("paperless.websockets") class StatusConsumer(WebsocketConsumer): @@ -20,7 +25,14 @@ class StatusConsumer(WebsocketConsumer): else True ) + def _discard_database_connections(self): + logger.debug("Discarding %s database connections...", len(connections.all())) + for conn in connections.all(): + conn.close() + def connect(self): + logger.debug("Connecting ws...") + close_old_connections() if not self._authenticated(): raise DenyConnection else: @@ -31,11 +43,17 @@ class StatusConsumer(WebsocketConsumer): raise AcceptConnection def disconnect(self, close_code): + logger.debug("Disconnecting ws...") + self._discard_database_connections() async_to_sync(self.channel_layer.group_discard)( "status_updates", self.channel_name, ) + def close(self, code=None, reason=None): + self._discard_database_connections() + return super().close(code, reason) + def status_update(self, event): if not self._authenticated(): self.close()