Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 125 additions & 44 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,26 @@
import json
import sys
import traceback
from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
from datetime import datetime
from pathlib import Path
from timeit import default_timer as timer
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, BinaryIO, get_args, get_type_hints

from fastapi import FastAPI, HTTPException, Request
from fastapi import UploadFile as FastAPIUploadFile
from fastapi.middleware import cors
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer
from socketio import ASGIApp as EngineIOApp
from socketio import AsyncNamespace, AsyncServer
from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.exceptions import HTTPException
from starlette.middleware import cors
from starlette.requests import Request
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.staticfiles import StaticFiles
from typing_extensions import deprecated

from reflex import constants
from reflex.admin import AdminDash
Expand Down Expand Up @@ -101,6 +105,7 @@
)
from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
from reflex.utils.imports import ImportVar
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send

if TYPE_CHECKING:
from reflex.vars import Var
Expand Down Expand Up @@ -388,7 +393,7 @@ class App(MiddlewareMixin, LifespanMixin):
_stateful_pages: dict[str, None] = dataclasses.field(default_factory=dict)

# The backend API object.
_api: FastAPI | None = None
_api: Starlette | None = None

# The state class to use for the app.
_state: type[BaseState] | None = None
Expand Down Expand Up @@ -423,14 +428,34 @@ class App(MiddlewareMixin, LifespanMixin):
# Put the toast provider in the app wrap.
toaster: Component | None = dataclasses.field(default_factory=toast.provider)

# Transform the ASGI app before running it.
api_transformer: (
Sequence[Callable[[ASGIApp], ASGIApp] | Starlette]
| Callable[[ASGIApp], ASGIApp]
| Starlette
| None
) = None

# FastAPI app for compatibility with FastAPI.
_cached_fastapi_app: FastAPI | None = None

@property
def api(self) -> FastAPI | None:
@deprecated("Use `api_transformer=your_fastapi_app` instead.")
def api(self) -> FastAPI:
"""Get the backend api.

Returns:
The backend api.
"""
return self._api
if self._cached_fastapi_app is None:
self._cached_fastapi_app = FastAPI()
console.deprecate(
feature_name="App.api",
reason="Set `api_transformer=your_fastapi_app` instead.",
deprecation_version="0.7.9",
removal_version="0.8.0",
)
return self._cached_fastapi_app

@property
def event_namespace(self) -> EventNamespace | None:
Expand Down Expand Up @@ -462,7 +487,7 @@ def __post_init__(self):
set_breakpoints(self.style.pop("breakpoints"))

# Set up the API.
self._api = FastAPI(lifespan=self._run_lifespan_tasks)
self._api = Starlette(lifespan=self._run_lifespan_tasks)
self._add_cors()
self._add_default_endpoints()

Expand Down Expand Up @@ -528,7 +553,7 @@ def _setup_state(self) -> None:
)

# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
socket_app = ASGIApp(self.sio, socketio_path="")
socket_app = EngineIOApp(self.sio, socketio_path="")
namespace = config.get_event_namespace()

# Create the event namespace and attach the main app. Not related to any paths.
Expand All @@ -537,18 +562,16 @@ def _setup_state(self) -> None:
# Register the event namespace with the socket.
self.sio.register_namespace(self.event_namespace)
# Mount the socket app with the API.
if self.api:
if self._api:

class HeaderMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(
self, scope: MutableMapping[str, Any], receive: Any, send: Callable
):
async def __call__(self, scope: Scope, receive: Receive, send: Send):
original_send = send

async def modified_send(message: dict):
async def modified_send(message: Message):
if message["type"] == "websocket.accept":
if scope.get("subprotocols"):
# The following *does* say "subprotocol" instead of "subprotocols", intentionally.
Expand All @@ -567,7 +590,7 @@ async def modified_send(message: dict):
return await self.app(scope, receive, modified_send)

socket_app_with_headers = HeaderMiddleware(socket_app)
self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
self._api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)

# Check the exception handlers
self._validate_exception_handlers()
Expand All @@ -580,7 +603,7 @@ def __repr__(self) -> str:
"""
return f"<App state={self._state.__name__ if self._state else None}>"

def __call__(self) -> FastAPI:
def __call__(self) -> ASGIApp:
"""Run the backend api instance.

