Skip to content

Commit ff01b96

Browse files
committed
models - openai - async client
1 parent 5190645 commit ff01b96

File tree

3 files changed

+61
-30
lines changed

3 files changed

+61
-30
lines changed

src/strands/models/openai.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
6161
logger.debug("config=<%s> | initializing", self.config)
6262

6363
client_args = client_args or {}
64-
self.client = openai.OpenAI(**client_args)
64+
self.client = openai.AsyncOpenAI(**client_args)
6565

6666
@override
6767
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
@@ -91,14 +91,14 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
9191
Returns:
9292
An iterable of response events from the OpenAI model.
9393
"""
94-
response = self.client.chat.completions.create(**request)
94+
response = await self.client.chat.completions.create(**request)
9595

9696
yield {"chunk_type": "message_start"}
9797
yield {"chunk_type": "content_start", "data_type": "text"}
9898

9999
tool_calls: dict[int, list[Any]] = {}
100100

101-
for event in response:
101+
async for event in response:
102102
# Defensive: skip events with empty or missing choices
103103
if not getattr(event, "choices", None):
104104
continue
@@ -133,7 +133,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
133133
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
134134

135135
# Skip remaining events as we don't have use for anything except the final usage payload
136-
for event in response:
136+
async for event in response:
137137
_ = event
138138

139139
yield {"chunk_type": "metadata", "data": event.usage}
@@ -151,7 +151,7 @@ async def structured_output(
151151
Yields:
152152
Model events with the last being the structured output.
153153
"""
154-
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
154+
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
155155
model=self.get_config()["model_id"],
156156
messages=super().format_request(prompt)["messages"],
157157
response_format=output_model,

tests-integ/test_model_openai.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from strands.models.openai import OpenAIModel
1313

1414

15-
@pytest.fixture
15+
@pytest.fixture(scope="module")
1616
def model():
1717
return OpenAIModel(
1818
model_id="gpt-4o",
@@ -22,7 +22,7 @@ def model():
2222
)
2323

2424

25-
@pytest.fixture
25+
@pytest.fixture(scope="module")
2626
def tools():
2727
@strands.tool
2828
def tool_time() -> str:
@@ -35,36 +35,65 @@ def tool_weather() -> str:
3535
return [tool_time, tool_weather]
3636

3737

38-
@pytest.fixture
38+
@pytest.fixture(scope="module")
3939
def agent(model, tools):
4040
return Agent(model=model, tools=tools)
4141

4242

43-
@pytest.fixture
43+
@pytest.fixture(scope="module")
4444
def test_image_path(request):
4545
return request.config.rootpath / "tests-integ" / "test_image.png"
4646

4747

48-
def test_agent(agent):
48+
@pytest.fixture(scope="module")
49+
def weather():
50+
class Weather(BaseModel):
51+
"""Extracts the time and weather from the user's message with the exact strings."""
52+
53+
time: str
54+
weather: str
55+
56+
return Weather(time="12:00", weather="sunny")
57+
58+
59+
def test_agent_invoke(agent):
4960
result = agent("What is the time and weather in New York?")
5061
text = result.message["content"][0]["text"].lower()
5162

5263
assert all(string in text for string in ["12:00", "sunny"])
5364

5465

55-
def test_structured_output(model):
56-
class Weather(BaseModel):
57-
"""Extracts the time and weather from the user's message with the exact strings."""
66+
@pytest.mark.asyncio
67+
async def test_agent_invoke_async(agent):
68+
result = await agent.invoke_async("What is the time and weather in New York?")
69+
text = result.message["content"][0]["text"].lower()
5870

59-
time: str
60-
weather: str
71+
assert all(string in text for string in ["12:00", "sunny"])
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_agent_stream_async(agent):
76+
stream = agent.stream_async("What is the time and weather in New York?")
77+
async for event in stream:
78+
_ = event
79+
80+
result = event["result"]
81+
text = result.message["content"][0]["text"].lower()
82+
83+
assert all(string in text for string in ["12:00", "sunny"])
84+
85+
86+
def test_structured_output(agent, weather):
87+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
88+
exp_weather = weather
89+
assert tru_weather == exp_weather
6190

62-
agent = Agent(model=model)
6391

64-
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
65-
assert isinstance(result, Weather)
66-
assert result.time == "12:00"
67-
assert result.weather == "sunny"
92+
@pytest.mark.asyncio
93+
async def test_structured_output_async(agent, weather):
94+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
95+
exp_weather = weather
96+
assert tru_weather == exp_weather
6897

6998

7099
def test_tool_returning_images(model, test_image_path):

tests/strands/models/test_openai.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.fixture
1111
def openai_client_cls():
12-
with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls:
12+
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
1313
yield mock_client_cls
1414

1515

@@ -70,7 +70,7 @@ def test_update_config(model, model_id):
7070

7171

7272
@pytest.mark.asyncio
73-
async def test_stream(openai_client, model, alist):
73+
async def test_stream(openai_client, model, agenerator, alist):
7474
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
7575
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
7676
mock_delta_1 = unittest.mock.Mock(
@@ -102,8 +102,8 @@ async def test_stream(openai_client, model, alist):
102102
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
103103
mock_event_6 = unittest.mock.Mock()
104104

105-
openai_client.chat.completions.create.return_value = iter(
106-
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
105+
openai_client.chat.completions.create = unittest.mock.AsyncMock(
106+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
107107
)
108108

109109
request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
@@ -133,7 +133,7 @@ async def test_stream(openai_client, model, alist):
133133

134134

135135
@pytest.mark.asyncio
136-
async def test_stream_empty(openai_client, model, alist):
136+
async def test_stream_empty(openai_client, model, agenerator, alist):
137137
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)
138138
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
139139

@@ -142,7 +142,9 @@ async def test_stream_empty(openai_client, model, alist):
142142
mock_event_3 = unittest.mock.Mock()
143143
mock_event_4 = unittest.mock.Mock(usage=mock_usage)
144144

145-
openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
145+
openai_client.chat.completions.create = unittest.mock.AsyncMock(
146+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]),
147+
)
146148

147149
request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
148150
response = model.stream(request)
@@ -161,7 +163,7 @@ async def test_stream_empty(openai_client, model, alist):
161163

162164

163165
@pytest.mark.asyncio
164-
async def test_stream_with_empty_choices(openai_client, model, alist):
166+
async def test_stream_with_empty_choices(openai_client, model, agenerator, alist):
165167
mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None)
166168
mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
167169

@@ -180,8 +182,8 @@ async def test_stream_with_empty_choices(openai_client, model, alist):
180182
# Final event with usage info
181183
mock_event_5 = unittest.mock.Mock(usage=mock_usage)
182184

183-
openai_client.chat.completions.create.return_value = iter(
184-
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]
185+
openai_client.chat.completions.create = unittest.mock.AsyncMock(
186+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5])
185187
)
186188

187189
request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]}
@@ -212,7 +214,7 @@ async def test_structured_output(openai_client, model, test_output_model_cls, al
212214
mock_response = unittest.mock.Mock()
213215
mock_response.choices = [mock_choice]
214216

215-
openai_client.beta.chat.completions.parse.return_value = mock_response
217+
openai_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response)
216218

217219
stream = model.structured_output(test_output_model_cls, messages)
218220
events = await alist(stream)

0 commit comments

Comments
 (0)