Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 65 additions & 30 deletions litellm/proxy/management_endpoints/ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import secrets
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast

from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
Expand Down Expand Up @@ -82,7 +82,15 @@
get_server_root_path,
)
from litellm.secret_managers.main import get_secret_bool, str_to_bool
from litellm.types.proxy.management_endpoints.ui_sso import *
from litellm.types.proxy.management_endpoints.ui_sso import (
DefaultTeamSSOParams,
MicrosoftGraphAPIUserGroupDirectoryObject,
MicrosoftGraphAPIUserGroupResponse,
MicrosoftServicePrincipalTeam,
RoleMappings,
TeamMappings,
)
from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401
from litellm.types.proxy.ui_sso import ParsedOpenIDResult

if TYPE_CHECKING:
Expand All @@ -96,15 +104,15 @@
def normalize_email(email: Optional[str]) -> Optional[str]:
"""
Normalize email address to lowercase for consistent storage and comparison.

Email addresses should be treated as case-insensitive for SSO purposes,
even though RFC 5321 technically allows case-sensitive local parts.
This prevents issues where SSO providers return emails with different casing
than what's stored in the database.

Args:
email: Email address to normalize, can be None

Returns:
Lowercased email address, or None if input is None
"""
Expand Down Expand Up @@ -280,7 +288,7 @@ async def google_login(
# check if user defined a custom auth sso sign in handler, if yes, use it
if user_custom_ui_sso_sign_in_handler is not None:
try:
from litellm_enterprise.proxy.auth.custom_sso_handler import (
from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped]
EnterpriseCustomSSOHandler,
)

Expand Down Expand Up @@ -428,7 +436,9 @@ def generic_response_convertor(
display_name=get_nested_value(
response, generic_user_display_name_attribute_name
),
email=normalize_email(get_nested_value(response, generic_user_email_attribute_name)),
email=normalize_email(
get_nested_value(response, generic_user_email_attribute_name)
),
first_name=get_nested_value(response, generic_user_first_name_attribute_name),
last_name=get_nested_value(response, generic_user_last_name_attribute_name),
provider=get_nested_value(response, generic_provider_attribute_name),
Expand Down Expand Up @@ -517,6 +527,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]:

if team_mappings_data:
from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings

if isinstance(team_mappings_data, dict):
team_mappings = TeamMappings(**team_mappings_data)
elif isinstance(team_mappings_data, TeamMappings):
Expand Down Expand Up @@ -554,6 +565,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:

if role_mappings_data:
from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings

if isinstance(role_mappings_data, dict):
role_mappings = RoleMappings(**role_mappings_data)
elif isinstance(role_mappings_data, RoleMappings):
Expand All @@ -567,7 +579,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:
verbose_proxy_logger.debug(
f"Could not load role_mappings from database: {e}. Continuing with existing role logic."
)

generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None)
generic_role_mappings_group_claim = os.getenv(
"GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None
Expand All @@ -577,8 +589,8 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:
)
if generic_role_mappings is not None:
verbose_proxy_logger.debug(
"Found role_mappings for generic provider in environment variables"
)
"Found role_mappings for generic provider in environment variables"
)
import ast

try:
Expand All @@ -603,7 +615,9 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:
)
return role_mappings
except TypeError as e:
verbose_proxy_logger.warning(f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic.")
verbose_proxy_logger.warning(
f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic."
)
return role_mappings


Expand Down Expand Up @@ -680,7 +694,7 @@ def response_convertor(response, client):
try:
result = await generic_sso.verify_and_process(
request,
params=SSOAuthenticationHandler.prepare_token_exchange_parameters(
params=await SSOAuthenticationHandler.prepare_token_exchange_parameters(
request=request,
generic_include_client_id=generic_include_client_id,
),
Expand Down Expand Up @@ -875,7 +889,7 @@ def _build_sso_user_update_data(

Returns:
dict: Update data containing user_email and optionally user_role if valid
"""
"""
update_data: dict = {"user_email": normalize_email(user_email)}

# Get SSO role from result and include if valid
Expand Down Expand Up @@ -1673,7 +1687,7 @@ async def get_generic_sso_redirect_response(
"""
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from litellm.proxy.proxy_server import user_api_key_cache
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache

with generic_sso:
# TODO: state should be a random string and added to the user session with cookie
Expand Down Expand Up @@ -1702,13 +1716,21 @@ async def get_generic_sso_redirect_response(

# If PKCE is enabled, add PKCE parameters to the redirect URL
if code_verifier and "state" in redirect_params:
# Store code_verifier in cache (10 min TTL)
# Store code_verifier in cache (10 min TTL). Use Redis when available
# so callbacks landing on another pod can retrieve it (multi-pod SSO).
cache_key = f"pkce_verifier:{redirect_params['state']}"
user_api_key_cache.set_cache(
key=cache_key,
value=code_verifier,
ttl=600,
)
if redis_usage_cache is not None:
await redis_usage_cache.async_set_cache(
key=cache_key,
value=code_verifier,
ttl=600,
)
else:
await user_api_key_cache.async_set_cache(
key=cache_key,
value=code_verifier,
ttl=600,
)

# Add PKCE parameters to the authorization URL
if pkce_params:
Expand Down Expand Up @@ -2305,7 +2327,7 @@ async def get_redirect_response_from_openid( # noqa: PLR0915
return redirect_response

@staticmethod
def prepare_token_exchange_parameters(
async def prepare_token_exchange_parameters(
request: Request,
generic_include_client_id: bool,
) -> dict:
Expand All @@ -2319,27 +2341,38 @@ def prepare_token_exchange_parameters(
Returns:
dict: Token exchange parameters
"""
# Prepare token exchange parameters
token_params = {"include_client_id": generic_include_client_id}
# Prepare token exchange parameters (may add code_verifier: str later)
token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id}

# Retrieve PKCE code_verifier if PKCE was used in authorization
# Retrieve PKCE code_verifier if PKCE was used in authorization.
# Use same cache as store: Redis when available (multi-pod), else in-memory.
query_params = dict(request.query_params)
state = query_params.get("state")
if state:
from litellm.proxy.proxy_server import user_api_key_cache
from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache

cache_key = f"pkce_verifier:{state}"
code_verifier = user_api_key_cache.get_cache(key=cache_key)
if redis_usage_cache is not None:
code_verifier = await redis_usage_cache.async_get_cache(key=cache_key)
else:
code_verifier = await user_api_key_cache.async_get_cache(key=cache_key)

if code_verifier:
# Add code_verifier to token exchange parameters
token_params["code_verifier"] = code_verifier
# Add code_verifier to token exchange parameters (Redis returns decoded string)
token_params["code_verifier"] = (
code_verifier
if isinstance(code_verifier, str)
else str(code_verifier)
)
verbose_proxy_logger.debug(
"PKCE code_verifier retrieved and will be included in token exchange"
)

# Clean up the cache entry (single-use verifier)
user_api_key_cache.delete_cache(key=cache_key)
if redis_usage_cache is not None:
await redis_usage_cache.async_delete_cache(key=cache_key)
else:
await user_api_key_cache.async_delete_cache(key=cache_key)
return token_params

@staticmethod
Expand Down Expand Up @@ -2482,7 +2515,9 @@ def openid_from_response(
response = response or {}
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
openid_response = CustomOpenID(
email=normalize_email(response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")),
email=normalize_email(
response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")
),
display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE),
provider="microsoft",
id=response.get(MICROSOFT_USER_ID_ATTRIBUTE),
Expand Down
Loading
Loading