Raises:
Expand All @@ -589,8 +612,18 @@ def __call__(self) -> FastAPI:
Returns:
The backend api.
"""
if not self.api:
raise ValueError("The app has not been initialized.")
if self._cached_fastapi_app is not None:
asgi_app = self._cached_fastapi_app

if not asgi_app or not self._api:
raise ValueError("The app has not been initialized.")

asgi_app.mount("", self._api)
else:
asgi_app = self._api

if not asgi_app:
raise ValueError("The app has not been initialized.")

# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
Expand All @@ -607,30 +640,58 @@ def __call__(self) -> FastAPI:
if is_prod_mode():
compile_future.result()

return self.api
if self.api_transformer is not None:
api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
[self.api_transformer]
if not isinstance(self.api_transformer, Sequence)
else self.api_transformer
)

for api_transformer in api_transformers:
if isinstance(api_transformer, Starlette):
# Mount the api to the fastapi app.
api_transformer.mount("", asgi_app)
asgi_app = api_transformer
else:
# Transform the asgi app.
asgi_app = api_transformer(asgi_app)

return asgi_app

def _add_default_endpoints(self):
"""Add default api endpoints (ping)."""
# To test the server.
if not self.api:
if not self._api:
return

self.api.get(str(constants.Endpoint.PING))(ping)
self.api.get(str(constants.Endpoint.HEALTH))(health)
self._api.add_route(
str(constants.Endpoint.PING),
ping,
methods=["GET"],
)
self._api.add_route(
str(constants.Endpoint.HEALTH),
health,
methods=["GET"],
)

def _add_optional_endpoints(self):
"""Add optional api endpoints (_upload)."""
if not self.api:
if not self._api:
return
upload_is_used_marker = (
prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
)
if Upload.is_used or upload_is_used_marker.exists():
# To upload files.
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
self._api.add_route(
str(constants.Endpoint.UPLOAD),
upload(self),
methods=["POST"],
)

# To access uploaded files.
self.api.mount(
self._api.mount(
str(constants.Endpoint.UPLOAD),
StaticFiles(directory=get_upload_dir()),
name="uploaded_files",
Expand All @@ -639,17 +700,19 @@ def _add_optional_endpoints(self):
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
upload_is_used_marker.touch()
if codespaces.is_running_in_codespaces():
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
codespaces.auth_codespace
self._api.add_route(
str(constants.Endpoint.AUTH_CODESPACE),
codespaces.auth_codespace,
methods=["GET"],
)
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
self.add_all_routes_endpoint()

def _add_cors(self):
"""Add CORS middleware to the app."""
if not self.api:
if not self._api:
return
self.api.add_middleware(
self._api.add_middleware(
cors.CORSMiddleware,
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -914,7 +977,7 @@ def _setup_admin_dash(self):
return

# Get the admin dash.
if not self.api:
if not self._api:
return

admin_dash = self.admin_dash
Expand All @@ -935,7 +998,7 @@ def _setup_admin_dash(self):
view = admin_dash.view_overrides.get(model, ModelView)
admin.add_view(view(model))

admin.mount_to(self.api)
admin.mount_to(self._api)

def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]):
"""Gets the frontend packages to be installed and filters out the unnecessary ones.
Expand Down Expand Up @@ -1433,12 +1496,15 @@ def _write_stateful_pages_marker(self):

def add_all_routes_endpoint(self):
"""Add an endpoint to the app that returns all the routes."""
if not self.api:
if not self._api:
return

@self.api.get(str(constants.Endpoint.ALL_ROUTES))
async def all_routes():
return list(self._unevaluated_pages.keys())
async def all_routes(_request: Request) -> Response:
return JSONResponse(list(self._unevaluated_pages.keys()))

self._api.add_route(
str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"]
)

@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
Expand Down Expand Up @@ -1693,18 +1759,24 @@ async def process(
raise


async def ping() -> str:
async def ping(_request: Request) -> Response:
"""Test API endpoint.

Args:
_request: The Starlette request object.

Returns:
The response.
"""
return "pong"
return JSONResponse("pong")


async def health() -> JSONResponse:
async def health(_request: Request) -> JSONResponse:
"""Health check endpoint to assess the status of the database and Redis services.

Args:
_request: The Starlette request object.

Returns:
JSONResponse: A JSON object with the health status:
- "status" (bool): Overall health, True if all checks pass.
Expand Down Expand Up @@ -1746,12 +1818,11 @@ def upload(app: App):
The upload function.
"""

async def upload_file(request: Request, files: list[FastAPIUploadFile]):
async def upload_file(request: Request):
"""Upload a file.

Args:
request: The FastAPI request object.
files: The file(s) to upload.
request: The Starlette request object.

Returns:
StreamingResponse yielding newline-delimited JSON of StateUpdate
Expand All @@ -1764,6 +1835,12 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
"""
from reflex.utils.exceptions import UploadTypeError, UploadValueError

# Get the files from the request.
files = await request.form()
files = files.getlist("files")
if not files:
raise UploadValueError("No files were uploaded.")

token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")

Expand Down Expand Up @@ -1816,6 +1893,10 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]):
# event is handled.
file_copies = []
for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
)
content_copy = io.BytesIO()
content_copy.write(await file.read())
content_copy.seek(0)
Expand Down
4 changes: 2 additions & 2 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import inspect
from collections.abc import Callable, Coroutine

from fastapi import FastAPI
from starlette.applications import Starlette

from reflex.utils import console
from reflex.utils.exceptions import InvalidLifespanTaskTypeError
Expand All @@ -27,7 +27,7 @@ class LifespanMixin(AppMixin):
)

@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: FastAPI):
async def _run_lifespan_tasks(self, app: Starlette):
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
Expand Down
4 changes: 2 additions & 2 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ async def _shutdown(*args, **kwargs) -> None:
return _shutdown

def _start_backend(self, port: int = 0):
if self.app_instance is None or self.app_instance.api is None:
if self.app_instance is None or self.app_instance._api is None:
raise RuntimeError("App was not initialized.")
self.backend = uvicorn.Server(
uvicorn.Config(
app=self.app_instance.api,
app=self.app_instance._api,
host="127.0.0.1",
port=port,
)
Expand Down
Loading
Loading