diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index d894e58e2..1536fc4d6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -13,7 +13,7 @@ from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages -from .openai import OpenAIModel +from ..types.models.openai import OpenAIModel logger = logging.getLogger(__name__) @@ -103,6 +103,63 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) + @override + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + """Send the request to the LiteLLM model and get the streaming response. + + Args: + request: The formatted request to send to the LiteLLM model. + + Returns: + An iterable of response events from the LiteLLM model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} + @override async def structured_output( self, output_model: Type[T], prompt: Messages diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 5446cbd3d..bde0bb45f 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -61,7 +61,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: logger.debug("config=<%s> | initializing", self.config) client_args = client_args or {} - self.client = openai.OpenAI(**client_args) + self.client = openai.AsyncOpenAI(**client_args) @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] @@ -91,14 +91,14 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] Returns: An iterable of response events from the OpenAI model. """ - response = self.client.chat.completions.create(**request) + response = await self.client.chat.completions.create(**request) yield {"chunk_type": "message_start"} yield {"chunk_type": "content_start", "data_type": "text"} tool_calls: dict[int, list[Any]] = {} - for event in response: + async for event in response: # Defensive: skip events with empty or missing choices if not getattr(event, "choices", None): continue @@ -133,7 +133,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] yield {"chunk_type": "message_stop", "data": choice.finish_reason} # Skip remaining events as we don't have use for anything except the final usage payload - for event in response: + async for event in response: _ = event yield {"chunk_type": "metadata", "data": event.usage} @@ -151,7 +151,7 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore model=self.get_config()["model_id"], messages=super().format_request(prompt)["messages"], response_format=output_model, diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py index bca874af8..e0dfcb34b 100644 --- a/tests-integ/test_model_openai.py +++ b/tests-integ/test_model_openai.py @@ -12,7 +12,7 @@ from strands.models.openai import OpenAIModel -@pytest.fixture +@pytest.fixture(scope="module") def model(): return OpenAIModel( model_id="gpt-4o", @@ -22,7 +22,7 @@ def model(): ) -@pytest.fixture +@pytest.fixture(scope="module") def tools(): @strands.tool def tool_time() -> str: @@ -35,36 +35,65 @@ def tool_weather() -> str: return [tool_time, tool_weather] -@pytest.fixture +@pytest.fixture(scope="module") def agent(model, tools): return Agent(model=model, tools=tools) -@pytest.fixture +@pytest.fixture(scope="module") +def weather(): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +@pytest.fixture(scope="module") def test_image_path(request): return request.config.rootpath / "tests-integ" / "test_image.png" -def test_agent(agent): +def test_agent_invoke(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) -def test_structured_output(model): - class Weather(BaseModel): - """Extracts the time and weather from the user's message with the exact strings.""" +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() - time: str - weather: str + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather - agent = Agent(model=model) - result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") - assert isinstance(result, Weather) - assert result.time == "12:00" - assert result.weather == "sunny" +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather def test_tool_returning_images(model, test_image_path): diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 989b7eae6..8f4a9e341 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -115,6 +115,69 @@ def test_format_request_message_content(content, exp_result): assert tru_result == exp_result +@pytest.mark.asyncio +async def test_stream(litellm_client, model, alist): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + reasoning_content="", + content=None, + tool_calls=None, + ) + mock_delta_2 = unittest.mock.Mock( + reasoning_content="\nI'm thinking", + content=None, + tool_calls=None, + ) + mock_delta_3 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_4 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None + ) + + mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) + mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) + mock_event_6 = unittest.mock.Mock() + + litellm_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6] + ) + + request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + tru_events = await alist(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nI'm thinking"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_6.usage}, + ] + + assert tru_events == exp_events + litellm_client.chat.completions.create.assert_called_once_with(**request) + + @pytest.mark.asyncio async def test_structured_output(litellm_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7bc16e5c2..ec659effe 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -9,7 +9,7 @@ @pytest.fixture def openai_client_cls(): - with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: + with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: yield mock_client_cls @@ -70,7 +70,7 @@ def test_update_config(model, model_id): @pytest.mark.asyncio -async def test_stream(openai_client, model, alist): +async def test_stream(openai_client, model, agenerator, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( @@ -102,8 +102,8 @@ async def test_stream(openai_client, model, alist): mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) mock_event_6 = unittest.mock.Mock() - openai_client.chat.completions.create.return_value = iter( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6] + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]) ) request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} @@ -133,7 +133,7 @@ async def test_stream(openai_client, model, alist): @pytest.mark.asyncio -async def test_stream_empty(openai_client, model, alist): +async def test_stream_empty(openai_client, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) @@ -142,7 +142,9 @@ async def test_stream_empty(openai_client, model, alist): mock_event_3 = unittest.mock.Mock() mock_event_4 = unittest.mock.Mock(usage=mock_usage) - openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]), + ) request = {"model": "m1", "messages": [{"role": "user", "content": []}]} response = model.stream(request) @@ -161,7 +163,7 @@ async def test_stream_empty(openai_client, model, alist): @pytest.mark.asyncio -async def test_stream_with_empty_choices(openai_client, model, alist): +async def test_stream_with_empty_choices(openai_client, model, agenerator, alist): mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -180,8 +182,8 @@ async def test_stream_with_empty_choices(openai_client, model, alist): # Final event with usage info mock_event_5 = unittest.mock.Mock(usage=mock_usage) - openai_client.chat.completions.create.return_value = iter( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]) ) request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} @@ -212,7 +214,7 @@ async def test_structured_output(openai_client, model, test_output_model_cls, al mock_response = unittest.mock.Mock() mock_response.choices = [mock_choice] - openai_client.beta.chat.completions.parse.return_value = mock_response + openai_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response) stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream)