diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py index b8f6d204d2ed1..59610e0fc9469 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py @@ -21,7 +21,7 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import is_safe_url +from airflow.api_fastapi.core_api.security import AuthManagerDep, is_safe_url auth_router = AirflowRouter(tags=["Login"], prefix="/auth") @@ -30,9 +30,9 @@ "/login", responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]), ) -def login(request: Request, next: None | str = None) -> RedirectResponse: +def login(request: Request, auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse: """Redirect to the login URL depending on the AuthManager configured.""" - login_url = request.app.state.auth_manager.get_url_login() + login_url = auth_manager.get_url_login() if next and not is_safe_url(next, request=request): raise HTTPException(status_code=400, detail="Invalid or unsafe next URL") @@ -47,11 +47,11 @@ def login(request: Request, next: None | str = None) -> RedirectResponse: "/logout", responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]), ) -def logout(request: Request, next: None | str = None) -> RedirectResponse: +def logout(auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse: """Logout the user.""" - logout_url = request.app.state.auth_manager.get_url_logout() + logout_url = auth_manager.get_url_logout() if not logout_url: - logout_url = request.app.state.auth_manager.get_url_login() + logout_url = auth_manager.get_url_login() return RedirectResponse(logout_url) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index 5e7d676bf9da2..05ad624353b58 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -27,7 +27,10 @@ from pydantic import NonNegativeInt from airflow.api_fastapi.app import get_auth_manager -from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN +from airflow.api_fastapi.auth.managers.base_auth_manager import ( + COOKIE_NAME_JWT_TOKEN, + BaseAuthManager, +) from airflow.api_fastapi.auth.managers.models.base_user import BaseUser from airflow.api_fastapi.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, @@ -70,7 +73,20 @@ from fastapi.security import HTTPAuthorizationCredentials from sqlalchemy.sql import Select - from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod + from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod + + +def auth_manager_from_app(request: Request) -> BaseAuthManager: + """ + FastAPI dependency resolver that returns the shared AuthManager instance from app.state. + + This ensures that all API routes using AuthManager via dependency injection receive the same + singleton instance that was initialized at app startup. + """ + return request.app.state.auth_manager + + +AuthManagerDep = Annotated[BaseAuthManager, Depends(auth_manager_from_app)] auth_description = ( "To authenticate Airflow API requests, clients must include a JWT (JSON Web Token) in " @@ -196,7 +212,7 @@ def to_orm(self, select: Select) -> Select: def permitted_dag_filter_factory( method: ResourceMethod, filter_class=PermittedDagFilter -) -> Callable[[Request, BaseUser], PermittedDagFilter]: +) -> Callable[[BaseUser, BaseAuthManager], PermittedDagFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. @@ -205,10 +221,9 @@ def permitted_dag_filter_factory( """ def depends_permitted_dags_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedDagFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_dags: set[str] = auth_manager.get_authorized_dag_ids(user=user, method=method) return filter_class(authorized_dags) @@ -260,7 +275,7 @@ def to_orm(self, select: Select) -> Select: def permitted_pool_filter_factory( method: ResourceMethod, -) -> Callable[[Request, BaseUser], PermittedPoolFilter]: +) -> Callable[[BaseUser, BaseAuthManager], PermittedPoolFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted pools for the user. @@ -268,10 +283,9 @@ def permitted_pool_filter_factory( """ def depends_permitted_pools_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedPoolFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_pools: set[str] = auth_manager.get_authorized_pools(user=user, method=method) return PermittedPoolFilter(authorized_pools) @@ -353,7 +367,7 @@ def to_orm(self, select: Select) -> Select: def permitted_connection_filter_factory( method: ResourceMethod, -) -> Callable[[Request, BaseUser], PermittedConnectionFilter]: +) -> Callable[[BaseUser, BaseAuthManager], PermittedConnectionFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted connections for the user. @@ -361,10 +375,9 @@ def permitted_connection_filter_factory( """ def depends_permitted_connections_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedConnectionFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_connections: set[str] = auth_manager.get_authorized_connections(user=user, method=method) return PermittedConnectionFilter(authorized_connections) @@ -470,14 +483,13 @@ def to_orm(self, select: Select) -> Select: return select.where(Team.name.in_(self.value)) -def permitted_team_filter_factory() -> Callable[[Request, BaseUser], PermittedTeamFilter]: +def permitted_team_filter_factory() -> Callable[[BaseUser, BaseAuthManager], PermittedTeamFilter]: """Create a callable for Depends in FastAPI that returns a filter of the permitted teams for the user.""" def depends_permitted_teams_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedTeamFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_teams: set[str] = auth_manager.get_authorized_teams(user=user, method="GET") return PermittedTeamFilter(authorized_teams) @@ -496,7 +508,7 @@ def to_orm(self, select: Select) -> Select: def permitted_variable_filter_factory( method: ResourceMethod, -) -> Callable[[Request, BaseUser], PermittedVariableFilter]: +) -> Callable[[BaseUser, BaseAuthManager], PermittedVariableFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted variables for the user. @@ -504,10 +516,9 @@ def permitted_variable_filter_factory( """ def depends_permitted_variables_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedVariableFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_variables: set[str] = auth_manager.get_authorized_variables(user=user, method=method) return PermittedVariableFilter(authorized_variables) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py index 026dd10fa31ea..6aea3b2352882 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py @@ -484,3 +484,40 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to ], user=user, ) + + +class TestAuthManagerDependency: + """Test the auth_manager_from_app dependency function.""" + + def test_auth_manager_from_app_returns_instance_from_state(self): + """Test that auth_manager_from_app correctly retrieves auth_manager from app.state.""" + from airflow.api_fastapi.core_api.security import auth_manager_from_app + + # Create a mock auth manager + mock_auth_manager = Mock() + + # Create a mock request with app.state.auth_manager + mock_request = Mock() + mock_request.app.state.auth_manager = mock_auth_manager + + # Call the dependency function + result = auth_manager_from_app(mock_request) + + # Assert it returns the correct auth manager + assert result is mock_auth_manager + + def test_auth_manager_from_app_integration_with_test_client(self, test_client): + """Test that auth_manager_from_app works with the test client setup.""" + from airflow.api_fastapi.core_api.security import auth_manager_from_app + + # Create a mock request using the test client's app + mock_request = Mock() + mock_request.app = test_client.app + + # Get the auth manager + auth_manager = auth_manager_from_app(mock_request) + + # Verify it's not None (should be SimpleAuthManager from test fixture) + assert auth_manager is not None + assert hasattr(auth_manager, "get_url_login") + assert hasattr(auth_manager, "get_url_logout")