Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions litellm/proxy/management_endpoints/ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,15 @@ def process_sso_jwt_access_token(
if access_token_str and result:
import jwt

access_token_payload = jwt.decode(
access_token_str, options={"verify_signature": False}
)
try:
access_token_payload = jwt.decode(
access_token_str, options={"verify_signature": False}
)
except jwt.exceptions.DecodeError:
verbose_proxy_logger.debug(
"Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction"
)
return

# Extract team IDs from access token if sso_jwt_handler is available
if sso_jwt_handler:
Expand Down
193 changes: 124 additions & 69 deletions tests/test_litellm/proxy/management_endpoints/test_ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,47 +2374,6 @@ def sample_jwt_payload(self):
"groups": ["team1", "team2", "team3"],
}

def test_process_sso_jwt_access_token_with_valid_token(
self, mock_jwt_handler, sample_jwt_token, sample_jwt_payload
):
"""Test processing a valid JWT access token with team extraction"""
from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

# Create a result object without team_ids
result = CustomOpenID(
id="test_user",
email="test@example.com",
first_name="Test",
last_name="User",
display_name="Test User",
provider="generic",
team_ids=[],
)

with patch("jwt.decode", return_value=sample_jwt_payload) as mock_jwt_decode:
# Act
process_sso_jwt_access_token(
access_token_str=sample_jwt_token,
sso_jwt_handler=mock_jwt_handler,
result=result,
)

# Assert
# Verify JWT was decoded correctly
mock_jwt_decode.assert_called_once_with(
sample_jwt_token, options={"verify_signature": False}
)

# Verify team IDs were extracted from JWT
mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with(
sample_jwt_payload
)

# Verify team IDs were set on the result object
assert result.team_ids == ["team1", "team2", "team3"]

def test_process_sso_jwt_access_token_with_existing_team_ids(
self, mock_jwt_handler, sample_jwt_token
):
Expand Down Expand Up @@ -2549,27 +2508,6 @@ def test_process_sso_jwt_access_token_no_access_token(self, mock_jwt_handler):
mock_jwt_handler.get_team_ids_from_jwt.assert_not_called()
assert result.team_ids == []

def test_process_sso_jwt_access_token_no_sso_jwt_handler(self, sample_jwt_token):
"""Test that JWT is decoded for role extraction even when sso_jwt_handler is None,
but team_ids are not extracted (team extraction requires sso_jwt_handler)."""
from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[])

mock_payload = {"sub": "test_user", "email": "test@example.com"}
with patch("jwt.decode", return_value=mock_payload) as mock_jwt_decode:
# Act
process_sso_jwt_access_token(
access_token_str=sample_jwt_token, sso_jwt_handler=None, result=result
)

# JWT is decoded (for role extraction) but team_ids are not extracted
mock_jwt_decode.assert_called_once()
assert result.team_ids == []
assert result.user_role is None

