Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions litellm/responses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ async def aresponses(
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
# Update local_vars with detected provider (fixes #19782)
local_vars["custom_llm_provider"] = custom_llm_provider

func = partial(
responses,
Expand Down Expand Up @@ -583,6 +585,9 @@ def responses(
api_key=litellm_params.api_key,
)

# Update local_vars with detected provider (fixes #19782)
local_vars["custom_llm_provider"] = custom_llm_provider

# Use dynamic credentials from get_llm_provider (e.g., when use_litellm_proxy=True)
if dynamic_api_key is not None:
litellm_params.api_key = dynamic_api_key
Expand Down Expand Up @@ -1411,6 +1416,8 @@ async def acompact_responses(
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
# Update local_vars with detected provider (fixes #19782)
local_vars["custom_llm_provider"] = custom_llm_provider

func = partial(
compact_responses,
Expand Down Expand Up @@ -1498,6 +1505,9 @@ def compact_responses(
api_key=litellm_params.api_key,
)

# Update local_vars with detected provider (fixes #19782)
local_vars["custom_llm_provider"] = custom_llm_provider

# Use dynamic credentials from get_llm_provider (e.g., when use_litellm_proxy=True)
if dynamic_api_key is not None:
litellm_params.api_key = dynamic_api_key
Expand Down
43 changes: 43 additions & 0 deletions tests/test_litellm/responses/test_responses_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,46 @@ def test_transform_response_api_usage_mixed_details(self):
assert result.completion_tokens_details.reasoning_tokens == 30
assert result.completion_tokens_details.image_tokens == 100
assert result.completion_tokens_details.text_tokens == 70


class TestResponsesAPIProviderSpecificParams:
"""
Tests for fix #19782: provider-specific params (aws_*, vertex_*) should work
without explicitly passing custom_llm_provider.
"""

def test_provider_specific_params_no_crash_with_bedrock(self):
"""Test that processing aws_* params with bedrock provider doesn't crash."""
params = {
"temperature": 0.7,
"custom_llm_provider": "bedrock",
"kwargs": {"aws_region_name": "eu-central-1"},
}

# Should not raise any exception
result = ResponsesAPIRequestUtils.get_requested_response_api_optional_param(params)
assert "temperature" in result

def test_provider_specific_params_no_crash_with_openai(self):
"""Test that processing aws_* params with openai provider doesn't crash."""
params = {
"temperature": 0.7,
"custom_llm_provider": "openai",
"kwargs": {"aws_region_name": "eu-central-1"},
}

# Should not raise any exception
result = ResponsesAPIRequestUtils.get_requested_response_api_optional_param(params)
assert "temperature" in result

def test_provider_specific_params_no_crash_with_vertex_ai(self):
"""Test that processing vertex_* params with vertex_ai provider doesn't crash."""
params = {
"temperature": 0.7,
"custom_llm_provider": "vertex_ai",
"kwargs": {"vertex_project": "my-project"},
}

# Should not raise any exception
result = ResponsesAPIRequestUtils.get_requested_response_api_optional_param(params)
assert "temperature" in result
Loading