Decorator for views that checks that the user passes the given test, redirecting to the log-in page if necessary. The test should be a callable that takes the user object and returns True if the user passes.
(
test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME
)
| 11 | |
| 12 | |
| 13 | def user_passes_test( |
| 14 | test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME |
| 15 | ): |
| 16 | """ |
| 17 | Decorator for views that checks that the user passes the given test, |
| 18 | redirecting to the log-in page if necessary. The test should be a callable |
| 19 | that takes the user object and returns True if the user passes. |
| 20 | """ |
| 21 | |
| 22 | def decorator(view_func): |
| 23 | def _redirect_to_login(request): |
| 24 | path = request.build_absolute_uri() |
| 25 | resolved_login_url = resolve_url(login_url or settings.LOGIN_URL) |
| 26 | # If the login url is the same scheme and net location then just |
| 27 | # use the path as the "next" url. |
| 28 | login_scheme, login_netloc = urlsplit(resolved_login_url)[:2] |
| 29 | current_scheme, current_netloc = urlsplit(path)[:2] |
| 30 | if (not login_scheme or login_scheme == current_scheme) and ( |
| 31 | not login_netloc or login_netloc == current_netloc |
| 32 | ): |
| 33 | path = request.get_full_path() |
| 34 | from django.contrib.auth.views import redirect_to_login |
| 35 | |
| 36 | return redirect_to_login(path, resolved_login_url, redirect_field_name) |
| 37 | |
| 38 | if iscoroutinefunction(view_func): |
| 39 | |
| 40 | async def _view_wrapper(request, *args, **kwargs): |
| 41 | auser = await request.auser() |
| 42 | if iscoroutinefunction(test_func): |
| 43 | test_pass = await test_func(auser) |
| 44 | else: |
| 45 | test_pass = await sync_to_async(test_func)(auser) |
| 46 | |
| 47 | if test_pass: |
| 48 | return await view_func(request, *args, **kwargs) |
| 49 | return _redirect_to_login(request) |
| 50 | |
| 51 | else: |
| 52 | |
| 53 | def _view_wrapper(request, *args, **kwargs): |
| 54 | if iscoroutinefunction(test_func): |
| 55 | test_pass = async_to_sync(test_func)(request.user) |
| 56 | else: |
| 57 | test_pass = test_func(request.user) |
| 58 | |
| 59 | if test_pass: |
| 60 | return view_func(request, *args, **kwargs) |
| 61 | return _redirect_to_login(request) |
| 62 | |
| 63 | # Attributes used by LoginRequiredMiddleware. |
| 64 | _view_wrapper.login_url = login_url |
| 65 | _view_wrapper.redirect_field_name = redirect_field_name |
| 66 | |
| 67 | return wraps(view_func)(_view_wrapper) |
| 68 | |
| 69 | return decorator |
| 70 |
no outgoing calls