diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 523b0da8b..07d7fb555 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -48,13 +48,11 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: https://github.com/BerriAI/litellm/blob/main/litellm/main.py. **model_config: Configuration options for the LiteLLM model. """ + self.client_args = client_args or {} self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) - client_args = client_args or {} - self.client = litellm.LiteLLM(**client_args) - @override def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override] """Update the LiteLLM model configuration with the provided arguments. @@ -124,7 +122,7 @@ async def stream( logger.debug("formatted request=<%s>", request) logger.debug("invoking model") - response = self.client.chat.completions.create(**request) + response = await litellm.acompletion(**self.client_args, **request) logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -132,7 +130,7 @@ async def stream( tool_calls: dict[int, list[Any]] = {} - for event in response: + async for event in response: # Defensive: skip events with empty or missing choices if not getattr(event, "choices", None): continue @@ -171,7 +169,7 @@ async def stream( yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) # Skip remaining events as we don't have use for anything except the final usage payload - for event in response: + async for event in response: _ = event yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) @@ -191,10 +189,8 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - # The LiteLLM `Client` inits with Chat(). - # Chat() inits with self.completions - # completions() has a method `create()` which wraps the real completion API of Litellm - response = self.client.chat.completions.create( + response = await litellm.acompletion( + **self.client_args, model=self.get_config()["model_id"], messages=self.format_request(prompt)["messages"], response_format=output_model, diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 2bafc331a..bddd44abb 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -8,14 +8,14 @@ @pytest.fixture -def litellm_client_cls(): - with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: - yield mock_client_cls +def litellm_acompletion(): + with unittest.mock.patch.object(strands.models.litellm.litellm, "acompletion") as mock_acompletion: + yield mock_acompletion @pytest.fixture -def litellm_client(litellm_client_cls): - return litellm_client_cls.return_value +def api_key(): + return "a1" @pytest.fixture @@ -24,10 +24,10 @@ def model_id(): @pytest.fixture -def model(litellm_client, model_id): - _ = litellm_client +def model(litellm_acompletion, api_key, model_id): + _ = litellm_acompletion - return LiteLLMModel(model_id=model_id) + return LiteLLMModel(client_args={"api_key": api_key}, model_id=model_id) @pytest.fixture @@ -49,17 +49,6 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__(litellm_client_cls, model_id): - model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) - - tru_config = model.get_config() - exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} - - assert tru_config == exp_config - - litellm_client_cls.assert_called_once_with(api_key="k1") - - def test_update_config(model, model_id): model.update_config(model_id=model_id) @@ -116,7 +105,7 @@ def test_format_request_message_content(content, exp_result): @pytest.mark.asyncio -async def test_stream(litellm_client, model, alist): +async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( @@ -148,8 +137,8 @@ async def test_stream(litellm_client, model, alist): mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) mock_event_6 = unittest.mock.Mock() - litellm_client.chat.completions.create.return_value = iter( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6] + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]) ) messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}] @@ -196,18 +185,20 @@ async def test_stream(litellm_client, model, alist): ] assert tru_events == exp_events + expected_request = { - "model": "m1", + "api_key": api_key, + "model": model_id, "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], "stream": True, "stream_options": {"include_usage": True}, "tools": [], } - litellm_client.chat.completions.create.assert_called_once_with(**expected_request) + litellm_acompletion.assert_called_once_with(**expected_request) @pytest.mark.asyncio -async def test_structured_output(litellm_client, model, test_output_model_cls, alist): +async def test_structured_output(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] mock_choice = unittest.mock.Mock() @@ -216,7 +207,7 @@ async def test_structured_output(litellm_client, model, test_output_model_cls, a mock_response = unittest.mock.Mock() mock_response.choices = [mock_choice] - litellm_client.chat.completions.create.return_value = mock_response + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True): stream = model.structured_output(test_output_model_cls, messages) diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 382f75194..efdd6a5ed 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -29,6 +29,17 @@ def agent(model, tools): return Agent(model=model, tools=tools) +@pytest.fixture +def weather(): + class Weather(pydantic.BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): @@ -44,24 +55,44 @@ def lower(_, value): return Color(name="yellow") -def test_agent(agent): +def test_agent_invoke(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) -def test_structured_output(model): - class Weather(pydantic.BaseModel): - time: str - weather: str +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather - agent_no_tools = Agent(model=model) - result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") - assert isinstance(result, Weather) - assert result.time == "12:00" - assert result.weather == "sunny" +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather def test_invoke_multi_modal_input(agent, yellow_img):