Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions litellm/proxy/management_endpoints/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@ def is_valid_litellm_user_role(role_str: str) -> bool:
return False


def get_litellm_user_role(role_str: str) -> Optional[LitellmUserRoles]:
def get_litellm_user_role(role_str) -> Optional[LitellmUserRoles]:
"""
Convert a string to a LitellmUserRoles enum if valid (case-insensitive).
Convert a string (or list of strings) to a LitellmUserRoles enum if valid (case-insensitive).

Handles list inputs since some SSO providers (e.g., Keycloak) return roles
as arrays like ["proxy_admin"] instead of plain strings.

Args:
role_str: String to convert (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user")
role_str: String or list to convert (e.g., "proxy_admin", ["proxy_admin"])

Returns:
LitellmUserRoles enum if valid, None otherwise
"""
try:
if isinstance(role_str, list):
if len(role_str) == 0:
return None
role_str = role_str[0]
# Use _value2member_map_ for O(1) lookup, case-insensitive
result = LitellmUserRoles._value2member_map_.get(role_str.lower())
return cast(Optional[LitellmUserRoles], result)
Expand Down
84 changes: 67 additions & 17 deletions litellm/proxy/management_endpoints/ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,36 +171,86 @@ def process_sso_jwt_access_token(
access_token_str: Optional[str],
sso_jwt_handler: Optional[JWTHandler],
result: Union[OpenID, dict, None],
role_mappings: Optional["RoleMappings"] = None,
) -> None:
"""
Process SSO JWT access token and extract team IDs if available.
Process SSO JWT access token and extract team IDs and user role if available.

This function decodes the JWT access token and extracts team IDs using the
sso_jwt_handler, then sets the team_ids attribute on the result object.
This function decodes the JWT access token and extracts team IDs and user
role, then sets them on the result object. Role extraction from the access
token is needed because some SSO providers (e.g., Keycloak) do not include
role claims in the UserInfo endpoint response.

Args:
access_token_str: The JWT access token string
sso_jwt_handler: SSO-specific JWT handler for team ID extraction
result: The SSO result object to update with team IDs
result: The SSO result object to update with team IDs and role
role_mappings: Optional role mappings configuration for group-based role determination
"""
if access_token_str and sso_jwt_handler and result:
if access_token_str and result:
import jwt

access_token_payload = jwt.decode(
access_token_str, options={"verify_signature": False}
)

# Handle both dict and object result types
if isinstance(result, dict):
result_team_ids: Optional[List[str]] = result.get("team_ids", [])
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload)
result["team_ids"] = team_ids
else:
result_team_ids = getattr(result, "team_ids", []) if result else []
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload)
setattr(result, "team_ids", team_ids)
# Extract team IDs from access token if sso_jwt_handler is available
if sso_jwt_handler:
if isinstance(result, dict):
result_team_ids: Optional[List[str]] = result.get("team_ids", [])
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload)
result["team_ids"] = team_ids
else:
result_team_ids = getattr(result, "team_ids", []) if result else []
if not result_team_ids:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload)
setattr(result, "team_ids", team_ids)

# Extract user role from access token if not already set from UserInfo
existing_role = result.get("user_role") if isinstance(result, dict) else getattr(result, "user_role", None)
if existing_role is None:
user_role: Optional[LitellmUserRoles] = None

# Try role_mappings first (group-based role determination)
if role_mappings is not None and role_mappings.roles:
group_claim = role_mappings.group_claim
user_groups_raw: Any = get_nested_value(access_token_payload, group_claim)

user_groups: List[str] = []
if isinstance(user_groups_raw, list):
user_groups = [str(g) for g in user_groups_raw]
elif isinstance(user_groups_raw, str):
user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()]
elif user_groups_raw is not None:
user_groups = [str(user_groups_raw)]

if user_groups:
user_role = determine_role_from_groups(user_groups, role_mappings)
verbose_proxy_logger.debug(
f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings"
)
elif role_mappings.default_role:
user_role = role_mappings.default_role

# Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload
if user_role is None:
generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role")
user_role_from_token = get_nested_value(access_token_payload, generic_user_role_attribute_name)
if user_role_from_token is not None:
user_role = get_litellm_user_role(user_role_from_token)
verbose_proxy_logger.debug(
f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'"
)

if user_role is not None:
if isinstance(result, dict):
result["user_role"] = user_role
else:
setattr(result, "user_role", user_role)
verbose_proxy_logger.debug(
f"Set user_role='{user_role}' from JWT access token"
)


@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
Expand Down Expand Up @@ -688,7 +738,7 @@ def response_convertor(response, client):
)

access_token_str: Optional[str] = generic_sso.access_token
process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result)
process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result, role_mappings=role_mappings)

except Exception as e:
verbose_proxy_logger.exception(
Expand Down
Loading
Loading