def test_process_sso_jwt_access_token_no_result(
self, mock_jwt_handler, sample_jwt_token
):
Expand All @@ -2590,30 +2528,29 @@ def test_process_sso_jwt_access_token_no_result(
mock_jwt_decode.assert_not_called()
mock_jwt_handler.get_team_ids_from_jwt.assert_not_called()

def test_process_sso_jwt_access_token_jwt_decode_exception(
def test_process_sso_jwt_access_token_non_decode_exception_propagates(
self, mock_jwt_handler, sample_jwt_token
):
"""Test that JWT decode exceptions are not caught (should propagate up)"""
"""Test that non-DecodeError JWT exceptions still propagate up."""
import jwt as pyjwt

from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[])

with patch(
"jwt.decode", side_effect=Exception("JWT decode error")
"jwt.decode", side_effect=pyjwt.exceptions.InvalidKeyError("Invalid key")
) as mock_jwt_decode:
# Act & Assert
with pytest.raises(Exception, match="JWT decode error"):
with pytest.raises(pyjwt.exceptions.InvalidKeyError, match="Invalid key"):
process_sso_jwt_access_token(
access_token_str=sample_jwt_token,
sso_jwt_handler=mock_jwt_handler,
result=result,
)

# Verify JWT decode was attempted
mock_jwt_decode.assert_called_once()
# But team extraction should not have been called
mock_jwt_handler.get_team_ids_from_jwt.assert_not_called()

def test_process_sso_jwt_access_token_empty_team_ids_from_jwt(
Expand Down Expand Up @@ -2646,6 +2583,124 @@ def test_process_sso_jwt_access_token_empty_team_ids_from_jwt(
# Even empty team IDs should be set
assert result.team_ids == []

def test_process_sso_jwt_access_token_with_opaque_token(self, mock_jwt_handler):
"""Test that opaque (non-JWT) access tokens are handled gracefully without raising."""
from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

result = CustomOpenID(
id="test_user",
email="test@example.com",
first_name="Test",
last_name="User",
display_name="Test User",
provider="generic",
team_ids=["existing_team"],
user_role=None,
)

# Opaque tokens like those from Logto are short random strings, not JWTs
opaque_token = "uTxyjXbS_random_opaque_token_string"

# Should NOT raise - opaque tokens should be silently skipped
process_sso_jwt_access_token(
access_token_str=opaque_token,
sso_jwt_handler=mock_jwt_handler,
result=result,
)

# Result should be untouched
mock_jwt_handler.get_team_ids_from_jwt.assert_not_called()
assert result.team_ids == ["existing_team"]
assert result.user_role is None

def test_process_sso_jwt_access_token_real_jwt_with_role_and_teams(
self, mock_jwt_handler
):
"""Test that a real JWT containing role and team fields is correctly processed."""
import jwt as pyjwt

from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

payload = {
"sub": "user123",
"email": "admin@example.com",
"role": "proxy_admin",
"groups": ["team_alpha", "team_beta"],
}
real_jwt_token = pyjwt.encode(payload, "test-secret", algorithm="HS256")

mock_jwt_handler.get_team_ids_from_jwt.return_value = [
"team_alpha",
"team_beta",
]

result = CustomOpenID(
id="user123",
email="admin@example.com",
first_name="Admin",
last_name="User",
display_name="Admin User",
provider="generic",
team_ids=[],
user_role=None,
)

process_sso_jwt_access_token(
access_token_str=real_jwt_token,
sso_jwt_handler=mock_jwt_handler,
result=result,
)

# Team IDs should be extracted via sso_jwt_handler
mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with(payload)
assert result.team_ids == ["team_alpha", "team_beta"]

# Role should be extracted from the "role" field in the JWT
from litellm.proxy._types import LitellmUserRoles

assert result.user_role == LitellmUserRoles.PROXY_ADMIN

def test_process_sso_jwt_access_token_real_jwt_without_role_and_teams(self):
"""Test that a real JWT without role/team fields leaves result unchanged."""
import jwt as pyjwt

from litellm.proxy.management_endpoints.ui_sso import (
process_sso_jwt_access_token,
)

payload = {
"sub": "user456",
"email": "plain@example.com",
"iat": 1700000000,
}
real_jwt_token = pyjwt.encode(payload, "test-secret", algorithm="HS256")

result = CustomOpenID(
id="user456",
email="plain@example.com",
first_name="Plain",
last_name="User",
display_name="Plain User",
provider="generic",
team_ids=[],
user_role=None,
)

# No sso_jwt_handler, no role/team fields in JWT
process_sso_jwt_access_token(
access_token_str=real_jwt_token,
sso_jwt_handler=None,
result=result,
)

# Nothing should be modified
assert result.team_ids == []
assert result.user_role is None


@pytest.mark.asyncio
async def test_get_ui_settings_includes_api_doc_base_url():
Expand Down
Loading