Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AuthenticationBackend #277

Merged
merged 5 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ autoflake==1.4
babel==2.10.3
black==22.6.0
coverage==6.4.2
email-validator==1.2.1
flake8==3.9.2
greenlet==1.1.2
httpx==0.23.0
isort==5.10.1
itsdangerous==2.1.2
mypy==0.971
pytest==7.1.2
pre-commit==2.20.0
greenlet==1.1.2
jinja2==3.1.2
sqlalchemy_utils==0.38.3
email-validator==1.2.1

# Documentation
mkdocs==1.3.1
Expand Down
55 changes: 48 additions & 7 deletions sqladmin/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from starlette.templating import Jinja2Templates

from sqladmin._types import ENGINE_TYPE
from sqladmin.authentication import AuthenticationBackend, login_required
from sqladmin.models import BaseView, ModelView

__all__ = [
Expand All @@ -39,19 +40,27 @@ def __init__(
logo_url: str = None,
templates_dir: str = "templates",
middlewares: Optional[Sequence[Middleware]] = None,
authentication_backend: Optional[AuthenticationBackend] = None,
) -> None:
self.app = app
self.engine = engine
self.base_url = base_url
self.templates_dir = templates_dir
self.title = title
self.logo_url = logo_url

middlewares = middlewares or []
self.authentication_backend = authentication_backend
if authentication_backend:
middlewares = list(middlewares)
middlewares.extend(authentication_backend.middlewares)

self.admin = Starlette(middleware=middlewares)
self._views: List[Union[BaseView, ModelView]] = []

self.templates = self.init_templating_engine(title=title, logo_url=logo_url)
self.templates = self.init_templating_engine()

def init_templating_engine(
self, title: str, logo_url: str = None
) -> Jinja2Templates:
def init_templating_engine(self) -> Jinja2Templates:
templates = Jinja2Templates("templates")
loaders = [
FileSystemLoader(self.templates_dir),
Expand All @@ -61,9 +70,7 @@ def init_templating_engine(
templates.env.loader = ChoiceLoader(loaders)
templates.env.globals["min"] = min
templates.env.globals["zip"] = zip
templates.env.globals["admin_title"] = title
templates.env.globals["admin_logo_url"] = logo_url
templates.env.globals["views"] = self.views
templates.env.globals["admin"] = self
templates.env.globals["is_list"] = lambda x: isinstance(x, list)

return templates
Expand Down Expand Up @@ -254,6 +261,7 @@ def __init__(
middlewares: Optional[Sequence[Middleware]] = None,
debug: bool = False,
templates_dir: str = "templates",
authentication_backend: Optional[AuthenticationBackend] = None,
) -> None:
"""
Args:
Expand All @@ -273,6 +281,7 @@ def __init__(
logo_url=logo_url,
templates_dir=templates_dir,
middlewares=middlewares,
authentication_backend=authentication_backend,
)

statics = StaticFiles(packages=["sqladmin"])
Expand Down Expand Up @@ -317,18 +326,22 @@ def http_exception(request: Request, exc: Exception) -> Response:
name="export",
methods=["GET"],
),
Route("/login", endpoint=self.login, name="login", methods=["GET", "POST"]),
Route("/logout", endpoint=self.logout, name="logout", methods=["GET"]),
]

self.admin.router.routes = routes
self.admin.exception_handlers = {HTTPException: http_exception}
self.admin.debug = debug
self.app.mount(base_url, app=self.admin, name="admin")

@login_required
async def index(self, request: Request) -> Response:
"""Index route which can be overridden to create dashboards."""

return self.templates.TemplateResponse("index.html", {"request": request})

@login_required
async def list(self, request: Request) -> Response:
"""List route to display paginated Model instances."""

Expand All @@ -353,6 +366,7 @@ async def list(self, request: Request) -> Response:

return self.templates.TemplateResponse(model_view.list_template, context)

@login_required
async def details(self, request: Request) -> Response:
"""Details route."""

Expand All @@ -373,6 +387,7 @@ async def details(self, request: Request) -> Response:

return self.templates.TemplateResponse(model_view.details_template, context)

@login_required
async def delete(self, request: Request) -> Response:
"""Delete route."""

Expand All @@ -389,6 +404,7 @@ async def delete(self, request: Request) -> Response:

return Response(content=request.url_for("admin:list", identity=identity))

@login_required
async def create(self, request: Request) -> Response:
"""Create model endpoint."""

Expand Down Expand Up @@ -423,6 +439,7 @@ async def create(self, request: Request) -> Response:
status_code=302,
)

@login_required
async def edit(self, request: Request) -> Response:
"""Edit model endpoint."""

Expand Down Expand Up @@ -461,6 +478,7 @@ async def edit(self, request: Request) -> Response:
status_code=302,
)

@login_required
async def export(self, request: Request) -> Response:
"""Export model endpoint."""

Expand All @@ -473,6 +491,29 @@ async def export(self, request: Request) -> Response:
rows = await model_view.get_model_objects(limit=model_view.export_max_rows)
return model_view.export_data(rows, export_type=export_type)

async def login(self, request: Request) -> Response:
assert self.authentication_backend is not None

context = {"request": request, "error": ""}

if request.method == "GET":
return self.templates.TemplateResponse("login.html", context)

ok = await self.authentication_backend.login(request)
if not ok:
context["error"] = "Invalid credentials."
return self.templates.TemplateResponse(
"login.html", context, status_code=400
)

return RedirectResponse(request.url_for("admin:index"), status_code=302)

async def logout(self, request: Request) -> Response:
assert self.authentication_backend is not None

await self.authentication_backend.logout(request)
return RedirectResponse(request.url_for("admin:index"), status_code=302)


def expose(
path: str,
Expand Down
58 changes: 58 additions & 0 deletions sqladmin/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import functools
from typing import Any, Callable

from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import RedirectResponse


class AuthenticationBackend:
"""Base class for implementing the Authentication into SQLAdmin.
You need to inherit this class and override the methods:
`login`, `logout` and `authenticate`.
"""

def __init__(self, secret_key: str) -> None:
self.middlewares = [
Middleware(SessionMiddleware, secret_key=secret_key),
]

async def login(self, request: Request) -> bool:
"""Implement login logic here.
You can access the login form data `await request.form()`
andvalidate the credentials.
"""
raise NotImplementedError()

async def logout(self, request: Request) -> bool:
"""Implement logout logic here.
This will usually clear the session with `request.session.clear()`.
"""
raise NotImplementedError()

async def authenticate(self, request: Request) -> bool:
"""Implement authenticate logic here.
This method will be called for each incoming request
to validate the authentication.
"""
raise NotImplementedError()


def login_required(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to check authentication of Admin routes.
If no authentication backend is setup, this will do nothing.
"""

@functools.wraps(func)
async def wrapper_decorator(*args: Any, **kwargs: Any) -> Any:
admin, request = args[0], args[1]
auth_backend = admin.authentication_backend
if auth_backend is not None:
is_authenticated = await auth_backend.authenticate(request)
if not is_authenticated:
return RedirectResponse(request.url_for("admin:login"), status_code=302)

return await func(*args, **kwargs)

return wrapper_decorator
2 changes: 1 addition & 1 deletion sqladmin/templates/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<link rel="stylesheet" href="{{ url_for('admin:statics', path='css/main.css') }}">
{% block head %}
{% endblock %}
<title>{{ admin_title }}</title>
<title>{{ admin.title }}</title>
</head>
<body>
{% block body %}
Expand Down
16 changes: 10 additions & 6 deletions sqladmin/templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
<div class="container-fluid">
<h1 class="navbar-brand navbar-brand-autodark">
<a href="{{ url_for('admin:index') }}">
{% if admin_logo_url %}
<img src="{{ admin_logo_url }}" width="64" height="64" alt="Admin" class="navbar-brand-image"/>
{% if admin.logo_url %}
<img src="{{ admin.logo_url }}" width="64" height="64" alt="Admin" class="navbar-brand-image"/>
{% else %}
<h3>{{ admin_title }}</h3>
<h3>{{ admin.title }}</h3>
{% endif %}
</a>
</h1>
<div class="collapse navbar-collapse" id="navbar-menu">
<ul class="navbar-nav pt-lg-3">
{% for view in views %}
{% for view in admin.views %}
{% if view.is_visible(request) and view.is_accessible(request) %}
<li class="nav-item">
{% if view.is_model %}
Expand All @@ -32,6 +32,12 @@ <h3>{{ admin_title }}</h3>
{% endfor %}
</ul>
</div>
{% if admin.authentication_backend %}
<a href="{{ request.url_for('admin:logout') }}" class="btn btn-secondary btn-icon">
<i class="fa fa-sign-out"></i>
<span>Logout</span>
</a>
{% endif %}
</div>
</aside>
<div class="page-wrapper">
Expand All @@ -54,8 +60,6 @@ <h2 class="page-title">{{ title }}</h2>
</div>
</div>
</div>
{% block footer %}
{% endblock %}
</div>
</div>
{% endblock %}
35 changes: 35 additions & 0 deletions sqladmin/templates/login.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{% extends "base.html" %}
{% block body %}
<div class="d-flex align-items-center justify-content-center vh-100">
<form class="Fcol-lg-6 col-md-6 card card-md" action="{{ request.url }}" method="POST" autocomplete="off">
<div class="card-body">
<h2 class="card-title text-center mb-4">Login to {{ admin.title }}</h2>
<div class="mb-3">
<label class="form-label">Username</label>
{% if error %}
<input type="text" class="form-control is-invalid" placeholder="Enter username" autocomplete="off">
<div class="invalid-feedback">{{ error }}</div>
{% else %}
<input type="text" class="form-control" placeholder="Enter username" autocomplete="off">
{% endif %}
</div>
<div class="mb-2">
<label class="form-label">
Password
</label>
<div class="input-group input-group-flat">
{% if error %}
<input type="password" class="form-control is-invalid" placeholder="Password" autocomplete="off">
<div class="invalid-feedback">{{ error }}</div>
{% else %}
<input type="password" class="form-control" placeholder="Password" autocomplete="off">
{% endif %}
</div>
</div>
<div class="form-footer">
<button type="submit" class="btn btn-primary w-100">Login</button>
</div>
</div>
</form>
</div>
{% endblock %}
68 changes: 68 additions & 0 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Generator

import pytest
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.testclient import TestClient

from sqladmin import Admin
from sqladmin.authentication import AuthenticationBackend
from tests.common import sync_engine as engine


class CustomBackend(AuthenticationBackend):
async def login(self, request: Request) -> bool:
form = await request.form()
if form["username"] != "a":
return False

request.session.update({"token": "amin"})
return True

async def logout(self, request: Request) -> bool:
request.session.clear()
return True

async def authenticate(self, request: Request) -> bool:
return "token" in request.session


app = Starlette()
authentication_backend = CustomBackend(secret_key="sqladmin")
admin = Admin(app=app, engine=engine, authentication_backend=authentication_backend)


@pytest.fixture
def client() -> Generator[TestClient, None, None]:
with TestClient(app=app, base_url="http://testserver") as c:
yield c


def test_access_logion_required_views(client: TestClient) -> None:
response = client.get("/admin/")
assert response.url == "http://testserver/admin/login"

response = client.get("/admin/users/list")
assert response.url == "http://testserver/admin/login"


def test_login_failure(client: TestClient) -> None:
response = client.post("/admin/login", data={"username": "x", "password": "b"})

assert response.status_code == 400
assert response.url == "http://testserver/admin/login"


def test_login(client: TestClient) -> None:
response = client.post("/admin/login", data={"username": "a", "password": "b"})

assert len(response.cookies) == 1
assert response.status_code == 302


def test_logout(client: TestClient) -> None:
response = client.get("/admin/logout")

assert len(response.cookies) == 0
assert response.status_code == 200
assert response.url == "http://testserver/admin/login"