diff --git a/reflex/app.py b/reflex/app.py index 79073fe08c8..f0bd0d6b807 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -13,22 +13,26 @@ import json import sys import traceback -from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping +from collections.abc import AsyncIterator, Callable, Coroutine, Sequence from datetime import datetime from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace from typing import TYPE_CHECKING, Any, BinaryIO, get_args, get_type_hints -from fastapi import FastAPI, HTTPException, Request -from fastapi import UploadFile as FastAPIUploadFile -from fastapi.middleware import cors -from fastapi.responses import JSONResponse, StreamingResponse -from fastapi.staticfiles import StaticFiles +from fastapi import FastAPI from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn -from socketio import ASGIApp, AsyncNamespace, AsyncServer +from socketio import ASGIApp as EngineIOApp +from socketio import AsyncNamespace, AsyncServer +from starlette.applications import Starlette from starlette.datastructures import Headers from starlette.datastructures import UploadFile as StarletteUploadFile +from starlette.exceptions import HTTPException +from starlette.middleware import cors +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.staticfiles import StaticFiles +from typing_extensions import deprecated from reflex import constants from reflex.admin import AdminDash @@ -101,6 +105,7 @@ ) from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar +from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send if TYPE_CHECKING: from reflex.vars import Var @@ -388,7 +393,7 @@ class App(MiddlewareMixin, LifespanMixin): _stateful_pages: dict[str, None] = dataclasses.field(default_factory=dict) # The backend API object. - _api: FastAPI | None = None + _api: Starlette | None = None # The state class to use for the app. _state: type[BaseState] | None = None @@ -423,14 +428,34 @@ class App(MiddlewareMixin, LifespanMixin): # Put the toast provider in the app wrap. toaster: Component | None = dataclasses.field(default_factory=toast.provider) + # Transform the ASGI app before running it. + api_transformer: ( + Sequence[Callable[[ASGIApp], ASGIApp] | Starlette] + | Callable[[ASGIApp], ASGIApp] + | Starlette + | None + ) = None + + # FastAPI app for compatibility with FastAPI. + _cached_fastapi_app: FastAPI | None = None + @property - def api(self) -> FastAPI | None: + @deprecated("Use `api_transformer=your_fastapi_app` instead.") + def api(self) -> FastAPI: """Get the backend api. Returns: The backend api. """ - return self._api + if self._cached_fastapi_app is None: + self._cached_fastapi_app = FastAPI() + console.deprecate( + feature_name="App.api", + reason="Set `api_transformer=your_fastapi_app` instead.", + deprecation_version="0.7.9", + removal_version="0.8.0", + ) + return self._cached_fastapi_app @property def event_namespace(self) -> EventNamespace | None: @@ -462,7 +487,7 @@ def __post_init__(self): set_breakpoints(self.style.pop("breakpoints")) # Set up the API. - self._api = FastAPI(lifespan=self._run_lifespan_tasks) + self._api = Starlette(lifespan=self._run_lifespan_tasks) self._add_cors() self._add_default_endpoints() @@ -528,7 +553,7 @@ def _setup_state(self) -> None: ) # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. - socket_app = ASGIApp(self.sio, socketio_path="") + socket_app = EngineIOApp(self.sio, socketio_path="") namespace = config.get_event_namespace() # Create the event namespace and attach the main app. Not related to any paths. @@ -537,18 +562,16 @@ def _setup_state(self) -> None: # Register the event namespace with the socket. self.sio.register_namespace(self.event_namespace) # Mount the socket app with the API. - if self.api: + if self._api: class HeaderMiddleware: def __init__(self, app: ASGIApp): self.app = app - async def __call__( - self, scope: MutableMapping[str, Any], receive: Any, send: Callable - ): + async def __call__(self, scope: Scope, receive: Receive, send: Send): original_send = send - async def modified_send(message: dict): + async def modified_send(message: Message): if message["type"] == "websocket.accept": if scope.get("subprotocols"): # The following *does* say "subprotocol" instead of "subprotocols", intentionally. @@ -567,7 +590,7 @@ async def modified_send(message: dict): return await self.app(scope, receive, modified_send) socket_app_with_headers = HeaderMiddleware(socket_app) - self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers) + self._api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers) # Check the exception handlers self._validate_exception_handlers() @@ -580,7 +603,7 @@ def __repr__(self) -> str: """ return f"" - def __call__(self) -> FastAPI: + def __call__(self) -> ASGIApp: """Run the backend api instance. Raises: @@ -589,8 +612,18 @@ def __call__(self) -> FastAPI: Returns: The backend api. """ - if not self.api: - raise ValueError("The app has not been initialized.") + if self._cached_fastapi_app is not None: + asgi_app = self._cached_fastapi_app + + if not asgi_app or not self._api: + raise ValueError("The app has not been initialized.") + + asgi_app.mount("", self._api) + else: + asgi_app = self._api + + if not asgi_app: + raise ValueError("The app has not been initialized.") # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). @@ -607,30 +640,58 @@ def __call__(self) -> FastAPI: if is_prod_mode(): compile_future.result() - return self.api + if self.api_transformer is not None: + api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = ( + [self.api_transformer] + if not isinstance(self.api_transformer, Sequence) + else self.api_transformer + ) + + for api_transformer in api_transformers: + if isinstance(api_transformer, Starlette): + # Mount the api to the fastapi app. + api_transformer.mount("", asgi_app) + asgi_app = api_transformer + else: + # Transform the asgi app. + asgi_app = api_transformer(asgi_app) + + return asgi_app def _add_default_endpoints(self): """Add default api endpoints (ping).""" # To test the server. - if not self.api: + if not self._api: return - self.api.get(str(constants.Endpoint.PING))(ping) - self.api.get(str(constants.Endpoint.HEALTH))(health) + self._api.add_route( + str(constants.Endpoint.PING), + ping, + methods=["GET"], + ) + self._api.add_route( + str(constants.Endpoint.HEALTH), + health, + methods=["GET"], + ) def _add_optional_endpoints(self): """Add optional api endpoints (_upload).""" - if not self.api: + if not self._api: return upload_is_used_marker = ( prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED ) if Upload.is_used or upload_is_used_marker.exists(): # To upload files. - self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) + self._api.add_route( + str(constants.Endpoint.UPLOAD), + upload(self), + methods=["POST"], + ) # To access uploaded files. - self.api.mount( + self._api.mount( str(constants.Endpoint.UPLOAD), StaticFiles(directory=get_upload_dir()), name="uploaded_files", @@ -639,17 +700,19 @@ def _add_optional_endpoints(self): upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True) upload_is_used_marker.touch() if codespaces.is_running_in_codespaces(): - self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( - codespaces.auth_codespace + self._api.add_route( + str(constants.Endpoint.AUTH_CODESPACE), + codespaces.auth_codespace, + methods=["GET"], ) if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get(): self.add_all_routes_endpoint() def _add_cors(self): """Add CORS middleware to the app.""" - if not self.api: + if not self._api: return - self.api.add_middleware( + self._api.add_middleware( cors.CORSMiddleware, allow_credentials=True, allow_methods=["*"], @@ -914,7 +977,7 @@ def _setup_admin_dash(self): return # Get the admin dash. - if not self.api: + if not self._api: return admin_dash = self.admin_dash @@ -935,7 +998,7 @@ def _setup_admin_dash(self): view = admin_dash.view_overrides.get(model, ModelView) admin.add_view(view(model)) - admin.mount_to(self.api) + admin.mount_to(self._api) def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]): """Gets the frontend packages to be installed and filters out the unnecessary ones. @@ -1433,12 +1496,15 @@ def _write_stateful_pages_marker(self): def add_all_routes_endpoint(self): """Add an endpoint to the app that returns all the routes.""" - if not self.api: + if not self._api: return - @self.api.get(str(constants.Endpoint.ALL_ROUTES)) - async def all_routes(): - return list(self._unevaluated_pages.keys()) + async def all_routes(_request: Request) -> Response: + return JSONResponse(list(self._unevaluated_pages.keys())) + + self._api.add_route( + str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"] + ) @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: @@ -1693,18 +1759,24 @@ async def process( raise -async def ping() -> str: +async def ping(_request: Request) -> Response: """Test API endpoint. + Args: + _request: The Starlette request object. + Returns: The response. """ - return "pong" + return JSONResponse("pong") -async def health() -> JSONResponse: +async def health(_request: Request) -> JSONResponse: """Health check endpoint to assess the status of the database and Redis services. + Args: + _request: The Starlette request object. + Returns: JSONResponse: A JSON object with the health status: - "status" (bool): Overall health, True if all checks pass. @@ -1746,12 +1818,11 @@ def upload(app: App): The upload function. """ - async def upload_file(request: Request, files: list[FastAPIUploadFile]): + async def upload_file(request: Request): """Upload a file. Args: - request: The FastAPI request object. - files: The file(s) to upload. + request: The Starlette request object. Returns: StreamingResponse yielding newline-delimited JSON of StateUpdate @@ -1764,6 +1835,12 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]): """ from reflex.utils.exceptions import UploadTypeError, UploadValueError + # Get the files from the request. + files = await request.form() + files = files.getlist("files") + if not files: + raise UploadValueError("No files were uploaded.") + token = request.headers.get("reflex-client-token") handler = request.headers.get("reflex-event-handler") @@ -1816,6 +1893,10 @@ async def upload_file(request: Request, files: list[FastAPIUploadFile]): # event is handled. file_copies = [] for file in files: + if not isinstance(file, StarletteUploadFile): + raise UploadValueError( + "Uploaded file is not an UploadFile." + str(file) + ) content_copy = io.BytesIO() content_copy.write(await file.read()) content_copy.seek(0) diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index fe3eb9267da..26ebd934c94 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -9,7 +9,7 @@ import inspect from collections.abc import Callable, Coroutine -from fastapi import FastAPI +from starlette.applications import Starlette from reflex.utils import console from reflex.utils.exceptions import InvalidLifespanTaskTypeError @@ -27,7 +27,7 @@ class LifespanMixin(AppMixin): ) @contextlib.asynccontextmanager - async def _run_lifespan_tasks(self, app: FastAPI): + async def _run_lifespan_tasks(self, app: Starlette): running_tasks = [] try: async with contextlib.AsyncExitStack() as stack: diff --git a/reflex/testing.py b/reflex/testing.py index 6cae0d4629a..a86d9b7217d 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -322,11 +322,11 @@ async def _shutdown(*args, **kwargs) -> None: return _shutdown def _start_backend(self, port: int = 0): - if self.app_instance is None or self.app_instance.api is None: + if self.app_instance is None or self.app_instance._api is None: raise RuntimeError("App was not initialized.") self.backend = uvicorn.Server( uvicorn.Config( - app=self.app_instance.api, + app=self.app_instance._api, host="127.0.0.1", port=port, ) diff --git a/reflex/utils/codespaces.py b/reflex/utils/codespaces.py index 1546156da93..0e4c8702d00 100644 --- a/reflex/utils/codespaces.py +++ b/reflex/utils/codespaces.py @@ -4,7 +4,8 @@ import os -from fastapi.responses import HTMLResponse +from starlette.requests import Request +from starlette.responses import HTMLResponse from reflex.components.base.script import Script from reflex.components.component import Component @@ -74,9 +75,12 @@ def codespaces_auto_redirect() -> list[Component]: return [] -async def auth_codespace() -> HTMLResponse: +async def auth_codespace(_request: Request) -> HTMLResponse: """Page automatically redirecting back to the app after authenticating a codespace port forward. + Args: + _request: The request object. + Returns: An HTML response with an embedded script to redirect back to the app. """ diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 82d89543a24..581d0d7ed83 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -11,11 +11,13 @@ from typing import ( # noqa: UP035 TYPE_CHECKING, Any, + Awaitable, ClassVar, Dict, ForwardRef, List, Literal, + MutableMapping, NoReturn, Tuple, Union, @@ -73,6 +75,13 @@ else: ArgsSpec = Callable[..., list[Any]] +Scope = MutableMapping[str, Any] +Message = MutableMapping[str, Any] + +Receive = Callable[[], Awaitable[Message]] +Send = Callable[[Message], Awaitable[None]] + +ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] PrimitiveToAnnotation = { list: List, # noqa: UP006 diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index ad5f05ea1ea..50eacff6962 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -1,4 +1,4 @@ -"""Test cases for the FastAPI lifespan integration.""" +"""Test cases for the Starlette lifespan integration.""" from collections.abc import Generator diff --git a/tests/units/test_app.py b/tests/units/test_app.py index f252cd4dfe3..1e9fa33f1b3 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -14,8 +14,9 @@ import pytest import sqlmodel -from fastapi import FastAPI, UploadFile from pytest_mock import MockerFixture +from starlette.applications import Starlette +from starlette.datastructures import UploadFile from starlette_admin.auth import AuthProvider from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView @@ -813,8 +814,22 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): filename="image2.jpg", file=bio, ) + + async def form(): + files_mock = unittest.mock.Mock() + + def getlist(key: str): + assert key == "files" + return [file1, file2] + + files_mock.getlist = getlist + + return files_mock + + request_mock.form = form + upload_fn = upload(app) - streaming_response = await upload_fn(request_mock, [file1, file2]) # pyright: ignore [reportFunctionMemberAccess] + streaming_response = await upload_fn(request_mock) async for state_update in streaming_response.body_iterator: assert ( state_update @@ -853,10 +868,23 @@ async def test_upload_file_without_annotation(state, tmp_path, token): "reflex-client-token": token, "reflex-event-handler": f"{state.get_full_name()}.handle_upload2", } - file_mock = unittest.mock.Mock(filename="image1.jpg") + + async def form(): + files_mock = unittest.mock.Mock() + + def getlist(key: str): + assert key == "files" + return [unittest.mock.Mock(filename="image1.jpg")] + + files_mock.getlist = getlist + + return files_mock + + request_mock.form = form + fn = upload(app) with pytest.raises(ValueError) as err: - await fn(request_mock, [file_mock]) + await fn(request_mock) assert ( err.value.args[0] == f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as list[rx.UploadFile]" @@ -887,10 +915,23 @@ async def test_upload_file_background(state, tmp_path, token): "reflex-client-token": token, "reflex-event-handler": f"{state.get_full_name()}.bg_upload", } - file_mock = unittest.mock.Mock(filename="image1.jpg") + + async def form(): + files_mock = unittest.mock.Mock() + + def getlist(key: str): + assert key == "files" + return [unittest.mock.Mock(filename="image1.jpg")] + + files_mock.getlist = getlist + + return files_mock + + request_mock.form = form + fn = upload(app) with pytest.raises(TypeError) as err: - await fn(request_mock, [file_mock]) + await fn(request_mock) assert ( err.value.args[0] == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`." @@ -1462,7 +1503,7 @@ def test_call_app(): """Test that the app can be called.""" app = App() api = app() - assert isinstance(api, FastAPI) + assert isinstance(api, Starlette) def test_app_with_optional_endpoints(): diff --git a/tests/units/test_health_endpoint.py b/tests/units/test_health_endpoint.py index 514d19d031f..52230e7600c 100644 --- a/tests/units/test_health_endpoint.py +++ b/tests/units/test_health_endpoint.py @@ -119,8 +119,10 @@ async def test_health( return_value={"redis": redis_status}, ) + request = Mock() + # Call the async health function - response = await health() + response = await health(request) # Verify the response content and status code assert response.status_code == expected_code