Adds a layer to translate between differing formats of socket based Redis URLs

This commit is contained in:
Trenton H 2022-12-02 09:34:59 -08:00
parent 2e8706f4e2
commit 01d070b882
2 changed files with 68 additions and 4 deletions

View File

@ -8,6 +8,7 @@ import tempfile
from typing import Final
from typing import Optional
from typing import Set
from typing import Tuple
from urllib.parse import urlparse
from celery.schedules import crontab
@ -65,6 +66,34 @@ def __get_path(key: str, default: str) -> str:
return os.path.abspath(os.path.normpath(os.environ.get(key, default)))
def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]:
"""
Gets the Redis information from the environment or a default and handles
converting from incompatible django_channels and celery formats.
Returns a tuple of (celery_url, channels_url)
"""
# Not set, return a compatible default
if env_redis is None:
return ("redis://localhost:6379", "redis://localhost:6379")
_, path = env_redis.split(":")
if "unix" in env_redis.lower():
# channels_redis socket format, looks like:
# "unix:///path/to/redis.sock"
return (f"redis+socket:{path}", env_redis)
elif "+socket" in env_redis.lower():
# celery socket style, looks like:
# "redis+socket:///path/to/redis.sock"
return (env_redis, f"unix:{path}")
# Not a socket
return (env_redis, env_redis)
# NEVER RUN WITH DEBUG IN PRODUCTION.
DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO")
@ -182,7 +211,9 @@ ASGI_APPLICATION = "paperless.asgi.application"
STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/")
WHITENOISE_STATIC_PREFIX = "/static/"
_REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379")
_CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url(
os.getenv("PAPERLESS_REDIS", None),
)
# TODO: what is this used for?
TEMPLATES = [
@ -205,7 +236,7 @@ CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [_REDIS_URL],
"hosts": [_CHANNELS_REDIS_URL],
"capacity": 2000, # default 100
"expiry": 15, # default 60
},
@ -468,7 +499,7 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1)
WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800)
CELERY_BROKER_URL = _REDIS_URL
CELERY_BROKER_URL = _CELERY_REDIS_URL
CELERY_TIMEZONE = TIME_ZONE
CELERY_WORKER_HIJACK_ROOT_LOGGER = False
@ -513,7 +544,7 @@ CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db")
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.redis.RedisCache",
"LOCATION": _REDIS_URL,
"LOCATION": _CHANNELS_REDIS_URL,
},
}

View File

@ -3,6 +3,7 @@ from unittest import mock
from unittest import TestCase
from paperless.settings import _parse_ignore_dates
from paperless.settings import _parse_redis_url
from paperless.settings import default_threads_per_worker
@ -82,3 +83,35 @@ class TestIgnoreDateParsing(TestCase):
self.assertGreaterEqual(default_threads, 1)
self.assertLessEqual(default_workers * default_threads, i)
def test_redis_socket_parsing(self):
"""
GIVEN:
- Various Redis connection URI formats
WHEN:
- The URI is parsed
THEN:
- Socket based URIs are translated
- Non-socket URIs are unchanged
- None provided uses default
"""
for input, expected in [
(None, ("redis://localhost:6379", "redis://localhost:6379")),
(
"redis+socket:///run/redis/redis.sock",
(
"redis+socket:///run/redis/redis.sock",
"unix:///run/redis/redis.sock",
),
),
(
"unix:///run/redis/redis.sock",
(
"redis+socket:///run/redis/redis.sock",
"unix:///run/redis/redis.sock",
),
),
]:
result = _parse_redis_url(input)
self.assertTupleEqual(expected, result)