Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from .openai import OpenAIModel
from ..types.models.openai import OpenAIModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,6 +103,63 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]

return super().format_request_message_content(content)

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the LiteLLM model and get the streaming response.

Args:
request: The formatted request to send to the LiteLLM model.

Returns:
An iterable of response events from the LiteLLM model.
"""
response = self.client.chat.completions.create(**request)

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}

tool_calls: dict[int, list[Any]] = {}

for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield {
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
break

yield {"chunk_type": "content_stop", "data_type": "text"}

for tool_deltas in tool_calls.values():
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}

yield {"chunk_type": "content_stop", "data_type": "tool"}

yield {"chunk_type": "message_stop", "data": choice.finish_reason}

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages
Expand Down
10 changes: 5 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client = openai.OpenAI(**client_args)
self.client = openai.AsyncOpenAI(**client_args)

@override
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
Expand Down Expand Up @@ -91,14 +91,14 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
Returns:
An iterable of response events from the OpenAI model.
"""
response = self.client.chat.completions.create(**request)
response = await self.client.chat.completions.create(**request)

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}

tool_calls: dict[int, list[Any]] = {}

for event in response:
async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
Expand Down Expand Up @@ -133,7 +133,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
yield {"chunk_type": "message_stop", "data": choice.finish_reason}

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
async for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}
Expand All @@ -151,7 +151,7 @@ async def structured_output(
Yields:
Model events with the last being the structured output.
"""
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
model=self.get_config()["model_id"],
messages=super().format_request(prompt)["messages"],
response_format=output_model,
Expand Down
59 changes: 44 additions & 15 deletions tests-integ/test_model_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from strands.models.openai import OpenAIModel


@pytest.fixture
@pytest.fixture(scope="module")
def model():
return OpenAIModel(
model_id="gpt-4o",
Expand All @@ -22,7 +22,7 @@ def model():
)


@pytest.fixture
@pytest.fixture(scope="module")
def tools():
@strands.tool
def tool_time() -> str:
Expand All @@ -35,36 +35,65 @@ def tool_weather() -> str:
return [tool_time, tool_weather]


@pytest.fixture
@pytest.fixture(scope="module")
def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.fixture
@pytest.fixture(scope="module")
def weather():
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

return Weather(time="12:00", weather="sunny")


@pytest.fixture(scope="module")
def test_image_path(request):
return request.config.rootpath / "tests-integ" / "test_image.png"


def test_agent(agent):
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


def test_structured_output(model):
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

time: str
weather: str
assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
async for event in stream:
_ = event

result = event["result"]
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


def test_agent_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather

agent = Agent(model=model)

result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


def test_tool_returning_images(model, test_image_path):
Expand Down
63 changes: 63 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,69 @@ def test_format_request_message_content(content, exp_result):
assert tru_result == exp_result


@pytest.mark.asyncio
async def test_stream(litellm_client, model, alist):
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_1 = unittest.mock.Mock(
reasoning_content="",
content=None,
tool_calls=None,
)
mock_delta_2 = unittest.mock.Mock(
reasoning_content="\nI'm thinking",
content=None,
tool_calls=None,
)
mock_delta_3 = unittest.mock.Mock(
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None
)

mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
mock_delta_4 = unittest.mock.Mock(
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None
)

mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)

mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)])
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)])
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
mock_event_6 = unittest.mock.Mock()

litellm_client.chat.completions.create.return_value = iter(
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
)

request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
response = model.stream(request)
tru_events = await alist(response)
exp_events = [
{"chunk_type": "message_start"},
{"chunk_type": "content_start", "data_type": "text"},
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nI'm thinking"},
{"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
{"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
{"chunk_type": "content_stop", "data_type": "text"},
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
{"chunk_type": "content_stop", "data_type": "tool"},
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
{"chunk_type": "content_stop", "data_type": "tool"},
{"chunk_type": "message_stop", "data": "tool_calls"},
{"chunk_type": "metadata", "data": mock_event_6.usage},
]

assert tru_events == exp_events
litellm_client.chat.completions.create.assert_called_once_with(**request)


@pytest.mark.asyncio
async def test_structured_output(litellm_client, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
Expand Down
22 changes: 12 additions & 10 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@pytest.fixture
def openai_client_cls():
with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls:
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
yield mock_client_cls


Expand Down Expand Up @@ -70,7 +70,7 @@ def test_update_config(model, model_id):


@pytest.mark.asyncio
async def test_stream(openai_client, model, alist):
async def test_stream(openai_client, model, agenerator, alist):
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_1 = unittest.mock.Mock(
Expand Down Expand Up @@ -102,8 +102,8 @@ async def test_stream(openai_client, model, alist):
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
mock_event_6 = unittest.mock.Mock()

openai_client.chat.completions.create.return_value = iter(
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
openai_client.chat.completions.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
)

request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
Expand Down Expand Up @@ -133,7 +133,7 @@ async def test_stream(openai_client, model, alist):


@pytest.mark.asyncio
async def test_stream_empty(openai_client, model, alist):
async def test_stream_empty(openai_client, model, agenerator, alist):
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)

Expand All @@ -142,7 +142,9 @@ async def test_stream_empty(openai_client, model, alist):
mock_event_3 = unittest.mock.Mock()
mock_event_4 = unittest.mock.Mock(usage=mock_usage)

openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
openai_client.chat.completions.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]),
)

request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
response = model.stream(request)
Expand All @@ -161,7 +163,7 @@ async def test_stream_empty(openai_client, model, alist):


@pytest.mark.asyncio
async def test_stream_with_empty_choices(openai_client, model, alist):
async def test_stream_with_empty_choices(openai_client, model, agenerator, alist):
mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None)
mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)

Expand All @@ -180,8 +182,8 @@ async def test_stream_with_empty_choices(openai_client, model, alist):
# Final event with usage info
mock_event_5 = unittest.mock.Mock(usage=mock_usage)

openai_client.chat.completions.create.return_value = iter(
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]
openai_client.chat.completions.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5])
)

request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]}
Expand Down Expand Up @@ -212,7 +214,7 @@ async def test_structured_output(openai_client, model, test_output_model_cls, al
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]

openai_client.beta.chat.completions.parse.return_value = mock_response
openai_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response)

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)
Expand Down