diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 456e15745..eef7344da 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -8,6 +8,7 @@ import tempfile from typing import Final from typing import Optional from typing import Set +from typing import Tuple from urllib.parse import urlparse from celery.schedules import crontab @@ -65,6 +66,34 @@ def __get_path(key: str, default: str) -> str: return os.path.abspath(os.path.normpath(os.environ.get(key, default))) +def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]: + """ + Gets the Redis information from the environment or a default and handles + converting from incompatible django_channels and celery formats. + + Returns a tuple of (celery_url, channels_url) + """ + + # Not set, return a compatible default + if env_redis is None: + return ("redis://localhost:6379", "redis://localhost:6379") + + _, path = env_redis.split(":") + + if "unix" in env_redis.lower(): + # channels_redis socket format, looks like: + # "unix:///path/to/redis.sock" + return (f"redis+socket:{path}", env_redis) + + elif "+socket" in env_redis.lower(): + # celery socket style, looks like: + # "redis+socket:///path/to/redis.sock" + return (env_redis, f"unix:{path}") + + # Not a socket + return (env_redis, env_redis) + + # NEVER RUN WITH DEBUG IN PRODUCTION. DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO") @@ -182,7 +211,9 @@ ASGI_APPLICATION = "paperless.asgi.application" STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/") WHITENOISE_STATIC_PREFIX = "/static/" -_REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379") +_CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url( + os.getenv("PAPERLESS_REDIS", None), +) # TODO: what is this used for? TEMPLATES = [ @@ -205,7 +236,7 @@ CHANNEL_LAYERS = { "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", "CONFIG": { - "hosts": [_REDIS_URL], + "hosts": [_CHANNELS_REDIS_URL], "capacity": 2000, # default 100 "expiry": 15, # default 60 }, @@ -468,7 +499,7 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1) WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800) -CELERY_BROKER_URL = _REDIS_URL +CELERY_BROKER_URL = _CELERY_REDIS_URL CELERY_TIMEZONE = TIME_ZONE CELERY_WORKER_HIJACK_ROOT_LOGGER = False @@ -513,7 +544,7 @@ CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db") CACHES = { "default": { "BACKEND": "django.core.cache.backends.redis.RedisCache", - "LOCATION": _REDIS_URL, + "LOCATION": _CHANNELS_REDIS_URL, }, } diff --git a/src/paperless/tests/test_settings.py b/src/paperless/tests/test_settings.py index fed4079e2..fa839299f 100644 --- a/src/paperless/tests/test_settings.py +++ b/src/paperless/tests/test_settings.py @@ -3,6 +3,7 @@ from unittest import mock from unittest import TestCase from paperless.settings import _parse_ignore_dates +from paperless.settings import _parse_redis_url from paperless.settings import default_threads_per_worker @@ -82,3 +83,35 @@ class TestIgnoreDateParsing(TestCase): self.assertGreaterEqual(default_threads, 1) self.assertLessEqual(default_workers * default_threads, i) + + def test_redis_socket_parsing(self): + """ + GIVEN: + - Various Redis connection URI formats + WHEN: + - The URI is parsed + THEN: + - Socket based URIs are translated + - Non-socket URIs are unchanged + - None provided uses default + """ + + for input, expected in [ + (None, ("redis://localhost:6379", "redis://localhost:6379")), + ( + "redis+socket:///run/redis/redis.sock", + ( + "redis+socket:///run/redis/redis.sock", + "unix:///run/redis/redis.sock", + ), + ), + ( + "unix:///run/redis/redis.sock", + ( + "redis+socket:///run/redis/redis.sock", + "unix:///run/redis/redis.sock", + ), + ), + ]: + result = _parse_redis_url(input) + self.assertTupleEqual(expected, result)