Skip to content

Commit 5268e88

Browse files
dbschmigelskipgrayy
authored andcommitted
fix(models): patch litellm bug to honor passing in use_litellm_proxy as client_args (strands-agents#808)
* fix(models): patch litellm bug to honor passing in use_litellm_proxy as client_args --------- Co-authored-by: Patrick Gray <[email protected]>
1 parent cd51a4e commit 5268e88

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/strands/models/litellm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
5252
self.client_args = client_args or {}
5353
validate_config_keys(model_config, self.LiteLLMConfig)
5454
self.config = dict(model_config)
55+
self._apply_proxy_prefix()
5556

5657
logger.debug("config=<%s> | initializing", self.config)
5758

@@ -64,6 +65,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
6465
"""
6566
validate_config_keys(model_config, self.LiteLLMConfig)
6667
self.config.update(model_config)
68+
self._apply_proxy_prefix()
6769

6870
@override
6971
def get_config(self) -> LiteLLMConfig:
@@ -226,3 +228,14 @@ async def structured_output(
226228

227229
# If no tool_calls found, raise an error
228230
raise ValueError("No tool_calls found in response")
231+
232+
def _apply_proxy_prefix(self) -> None:
233+
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.
234+
235+
This is a workaround for https://github.com/BerriAI/litellm/issues/13454
236+
where use_litellm_proxy parameter is not honored.
237+
"""
238+
if self.client_args.get("use_litellm_proxy") and "model_id" in self.config:
239+
model_id = self.get_config()["model_id"]
240+
if not model_id.startswith("litellm_proxy/"):
241+
self.config["model_id"] = f"litellm_proxy/{model_id}"

tests/strands/models/test_litellm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,39 @@ def test_update_config(model, model_id):
5858
assert tru_model_id == exp_model_id
5959

6060

61+
@pytest.mark.parametrize(
62+
"client_args, model_id, expected_model_id",
63+
[
64+
({"use_litellm_proxy": True}, "openai/gpt-4", "litellm_proxy/openai/gpt-4"),
65+
({"use_litellm_proxy": False}, "openai/gpt-4", "openai/gpt-4"),
66+
({"use_litellm_proxy": None}, "openai/gpt-4", "openai/gpt-4"),
67+
({}, "openai/gpt-4", "openai/gpt-4"),
68+
(None, "openai/gpt-4", "openai/gpt-4"),
69+
({"use_litellm_proxy": True}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"),
70+
({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"),
71+
],
72+
)
73+
def test__init__use_litellm_proxy_prefix(client_args, model_id, expected_model_id):
74+
"""Test litellm_proxy prefix behavior for various configurations."""
75+
model = LiteLLMModel(client_args=client_args, model_id=model_id)
76+
assert model.get_config()["model_id"] == expected_model_id
77+
78+
79+
@pytest.mark.parametrize(
80+
"client_args, initial_model_id, new_model_id, expected_model_id",
81+
[
82+
({"use_litellm_proxy": True}, "openai/gpt-4", "anthropic/claude-3", "litellm_proxy/anthropic/claude-3"),
83+
({"use_litellm_proxy": False}, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"),
84+
(None, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"),
85+
],
86+
)
87+
def test_update_config_proxy_prefix(client_args, initial_model_id, new_model_id, expected_model_id):
88+
"""Test that update_config applies proxy prefix correctly."""
89+
model = LiteLLMModel(client_args=client_args, model_id=initial_model_id)
90+
model.update_config(model_id=new_model_id)
91+
assert model.get_config()["model_id"] == expected_model_id
92+
93+
6194
@pytest.mark.parametrize(
6295
"content, exp_result",
6396
[

0 commit comments

Comments
 (0)