Skip to content

Commit

Permalink
Switch to async templates (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Oct 23, 2023
1 parent b5355b5 commit 7f98575
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 99 deletions.
12 changes: 5 additions & 7 deletions docs/writing_custom_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ To add custom views to the Admin interface, you can use the `BaseView` included
icon = "fa-chart-line"

@expose("/report", methods=["GET"])
def report_page(self, request):
return self.templates.TemplateResponse(
"report.html",
context={"request": request},
)
async def report_page(self, request):
return await self.templates.TemplateResponse(request, "report.html")

admin.add_view(ReportView)
```
Expand Down Expand Up @@ -77,9 +74,10 @@ The example above was very basic and you probably want to access database and SQ
result = await session.execute(stmt)
users_count = result.scalar_one()

return self.templates.TemplateResponse(
return await self.templates.TemplateResponse(
request,
"report.html",
context={"request": request, "users_count": users_count},
context={"users_count": users_count},
)


Expand Down
85 changes: 40 additions & 45 deletions sqladmin/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Mount, Route
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates

from sqladmin._menu import CategoryMenu, Menu, ViewMenu
from sqladmin._types import ENGINE_TYPE
Expand All @@ -41,6 +40,7 @@
slugify_action_name,
)
from sqladmin.models import BaseView, ModelView
from sqladmin.templating import Jinja2Templates

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand Down Expand Up @@ -247,11 +247,8 @@ class CustomAdmin(BaseView):
icon = "fa-solid fa-chart-line"
@expose("/custom", methods=["GET"])
def test_page(self, request: Request):
return self.templates.TemplateResponse(
"custom.html",
context={"request": request},
)
async def test_page(self, request: Request):
return await self.templates.TemplateResponse(request, "custom.html")
admin.add_base_view(CustomAdmin)
```
Expand Down Expand Up @@ -373,16 +370,15 @@ def __init__(

statics = StaticFiles(packages=["sqladmin"])

def http_exception(request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
context = {
"request": request,
"status_code": exc.status_code,
"message": exc.detail,
}
return self.templates.TemplateResponse(
"error.html", context, status_code=exc.status_code
)
# def http_exception(request: Request, exc: Exception) -> Response:
# assert isinstance(exc, HTTPException)
# context = {
# "status_code": exc.status_code,
# "message": exc.detail,
# }
# return self.templates.TemplateResponse(
# request, "error.html", context, status_code=exc.status_code
# )

routes = [
Mount("/statics", app=statics, name="statics"),
Expand Down Expand Up @@ -418,15 +414,15 @@ def http_exception(request: Request, exc: Exception) -> Response:
]

self.admin.router.routes = routes
self.admin.exception_handlers = {HTTPException: http_exception}
# self.admin.exception_handlers = {HTTPException: http_exception}
self.admin.debug = debug
self.app.mount(base_url, app=self.admin, name="admin")

@login_required
async def index(self, request: Request) -> Response:
"""Index route which can be overridden to create dashboards."""

return self.templates.TemplateResponse("index.html", {"request": request})
return await self.templates.TemplateResponse(request, "index.html")

@login_required
async def list(self, request: Request) -> Response:
Expand All @@ -438,13 +434,10 @@ async def list(self, request: Request) -> Response:
pagination = await model_view.list(request)
pagination.add_pagination_urls(request.url)

context = {
"request": request,
"model_view": model_view,
"pagination": pagination,
}

return self.templates.TemplateResponse(model_view.list_template, context)
context = {"model_view": model_view, "pagination": pagination}
return await self.templates.TemplateResponse(
request, model_view.list_template, context
)

@login_required
async def details(self, request: Request) -> Response:
Expand All @@ -459,13 +452,14 @@ async def details(self, request: Request) -> Response:
raise HTTPException(status_code=404)

context = {
"request": request,
"model_view": model_view,
"model": model,
"title": model_view.name,
}

return self.templates.TemplateResponse(model_view.details_template, context)
return await self.templates.TemplateResponse(
request, model_view.details_template, context
)

@login_required
async def delete(self, request: Request) -> Response:
Expand Down Expand Up @@ -501,26 +495,27 @@ async def create(self, request: Request) -> Response:
form = Form(form_data)

context = {
"request": request,
"model_view": model_view,
"form": form,
}

if request.method == "GET":
return self.templates.TemplateResponse(model_view.create_template, context)
return await self.templates.TemplateResponse(
request, model_view.create_template, context
)

if not form.validate():
return self.templates.TemplateResponse(
model_view.create_template, context, status_code=400
return await self.templates.TemplateResponse(
request, model_view.create_template, context, status_code=400
)

try:
obj = await model_view.insert_model(request, form.data)
except Exception as e:
logger.exception(e)
context["error"] = str(e)
return self.templates.TemplateResponse(
model_view.create_template, context, status_code=400
return await self.templates.TemplateResponse(
request, model_view.create_template, context, status_code=400
)

url = self.get_save_redirect_url(
Expand All @@ -546,21 +541,22 @@ async def edit(self, request: Request) -> Response:

Form = await model_view.scaffold_form()
context = {
"request": request,
"obj": model,
"model_view": model_view,
"form": Form(obj=model),
}

if request.method == "GET":
return self.templates.TemplateResponse(model_view.edit_template, context)
return await self.templates.TemplateResponse(
request, model_view.edit_template, context
)

form_data = await self._handle_form_data(request, model)
form = Form(form_data)
if not form.validate():
context["form"] = form
return self.templates.TemplateResponse(
model_view.edit_template, context, status_code=400
return await self.templates.TemplateResponse(
request, model_view.edit_template, context, status_code=400
)

try:
Expand All @@ -573,8 +569,8 @@ async def edit(self, request: Request) -> Response:
except Exception as e:
logger.exception(e)
context["error"] = str(e)
return self.templates.TemplateResponse(
model_view.edit_template, context, status_code=400
return await self.templates.TemplateResponse(
request, model_view.edit_template, context, status_code=400
)

url = self.get_save_redirect_url(
Expand All @@ -598,21 +594,20 @@ async def export(self, request: Request) -> Response:
rows = await model_view.get_model_objects(
request=request, limit=model_view.export_max_rows
)
return model_view.export_data(rows, export_type=export_type)
return await model_view.export_data(rows, export_type=export_type)

async def login(self, request: Request) -> Response:
assert self.authentication_backend is not None

context = {"request": request, "error": ""}

context = {}
if request.method == "GET":
return self.templates.TemplateResponse("login.html", context)
return await self.templates.TemplateResponse(request, "login.html")

ok = await self.authentication_backend.login(request)
if not ok:
context["error"] = "Invalid credentials."
return self.templates.TemplateResponse(
"login.html", context, status_code=400
return await self.templates.TemplateResponse(
request, "login.html", context, status_code=400
)

return RedirectResponse(request.url_for("admin:index"), status_code=302)
Expand Down
3 changes: 2 additions & 1 deletion sqladmin/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import timedelta
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Expand Down Expand Up @@ -158,7 +159,7 @@ def write(self, value: T) -> T:


def stream_to_csv(
callback: Callable[[Writer], Generator[T, None, None]]
callback: Callable[[Writer], AsyncGenerator[T, None]]
) -> Generator[T, None, None]:
"""Function that takes a callable (that yields from a CSV Writer), and
provides it a writer that streams the output directly instead of
Expand Down
36 changes: 17 additions & 19 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
Sequence,
Expand All @@ -26,7 +26,6 @@
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.templating import Jinja2Templates
from wtforms import Field, Form

from sqladmin._queries import Query
Expand All @@ -45,7 +44,10 @@
slugify_class_name,
stream_to_csv,
)

# stream_to_csv,
from sqladmin.pagination import Pagination
from sqladmin.templating import Jinja2Templates

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand Down Expand Up @@ -138,11 +140,8 @@ class CustomAdmin(BaseView):
icon = "fa-solid fa-chart-line"
@expose("/custom", methods=["GET"])
def test_page(self, request: Request):
return self.templates.TemplateResponse(
"custom.html",
context={"request": request},
)
async def test_page(self, request: Request):
return await self.templates.TemplateResponse(request, "custom.html")
admin.add_base_view(CustomAdmin)
```
Expand Down Expand Up @@ -846,28 +845,28 @@ def _stmt_by_identifier(self, identifier: str) -> Select:

return stmt.where(*conditions)

def get_prop_value(self, obj: Any, prop: str) -> Any:
async def get_prop_value(self, obj: Any, prop: str) -> Any:
result = getattr(obj, prop, None)
if result and isinstance(result, Enum):
result = result.name

return result

def get_list_value(self, obj: Any, prop: str) -> Tuple[Any, Any]:
async def get_list_value(self, obj: Any, prop: str) -> Tuple[Any, Any]:
"""Get tuple of (value, formatted_value) for the list view."""

value = self.get_prop_value(obj, prop)
value = await self.get_prop_value(obj, prop)
formatted_value = self._default_formatter(value)

formatter = self._list_formatters.get(prop)
if formatter:
formatted_value = formatter(obj, prop)
return value, formatted_value

def get_detail_value(self, obj: Any, prop: str) -> Tuple[Any, Any]:
async def get_detail_value(self, obj: Any, prop: str) -> Tuple[Any, Any]:
"""Get tuple of (value, formatted_value) for the detail view."""

value = self.get_prop_value(obj, prop)
value = await self.get_prop_value(obj, prop)
formatted_value = self._default_formatter(value)

formatter = self._detail_formatters.get(prop)
Expand Down Expand Up @@ -1083,27 +1082,26 @@ def get_export_name(self, export_type: str) -> str:

return f"{self.name}_{time.strftime('%Y-%m-%d_%H-%M-%S')}.{export_type}"

def export_data(
async def export_data(
self,
data: List[Any],
export_type: str = "csv",
) -> StreamingResponse:
if export_type == "csv":
return self._export_csv(data)
else:
raise NotImplementedError("Only export_type='csv' is implemented.")
return await self._export_csv(data)
raise NotImplementedError("Only export_type='csv' is implemented.")

def _export_csv(
async def _export_csv(
self,
data: List[Any],
) -> StreamingResponse:
def generate(writer: Writer) -> Generator[Any, None, None]:
async def generate(writer: Writer) -> AsyncGenerator[Any, None]:
# Append the column titles at the beginning
yield writer.writerow(self._export_prop_names)

for row in data:
vals = [
str(self.get_prop_value(row, name))
str(await self.get_prop_value(row, name))
for name in self._export_prop_names
]
yield writer.writerow(vals)
Expand Down
Loading

0 comments on commit 7f98575

Please sign in to comment.