Adds a layer to translate between differing formats of socket based Redis URLs

This commit is contained in:
Trenton H 2022-12-02 09:34:59 -08:00
parent 2e8706f4e2
commit 01d070b882
2 changed files with 68 additions and 4 deletions

View File

@ -8,6 +8,7 @@ import tempfile
from typing import Final from typing import Final
from typing import Optional from typing import Optional
from typing import Set from typing import Set
from typing import Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from celery.schedules import crontab from celery.schedules import crontab
@ -65,6 +66,34 @@ def __get_path(key: str, default: str) -> str:
return os.path.abspath(os.path.normpath(os.environ.get(key, default))) return os.path.abspath(os.path.normpath(os.environ.get(key, default)))
def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]:
"""
Gets the Redis information from the environment or a default and handles
converting from incompatible django_channels and celery formats.
Returns a tuple of (celery_url, channels_url)
"""
# Not set, return a compatible default
if env_redis is None:
return ("redis://localhost:6379", "redis://localhost:6379")
_, path = env_redis.split(":")
if "unix" in env_redis.lower():
# channels_redis socket format, looks like:
# "unix:///path/to/redis.sock"
return (f"redis+socket:{path}", env_redis)
elif "+socket" in env_redis.lower():
# celery socket style, looks like:
# "redis+socket:///path/to/redis.sock"
return (env_redis, f"unix:{path}")
# Not a socket
return (env_redis, env_redis)
# NEVER RUN WITH DEBUG IN PRODUCTION. # NEVER RUN WITH DEBUG IN PRODUCTION.
DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO") DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO")
@ -182,7 +211,9 @@ ASGI_APPLICATION = "paperless.asgi.application"
STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/") STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/")
WHITENOISE_STATIC_PREFIX = "/static/" WHITENOISE_STATIC_PREFIX = "/static/"
_REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379") _CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url(
os.getenv("PAPERLESS_REDIS", None),
)
# TODO: what is this used for? # TODO: what is this used for?
TEMPLATES = [ TEMPLATES = [
@ -205,7 +236,7 @@ CHANNEL_LAYERS = {
"default": { "default": {
"BACKEND": "channels_redis.core.RedisChannelLayer", "BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": { "CONFIG": {
"hosts": [_REDIS_URL], "hosts": [_CHANNELS_REDIS_URL],
"capacity": 2000, # default 100 "capacity": 2000, # default 100
"expiry": 15, # default 60 "expiry": 15, # default 60
}, },
@ -468,7 +499,7 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1)
WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800) WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800)
CELERY_BROKER_URL = _REDIS_URL CELERY_BROKER_URL = _CELERY_REDIS_URL
CELERY_TIMEZONE = TIME_ZONE CELERY_TIMEZONE = TIME_ZONE
CELERY_WORKER_HIJACK_ROOT_LOGGER = False CELERY_WORKER_HIJACK_ROOT_LOGGER = False
@ -513,7 +544,7 @@ CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db")
CACHES = { CACHES = {
"default": { "default": {
"BACKEND": "django.core.cache.backends.redis.RedisCache", "BACKEND": "django.core.cache.backends.redis.RedisCache",
"LOCATION": _REDIS_URL, "LOCATION": _CHANNELS_REDIS_URL,
}, },
} }

View File

@ -3,6 +3,7 @@ from unittest import mock
from unittest import TestCase from unittest import TestCase
from paperless.settings import _parse_ignore_dates from paperless.settings import _parse_ignore_dates
from paperless.settings import _parse_redis_url
from paperless.settings import default_threads_per_worker from paperless.settings import default_threads_per_worker
@ -82,3 +83,35 @@ class TestIgnoreDateParsing(TestCase):
self.assertGreaterEqual(default_threads, 1) self.assertGreaterEqual(default_threads, 1)
self.assertLessEqual(default_workers * default_threads, i) self.assertLessEqual(default_workers * default_threads, i)
def test_redis_socket_parsing(self):
"""
GIVEN:
- Various Redis connection URI formats
WHEN:
- The URI is parsed
THEN:
- Socket based URIs are translated
- Non-socket URIs are unchanged
- None provided uses default
"""
for input, expected in [
(None, ("redis://localhost:6379", "redis://localhost:6379")),
(
"redis+socket:///run/redis/redis.sock",
(
"redis+socket:///run/redis/redis.sock",
"unix:///run/redis/redis.sock",
),
),
(
"unix:///run/redis/redis.sock",
(
"redis+socket:///run/redis/redis.sock",
"unix:///run/redis/redis.sock",
),
),
]:
result = _parse_redis_url(input)
self.assertTupleEqual(expected, result)