Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,16 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str
for message in messages:
contents = message["content"]

# Check for reasoningContent and warn user
if any("reasoningContent" in content for content in contents):
logger.warning(
"reasoningContent is not supported in multi-turn conversations with the Chat Completions API."
)

formatted_contents = [
cls.format_request_message_content(content)
for content in contents
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"])
]
formatted_tool_calls = [
cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content
Expand Down Expand Up @@ -405,38 +411,46 @@ async def stream(

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}
data_type = None
finish_reason = None # Store finish_reason for later use
event = None # Initialize for scope safety

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
yield chunk
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data_type": data_type,
"data": choice.delta.reasoning_content,
}
)

if choice.delta.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
yield chunk
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
finish_reason = choice.finish_reason # Store for use outside loop
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

Expand All @@ -445,17 +459,37 @@ async def stream(

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event

if event.usage:
if event and hasattr(event, "usage") and event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")

def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
"""Handle switching to a new content stream.

Args:
data_type: The next content data type.
prev_data_type: The previous content data type.

Returns:
Tuple containing:
- Stop block for previous content and the start block for the next content.
- Next content data type.
"""
chunks = []
if data_type != prev_data_type:
if prev_data_type is not None:
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type}))
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type}))

return chunks, data_type

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
Expand Down
106 changes: 75 additions & 31 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,13 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
tru_events = await alist(response)
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStart": {"start": {}}}, # reasoning_content starts
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}},
{"contentBlockStop": {}}, # reasoning_content ends
{"contentBlockStart": {"start": {}}}, # text starts
{"contentBlockDelta": {"delta": {"text": "I'll calculate"}}},
{"contentBlockDelta": {"delta": {"text": "that for you"}}},
{"contentBlockStop": {}},
{"contentBlockStop": {}}, # text ends
{
"contentBlockStart": {
"start": {
Expand Down Expand Up @@ -631,9 +633,7 @@ async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
tru_events = await alist(response)
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{"messageStop": {"stopReason": "end_turn"}}, # No content blocks when no content
]

assert len(tru_events) == len(exp_events)
Expand Down Expand Up @@ -678,10 +678,10 @@ async def test_stream_with_empty_choices(openai_client, model, agenerator, alist
tru_events = await alist(response)
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStart": {"start": {}}}, # text content starts
{"contentBlockDelta": {"delta": {"text": "content"}}},
{"contentBlockDelta": {"delta": {"text": "content"}}},
{"contentBlockStop": {}},
{"contentBlockStop": {}}, # text content ends
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
Expand Down Expand Up @@ -756,6 +756,74 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
assert len(captured_warnings) == 0


@pytest.mark.parametrize(
"new_data_type, prev_data_type, expected_chunks, expected_data_type",
[
("text", None, [{"contentBlockStart": {"start": {}}}], "text"),
(
"reasoning_content",
"text",
[{"contentBlockStop": {}}, {"contentBlockStart": {"start": {}}}],
"reasoning_content",
),
("text", "text", [], "text"),
],
)
def test__stream_switch_content(model, new_data_type, prev_data_type, expected_chunks, expected_data_type):
"""Test _stream_switch_content method for content type switching."""
chunks, data_type = model._stream_switch_content(new_data_type, prev_data_type)
assert chunks == expected_chunks
assert data_type == expected_data_type


def test_format_request_messages_excludes_reasoning_content():
"""Test that reasoningContent is excluded from formatted messages."""
messages = [
{
"content": [
{"text": "Hello"},
{"reasoningContent": {"reasoningText": {"text": "excluded"}}},
],
"role": "user",
},
]

tru_result = OpenAIModel.format_request_messages(messages)

# Only text content should be included
exp_result = [
{
"content": [{"text": "Hello", "type": "text"}],
"role": "user",
},
]
assert tru_result == exp_result


@pytest.mark.asyncio
async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls):
"""Test that structured output also handles context overflow properly."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_context_overflow_exception(openai_client, model, messages):
"""Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException."""
Expand Down Expand Up @@ -803,30 +871,6 @@ async def test_stream_other_bad_request_errors_passthrough(openai_client, model,
assert exc_info.value == mock_error


@pytest.mark.asyncio
async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls):
"""Test that structured output also handles context overflow properly."""
# Create a mock OpenAI BadRequestError with context_length_exceeded code
mock_error = openai.BadRequestError(
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
response=unittest.mock.MagicMock(),
body={"error": {"code": "context_length_exceeded"}},
)
mock_error.code = "context_length_exceeded"

# Configure the mock client to raise the context overflow error
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert "maximum context length" in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_rate_limit_as_throttle(openai_client, model, messages):
"""Test that all rate limit errors are converted to ModelThrottledException."""
Expand Down