diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index a093fe2d2fd..38405f058c3 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -17,6 +17,22 @@ oidc_cache = DualCache() +def _get_oidc_http_handler(timeout: Optional[httpx.Timeout] = None) -> HTTPHandler: + """ + Factory function to create HTTPHandler for OIDC requests. + This function can be mocked in tests. + + Args: + timeout: Optional timeout for HTTP requests. Defaults to 600.0 seconds with 5.0 connect timeout. + + Returns: + HTTPHandler instance configured for OIDC requests. + """ + if timeout is None: + timeout = httpx.Timeout(timeout=600.0, connect=5.0) + return HTTPHandler(timeout=timeout) + + ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there @@ -103,7 +119,7 @@ def get_secret( # noqa: PLR0915 if oidc_token is not None: return oidc_token - oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + oidc_client = _get_oidc_http_handler() # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature response = oidc_client.get( "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", @@ -141,7 +157,7 @@ def get_secret( # noqa: PLR0915 if oidc_token is not None: return oidc_token - oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + oidc_client = _get_oidc_http_handler() response = oidc_client.get( actions_id_token_request_url, params={"audience": oidc_aud}, diff --git a/tests/test_litellm/containers/test_container_integration.py b/tests/test_litellm/containers/test_container_integration.py index e83ae921c19..d36918c63b9 100644 --- a/tests/test_litellm/containers/test_container_integration.py +++ b/tests/test_litellm/containers/test_container_integration.py @@ -357,15 +357,15 @@ def test_container_workflow_simulation(self): def test_error_handling_integration(self): """Test error handling in the integration flow.""" - with patch('litellm.containers.main.base_llm_http_handler') as mock_handler: - # Simulate an API error - mock_handler.container_create_handler.side_effect = litellm.APIError( - status_code=400, - message="API Error occurred", - llm_provider="openai", - model="" - ) - + # Simulate an API error + api_error = litellm.APIError( + status_code=400, + message="API Error occurred", + llm_provider="openai", + model="" + ) + + with patch.object(litellm.main.base_llm_http_handler, 'container_create_handler', side_effect=api_error): with pytest.raises(litellm.APIError): create_container( name="Error Test Container", @@ -385,12 +385,12 @@ def test_provider_support(self, provider): name="Provider Test Container" ) - with patch('litellm.containers.main.base_llm_http_handler') as mock_handler: - mock_handler.container_create_handler.return_value = mock_response - + with patch.object(litellm.main.base_llm_http_handler, 'container_create_handler', return_value=mock_response) as mock_handler: response = create_container( name="Provider Test Container", custom_llm_provider=provider ) assert response.name == "Provider Test Container" + # Verify the mock was actually called (not making real API calls) + mock_handler.assert_called_once() diff --git a/tests/test_litellm/secret_managers/test_secret_managers_main.py b/tests/test_litellm/secret_managers/test_secret_managers_main.py index 2e5270b5d70..9246b2a9ef8 100644 --- a/tests/test_litellm/secret_managers/test_secret_managers_main.py +++ b/tests/test_litellm/secret_managers/test_secret_managers_main.py @@ -47,11 +47,11 @@ def mock_env(): @patch("litellm.secret_managers.main.oidc_cache") -@patch("litellm.secret_managers.main.HTTPHandler") -def test_oidc_google_success(mock_http_handler, mock_oidc_cache): +@patch("litellm.secret_managers.main._get_oidc_http_handler") +def test_oidc_google_success(mock_get_http_handler, mock_oidc_cache): mock_oidc_cache.get_cache.return_value = None mock_handler = MockHTTPHandler(timeout=600.0) - mock_http_handler.return_value = mock_handler + mock_get_http_handler.return_value = mock_handler secret_name = "oidc/google/[invalid url, do not cite]" result = get_secret(secret_name) @@ -67,12 +67,12 @@ def test_oidc_google_cached(mock_oidc_cache): mock_oidc_cache.get_cache.return_value = "cached_token" secret_name = "oidc/google/[invalid url, do not cite]" - with patch("litellm.secret_managers.main.HTTPHandler") as mock_http: + with patch("litellm.secret_managers.main._get_oidc_http_handler") as mock_get_http: result = get_secret(secret_name) assert result == "cached_token", f"Expected cached token, got {result}" mock_oidc_cache.get_cache.assert_called_with(key=secret_name) - mock_http.assert_not_called() + mock_get_http.assert_not_called() @patch("litellm.secret_managers.main.oidc_cache") @@ -80,7 +80,7 @@ def test_oidc_google_failure(mock_oidc_cache): mock_handler = MockHTTPHandler(timeout=600.0) mock_handler.status_code = 400 - with patch("litellm.secret_managers.main.HTTPHandler", return_value=mock_handler): + with patch("litellm.secret_managers.main._get_oidc_http_handler", return_value=mock_handler): mock_oidc_cache.get_cache.return_value = None secret_name = "oidc/google/https://example.com/api" @@ -106,13 +106,13 @@ def test_oidc_circleci_failure(monkeypatch): @patch("litellm.secret_managers.main.oidc_cache") -@patch("litellm.secret_managers.main.HTTPHandler") -def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env): +@patch("litellm.secret_managers.main._get_oidc_http_handler") +def test_oidc_github_success(mock_get_http_handler, mock_oidc_cache, mock_env): mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token" mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token" mock_oidc_cache.get_cache.return_value = None mock_handler = MockHTTPHandler(timeout=600.0) - mock_http_handler.return_value = mock_handler + mock_get_http_handler.return_value = mock_handler secret_name = "oidc/github/github-audience" result = get_secret(secret_name) diff --git a/tests/test_litellm/test_main.py b/tests/test_litellm/test_main.py index d1185c72a29..80fd9f61298 100644 --- a/tests/test_litellm/test_main.py +++ b/tests/test_litellm/test_main.py @@ -477,6 +477,12 @@ async def test_openai_env_base( respx_mock: respx.MockRouter, env_base, openai_api_response, monkeypatch ): "This tests OpenAI env variables are honored, including legacy OPENAI_API_BASE" + # Clear cache to ensure no cached clients from previous tests interfere + # This prevents cache pollution where a previous test cached a client with + # aiohttp transport, which would bypass respx mocks + if hasattr(litellm, "in_memory_llm_clients_cache"): + litellm.in_memory_llm_clients_cache.flush_cache() + # Ensure aiohttp transport is disabled to use httpx which respx can mock litellm.disable_aiohttp_transport = True