Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions litellm/secret_managers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@
oidc_cache = DualCache()


def _get_oidc_http_handler(timeout: Optional[httpx.Timeout] = None) -> HTTPHandler:
"""
Factory function to create HTTPHandler for OIDC requests.
This function can be mocked in tests.

Args:
timeout: Optional timeout for HTTP requests. Defaults to 600.0 seconds with 5.0 connect timeout.

Returns:
HTTPHandler instance configured for OIDC requests.
"""
if timeout is None:
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
return HTTPHandler(timeout=timeout)


######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
Expand Down Expand Up @@ -103,7 +119,7 @@ def get_secret( # noqa: PLR0915
if oidc_token is not None:
return oidc_token

oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
oidc_client = _get_oidc_http_handler()
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
Expand Down Expand Up @@ -141,7 +157,7 @@ def get_secret( # noqa: PLR0915
if oidc_token is not None:
return oidc_token

oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
oidc_client = _get_oidc_http_handler()
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
Expand Down
24 changes: 12 additions & 12 deletions tests/test_litellm/containers/test_container_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,15 @@ def test_container_workflow_simulation(self):

def test_error_handling_integration(self):
"""Test error handling in the integration flow."""
with patch('litellm.containers.main.base_llm_http_handler') as mock_handler:
# Simulate an API error
mock_handler.container_create_handler.side_effect = litellm.APIError(
status_code=400,
message="API Error occurred",
llm_provider="openai",
model=""
)
# Simulate an API error
api_error = litellm.APIError(
status_code=400,
message="API Error occurred",
llm_provider="openai",
model=""
)

with patch.object(litellm.main.base_llm_http_handler, 'container_create_handler', side_effect=api_error):
with pytest.raises(litellm.APIError):
create_container(
name="Error Test Container",
Expand All @@ -385,12 +385,12 @@ def test_provider_support(self, provider):
name="Provider Test Container"
)

with patch('litellm.containers.main.base_llm_http_handler') as mock_handler:
mock_handler.container_create_handler.return_value = mock_response

with patch.object(litellm.main.base_llm_http_handler, 'container_create_handler', return_value=mock_response) as mock_handler:
response = create_container(
name="Provider Test Container",
custom_llm_provider=provider
)

assert response.name == "Provider Test Container"
# Verify the mock was actually called (not making real API calls)
mock_handler.assert_called_once()
18 changes: 9 additions & 9 deletions tests/test_litellm/secret_managers/test_secret_managers_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def mock_env():


@patch("litellm.secret_managers.main.oidc_cache")
@patch("litellm.secret_managers.main.HTTPHandler")
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
@patch("litellm.secret_managers.main._get_oidc_http_handler")
def test_oidc_google_success(mock_get_http_handler, mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
mock_get_http_handler.return_value = mock_handler
secret_name = "oidc/google/[invalid url, do not cite]"
result = get_secret(secret_name)

Expand All @@ -67,20 +67,20 @@ def test_oidc_google_cached(mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = "cached_token"

secret_name = "oidc/google/[invalid url, do not cite]"
with patch("litellm.secret_managers.main.HTTPHandler") as mock_http:
with patch("litellm.secret_managers.main._get_oidc_http_handler") as mock_get_http:
result = get_secret(secret_name)

assert result == "cached_token", f"Expected cached token, got {result}"
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
mock_http.assert_not_called()
mock_get_http.assert_not_called()


@patch("litellm.secret_managers.main.oidc_cache")
def test_oidc_google_failure(mock_oidc_cache):
mock_handler = MockHTTPHandler(timeout=600.0)
mock_handler.status_code = 400

with patch("litellm.secret_managers.main.HTTPHandler", return_value=mock_handler):
with patch("litellm.secret_managers.main._get_oidc_http_handler", return_value=mock_handler):
mock_oidc_cache.get_cache.return_value = None
secret_name = "oidc/google/https://example.com/api"

Expand All @@ -106,13 +106,13 @@ def test_oidc_circleci_failure(monkeypatch):


@patch("litellm.secret_managers.main.oidc_cache")
@patch("litellm.secret_managers.main.HTTPHandler")
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
@patch("litellm.secret_managers.main._get_oidc_http_handler")
def test_oidc_github_success(mock_get_http_handler, mock_oidc_cache, mock_env):
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
mock_get_http_handler.return_value = mock_handler

secret_name = "oidc/github/github-audience"
result = get_secret(secret_name)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_litellm/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,12 @@ async def test_openai_env_base(
respx_mock: respx.MockRouter, env_base, openai_api_response, monkeypatch
):
"This tests OpenAI env variables are honored, including legacy OPENAI_API_BASE"
# Clear cache to ensure no cached clients from previous tests interfere
# This prevents cache pollution where a previous test cached a client with
# aiohttp transport, which would bypass respx mocks
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()

# Ensure aiohttp transport is disabled to use httpx which respx can mock
litellm.disable_aiohttp_transport = True

Expand Down
Loading