118 lines
4.6 KiB
Python

import logging
from urllib.parse import quote
from allauth.account.adapter import DefaultAccountAdapter
from allauth.core import context
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.forms import ValidationError
from django.urls import reverse
logger = logging.getLogger("paperless.auth")
class CustomAccountAdapter(DefaultAccountAdapter):
def is_open_for_signup(self, request):
"""
Check whether the site is open for signups, which can be
disabled via the ACCOUNT_ALLOW_SIGNUPS setting.
"""
allow_signups = super().is_open_for_signup(request)
# Override with setting, otherwise default to super.
return getattr(settings, "ACCOUNT_ALLOW_SIGNUPS", allow_signups)
def pre_authenticate(self, request, **credentials):
"""
Called prior to calling the authenticate method on the
authentication backend. If login is disabled using DISABLE_REGULAR_LOGIN,
raise ValidationError to prevent the login.
"""
if settings.DISABLE_REGULAR_LOGIN:
raise ValidationError("Regular login is disabled")
return super().pre_authenticate(request, **credentials)
def is_safe_url(self, url):
"""
Check if the URL is a safe URL.
See https://github.com/paperless-ngx/paperless-ngx/issues/5780
"""
from django.utils.http import url_has_allowed_host_and_scheme
# get_host already validates the given host, so no need to check it again
allowed_hosts = {context.request.get_host()} | set(settings.ALLOWED_HOSTS)
if "*" in allowed_hosts:
# dont allow wildcard to allow urls from any host
allowed_hosts.remove("*")
allowed_hosts.add(context.request.get_host())
return url_has_allowed_host_and_scheme(url, allowed_hosts=allowed_hosts)
return url_has_allowed_host_and_scheme(url, allowed_hosts=allowed_hosts)
def get_reset_password_from_key_url(self, key):
"""
Return the URL to reset a password e.g. in reset email.
"""
if settings.PAPERLESS_URL is None:
return super().get_reset_password_from_key_url(key)
else:
path = reverse(
"account_reset_password_from_key",
kwargs={"uidb36": "UID", "key": "KEY"},
)
path = path.replace("UID-KEY", quote(key))
return settings.PAPERLESS_URL + path
def save_user(self, request, user, form, commit=True): # noqa: FBT002
"""
Save the user instance. Default groups are assigned to the user, if
specified in the settings.
"""
user: User = super().save_user(request, user, form, commit)
group_names: list[str] = settings.ACCOUNT_DEFAULT_GROUPS
if len(group_names) > 0:
groups = Group.objects.filter(name__in=group_names)
logger.debug(f"Adding default groups to user `{user}`: {group_names}")
user.groups.add(*groups)
user.save()
return user
class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
def is_open_for_signup(self, request, sociallogin):
"""
Check whether the site is open for signups via social account, which can be
disabled via the SOCIALACCOUNT_ALLOW_SIGNUPS setting.
"""
allow_signups = super().is_open_for_signup(request, sociallogin)
# Override with setting, otherwise default to super.
return getattr(settings, "SOCIALACCOUNT_ALLOW_SIGNUPS", allow_signups)
def get_connect_redirect_url(self, request, socialaccount):
"""
Returns the default URL to redirect to after successfully
connecting a social account.
"""
url = reverse("base")
return url
def save_user(self, request, sociallogin, form=None):
"""
Save the user instance. Default groups are assigned to the user, if
specified in the settings.
"""
# save_user also calls account_adapter save_user which would set ACCOUNT_DEFAULT_GROUPS
user: User = super().save_user(request, sociallogin, form)
group_names: list[str] = settings.SOCIAL_ACCOUNT_DEFAULT_GROUPS
if len(group_names) > 0:
groups = Group.objects.filter(name__in=group_names)
logger.debug(
f"Adding default social groups to user `{user}`: {group_names}",
)
user.groups.add(*groups)
user.save()
return user