Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import is_safe_url
from airflow.api_fastapi.core_api.security import AuthManagerDep, is_safe_url

auth_router = AirflowRouter(tags=["Login"], prefix="/auth")

Expand All @@ -30,9 +30,9 @@
"/login",
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
)
def login(request: Request, next: None | str = None) -> RedirectResponse:
def login(request: Request, auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse:
"""Redirect to the login URL depending on the AuthManager configured."""
login_url = request.app.state.auth_manager.get_url_login()
login_url = auth_manager.get_url_login()

if next and not is_safe_url(next, request=request):
raise HTTPException(status_code=400, detail="Invalid or unsafe next URL")
Expand All @@ -47,11 +47,11 @@ def login(request: Request, next: None | str = None) -> RedirectResponse:
"/logout",
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
)
def logout(request: Request, next: None | str = None) -> RedirectResponse:
def logout(auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse:
"""Logout the user."""
logout_url = request.app.state.auth_manager.get_url_logout()
logout_url = auth_manager.get_url_logout()

if not logout_url:
logout_url = request.app.state.auth_manager.get_url_login()
logout_url = auth_manager.get_url_login()

return RedirectResponse(logout_url)
45 changes: 28 additions & 17 deletions airflow-core/src/airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from pydantic import NonNegativeInt

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
from airflow.api_fastapi.auth.managers.base_auth_manager import (
COOKIE_NAME_JWT_TOKEN,
BaseAuthManager,
)
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
from airflow.api_fastapi.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
Expand Down Expand Up @@ -70,7 +73,20 @@
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.sql import Select

from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod


def auth_manager_from_app(request: Request) -> BaseAuthManager:
"""
FastAPI dependency resolver that returns the shared AuthManager instance from app.state.

This ensures that all API routes using AuthManager via dependency injection receive the same
singleton instance that was initialized at app startup.
"""
return request.app.state.auth_manager


AuthManagerDep = Annotated[BaseAuthManager, Depends(auth_manager_from_app)]

auth_description = (
"To authenticate Airflow API requests, clients must include a JWT (JSON Web Token) in "
Expand Down Expand Up @@ -196,7 +212,7 @@ def to_orm(self, select: Select) -> Select:

def permitted_dag_filter_factory(
method: ResourceMethod, filter_class=PermittedDagFilter
) -> Callable[[Request, BaseUser], PermittedDagFilter]:
) -> Callable[[BaseUser, BaseAuthManager], PermittedDagFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user.

Expand All @@ -205,10 +221,9 @@ def permitted_dag_filter_factory(
"""

def depends_permitted_dags_filter(
request: Request,
user: GetUserDep,
auth_manager: AuthManagerDep,
) -> PermittedDagFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_dags: set[str] = auth_manager.get_authorized_dag_ids(user=user, method=method)
return filter_class(authorized_dags)

Expand Down Expand Up @@ -260,18 +275,17 @@ def to_orm(self, select: Select) -> Select:

def permitted_pool_filter_factory(
method: ResourceMethod,
) -> Callable[[Request, BaseUser], PermittedPoolFilter]:
) -> Callable[[BaseUser, BaseAuthManager], PermittedPoolFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the permitted pools for the user.

:param method: whether filter readable or writable.
"""

def depends_permitted_pools_filter(
request: Request,
user: GetUserDep,
auth_manager: AuthManagerDep,
) -> PermittedPoolFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_pools: set[str] = auth_manager.get_authorized_pools(user=user, method=method)
return PermittedPoolFilter(authorized_pools)

Expand Down Expand Up @@ -353,18 +367,17 @@ def to_orm(self, select: Select) -> Select:

def permitted_connection_filter_factory(
method: ResourceMethod,
) -> Callable[[Request, BaseUser], PermittedConnectionFilter]:
) -> Callable[[BaseUser, BaseAuthManager], PermittedConnectionFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the permitted connections for the user.

:param method: whether filter readable or writable.
"""

def depends_permitted_connections_filter(
request: Request,
user: GetUserDep,
auth_manager: AuthManagerDep,
) -> PermittedConnectionFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_connections: set[str] = auth_manager.get_authorized_connections(user=user, method=method)
return PermittedConnectionFilter(authorized_connections)

Expand Down Expand Up @@ -470,14 +483,13 @@ def to_orm(self, select: Select) -> Select:
return select.where(Team.name.in_(self.value))


def permitted_team_filter_factory() -> Callable[[Request, BaseUser], PermittedTeamFilter]:
def permitted_team_filter_factory() -> Callable[[BaseUser, BaseAuthManager], PermittedTeamFilter]:
"""Create a callable for Depends in FastAPI that returns a filter of the permitted teams for the user."""

def depends_permitted_teams_filter(
request: Request,
user: GetUserDep,
auth_manager: AuthManagerDep,
) -> PermittedTeamFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_teams: set[str] = auth_manager.get_authorized_teams(user=user, method="GET")
return PermittedTeamFilter(authorized_teams)

Expand All @@ -496,18 +508,17 @@ def to_orm(self, select: Select) -> Select:

def permitted_variable_filter_factory(
method: ResourceMethod,
) -> Callable[[Request, BaseUser], PermittedVariableFilter]:
) -> Callable[[BaseUser, BaseAuthManager], PermittedVariableFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the permitted variables for the user.

:param method: whether filter readable or writable.
"""

def depends_permitted_variables_filter(
request: Request,
user: GetUserDep,
auth_manager: AuthManagerDep,
) -> PermittedVariableFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_variables: set[str] = auth_manager.get_authorized_variables(user=user, method=method)
return PermittedVariableFilter(authorized_variables)

Expand Down
37 changes: 37 additions & 0 deletions airflow-core/tests/unit/api_fastapi/core_api/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,40 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to
],
user=user,
)


class TestAuthManagerDependency:
"""Test the auth_manager_from_app dependency function."""

def test_auth_manager_from_app_returns_instance_from_state(self):
"""Test that auth_manager_from_app correctly retrieves auth_manager from app.state."""
from airflow.api_fastapi.core_api.security import auth_manager_from_app

# Create a mock auth manager
mock_auth_manager = Mock()

# Create a mock request with app.state.auth_manager
mock_request = Mock()
mock_request.app.state.auth_manager = mock_auth_manager

# Call the dependency function
result = auth_manager_from_app(mock_request)

# Assert it returns the correct auth manager
assert result is mock_auth_manager

def test_auth_manager_from_app_integration_with_test_client(self, test_client):
"""Test that auth_manager_from_app works with the test client setup."""
from airflow.api_fastapi.core_api.security import auth_manager_from_app

# Create a mock request using the test client's app
mock_request = Mock()
mock_request.app = test_client.app

# Get the auth manager
auth_manager = auth_manager_from_app(mock_request)

# Verify it's not None (should be SimpleAuthManager from test fixture)
assert auth_manager is not None
assert hasattr(auth_manager, "get_url_login")
assert hasattr(auth_manager, "get_url_logout")
Loading