Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ async def common_checks(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)

# 1.1. If user is deactivated via SCIM
if user_object is not None and user_object.metadata is not None:
scim_active = user_object.metadata.get("scim_active")
if scim_active is False:
raise ProxyException(
message="User account is deactivated.",
type=ProxyErrorTypes.auth_error,
param="user_id",
code=status.HTTP_401_UNAUTHORIZED,
)

# 2. If team can call model
if _model and team_object:
if not can_team_access_model(
Expand Down
86 changes: 86 additions & 0 deletions tests/test_litellm/proxy/auth/test_auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_log_budget_lookup_failure,
_virtual_key_max_budget_alert_check,
_virtual_key_soft_budget_check,
common_checks,
get_user_object,
vector_store_access_check,
)
Expand Down Expand Up @@ -1414,3 +1415,88 @@ async def test_get_fuzzy_user_object_case_insensitive_email():
assert call_args.kwargs["where"]["user_email"]["equals"] == "test@example.com"
assert call_args.kwargs["where"]["user_email"]["mode"] == "insensitive"
assert call_args.kwargs["include"] == {"organization_memberships": True}



@pytest.mark.asyncio
async def test_scim_deactivated_user_blocked():
"""Test that user with scim_active=False is blocked in common_checks"""
# Create a deactivated user
user_object = LiteLLM_UserTable(
user_id="test-user",
user_email="test@example.com",
metadata={"scim_active": False}
)

# Mock required objects
request = MagicMock()
request_body = {}
team_object = None
end_user_object = None
global_proxy_spend = None
general_settings = {}
route = "/models"
llm_router = None
proxy_logging_obj = MagicMock()
valid_token = MagicMock()

# Should raise ProxyException
with pytest.raises(ProxyException) as exc_info:
await common_checks(
request_body=request_body,
team_object=team_object,
user_object=user_object,
end_user_object=end_user_object,
global_proxy_spend=global_proxy_spend,
general_settings=general_settings,
route=route,
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
request=request,
)

# Verify the error
assert "deactivated" in str(exc_info.value.message).lower()
assert exc_info.value.code == "401"
assert exc_info.value.type == ProxyErrorTypes.auth_error


@pytest.mark.asyncio
async def test_scim_active_user_allowed():
"""Test that user with scim_active=True is allowed"""
# Create an active user
user_object = LiteLLM_UserTable(
user_id="test-user",
user_email="test@example.com",
metadata={"scim_active": True}
)

# Mock required objects
request = MagicMock()
request_body = {}
team_object = None
end_user_object = None
global_proxy_spend = None
general_settings = {}
route = "/models"
llm_router = None
proxy_logging_obj = MagicMock()
valid_token = MagicMock()

# Should NOT raise exception
result = await common_checks(
request_body=request_body,
team_object=team_object,
user_object=user_object,
end_user_object=end_user_object,
global_proxy_spend=global_proxy_spend,
general_settings=general_settings,
route=route,
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
request=request,
)

assert result is True
Loading