Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 182 additions & 20 deletions tests/entrypoints/openai/responses/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def pairs_of_event_types() -> dict[str, str]:
"response.mcp_call.completed": "response.mcp_call.in_progress",
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501
"response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501
"response.code_interpreter_call.completed": "response.code_interpreter_call.in_progress", # noqa: E501
"response.web_search_call.completed": "response.web_search_call.in_progress",
}
# fmt: on
Expand Down Expand Up @@ -108,39 +109,200 @@ def events_contain_type(events: list, type_substring: str) -> bool:
return any(type_substring in getattr(e, "type", "") for e in events)


def validate_streaming_event_stack(
events: list, pairs_of_event_types: dict[str, str]
) -> None:
"""Validate that streaming events are properly nested/paired."""
def _validate_event_pairing(events: list, pairs_of_event_types: dict[str, str]) -> None:
"""Validate that streaming events are properly nested/paired.

Derives push/pop sets from *pairs_of_event_types* so that every
start/end pair in the dict is handled automatically.
"""
start_events = set(pairs_of_event_types.values())
end_events = set(pairs_of_event_types.keys())

stack: list[str] = []
for event in events:
etype = event.type
if etype == "response.created":
stack.append(etype)
elif etype == "response.completed":
assert stack and stack[-1] == pairs_of_event_types[etype], (
f"Unexpected stack top for {etype}: "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
stack.append(etype)
elif etype.endswith("delta"):
if stack and stack[-1] == etype:
continue
stack.append(etype)
elif etype.endswith("done") or etype == "response.mcp_call.completed":
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
if etype in end_events:
expected_start = pairs_of_event_types[etype]
assert stack and stack[-1] == expected_start, (
f"Stack mismatch for {etype}: "
f"expected {expected_start}, "
f"got {stack[-1] if stack else '<empty>'}"
)
stack.pop()
elif etype in start_events:
# Consecutive deltas of the same type share a single stack slot.
if etype.endswith("delta") and stack and stack[-1] == etype:
continue
stack.append(etype)
# else: passthrough event (e.g. response.in_progress,
# web_search_call.searching, code_interpreter_call.interpreting)
assert len(stack) == 0, f"Unclosed events on stack: {stack}"


def _validate_event_ordering(events: list) -> None:
"""Validate that envelope events appear in the correct positions."""
assert len(events) >= 2, f"Expected at least 2 events, got {len(events)}"

# First event must be response.created
assert events[0].type == "response.created", (
f"First event must be response.created, got {events[0].type}"
)
# Last event must be response.completed
assert events[-1].type == "response.completed", (
f"Last event must be response.completed, got {events[-1].type}"
)

# response.in_progress, if present, must be the second event
in_progress_indices = [
i for i, e in enumerate(events) if e.type == "response.in_progress"
]
if in_progress_indices:
assert in_progress_indices == [1], (
f"response.in_progress must be the second event, "
f"found at indices {in_progress_indices}"
)

# Exactly one created and one completed
created_count = sum(1 for e in events if e.type == "response.created")
completed_count = sum(1 for e in events if e.type == "response.completed")
assert created_count == 1, (
f"Expected exactly 1 response.created, got {created_count}"
)
assert completed_count == 1, (
f"Expected exactly 1 response.completed, got {completed_count}"
)


def _validate_field_consistency(events: list) -> None:
"""Validate item_id, output_index, and content_index consistency.

Tracks the active output item established by ``output_item.added``
and verifies that all subsequent events for that item carry matching
identifiers until ``output_item.done`` closes it.
"""
_SESSION_EVENTS = {
"response.created",
"response.in_progress",
"response.completed",
}

active_item_id: str | None = None
active_output_index: int | None = None
last_output_index: int = -1
active_content_index: int | None = None

for event in events:
etype = event.type

if etype in _SESSION_EVENTS:
continue

# --- output_item.added: opens a new item ------------------
if etype == "response.output_item.added":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)

assert item is not None, "output_item.added must have an item"
item_id = getattr(item, "id", None)
assert item_id, "output_item.added item must have an id"

# output_index must be non-decreasing across items
if output_index is not None:
assert output_index >= last_output_index, (
f"output_index went backwards: {output_index} < {last_output_index}"
)
last_output_index = output_index

active_item_id = item_id
active_output_index = output_index
active_content_index = None
continue

# --- output_item.done: closes the active item -------------
if etype == "response.output_item.done":
item = getattr(event, "item", None)
output_index = getattr(event, "output_index", None)

assert item is not None, "output_item.done must have an item"
done_item_id = getattr(item, "id", None)

if active_item_id is not None and done_item_id:
assert done_item_id == active_item_id, (
f"output_item.done item.id mismatch: "
f"expected {active_item_id}, got {done_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"output_item.done output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)

active_item_id = None
active_output_index = None
active_content_index = None
continue

# --- content_part / reasoning_part added: sets content_index
if etype in (
"response.content_part.added",
"response.reasoning_part.added",
):
_assert_item_fields(event, etype, active_item_id, active_output_index)
active_content_index = getattr(event, "content_index", None)
continue

# --- all other item-level events --------------------------
_assert_item_fields(event, etype, active_item_id, active_output_index)

# content_index (only meaningful on events that carry it)
content_index = getattr(event, "content_index", None)
if content_index is not None and active_content_index is not None:
assert content_index == active_content_index, (
f"{etype} content_index mismatch: "
f"expected {active_content_index}, got {content_index}"
)


def _assert_item_fields(
event,
etype: str,
active_item_id: str | None,
active_output_index: int | None,
) -> None:
"""Check that *event*'s item_id and output_index match the active item."""
event_item_id = getattr(event, "item_id", None)
output_index = getattr(event, "output_index", None)

if active_item_id is not None and event_item_id is not None:
assert event_item_id == active_item_id, (
f"{etype} item_id mismatch: expected {active_item_id}, got {event_item_id}"
)
if active_output_index is not None and output_index is not None:
assert output_index == active_output_index, (
f"{etype} output_index mismatch: "
f"expected {active_output_index}, got {output_index}"
)


def validate_streaming_event_stack(
events: list, pairs_of_event_types: dict[str, str]
) -> None:
"""Validate streaming events: pairing, ordering, and field consistency.

Checks three aspects:
1. **Event pairing** — start/end events are properly nested
(stack-based matching derived from *pairs_of_event_types*).
2. **Event ordering** — envelope events (``created``,
``in_progress``, ``completed``) appear at the correct positions.
3. **Field consistency** — ``item_id``, ``output_index``, and
``content_index`` are consistent across related events within
each output item's lifecycle.
"""
_validate_event_pairing(events, pairs_of_event_types)
_validate_event_ordering(events)
_validate_field_consistency(events)


def log_response_diagnostics(
response,
*,
Expand Down
67 changes: 24 additions & 43 deletions tests/entrypoints/openai/responses/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,21 +910,25 @@ def _has_function_call(evts: list) -> bool:
reason="This test is flaky in CI, needs investigation and "
"potential fixes in the code interpreter MCP implementation."
)
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
tools = [{"type": "mcp", "server_label": "code_interpreter"}]
async def test_code_interpreter_streaming(
client: OpenAI,
model_name: str,
pairs_of_event_types: dict[str, str],
):
tools = [{"type": "code_interpreter", "container": {"type": "auto"}}]
input_text = (
"Calculate 123 * 456 using python. "
"The python interpreter is not stateful and you must "
"print to see the output."
)

def _has_mcp_call(evts: list) -> bool:
return events_contain_type(evts, "mcp_call")
def _has_code_interpreter(evts: list) -> bool:
return events_contain_type(evts, "code_interpreter")

events = await retry_streaming_for(
client,
model=model_name,
validate_events=_has_mcp_call,
validate_events=_has_code_interpreter,
input=input_text,
tools=tools,
temperature=0.0,
Expand All @@ -936,59 +940,36 @@ def _has_mcp_call(evts: list) -> bool:
event_types = [e.type for e in events]
event_types_set = set(event_types)
logger.info(
"\n====== MCP Streaming Diagnostics ======\n"
"\n====== Code Interpreter Streaming Diagnostics ======\n"
"Event count: %d\n"
"Event types (in order): %s\n"
"Unique event types: %s\n"
"=======================================",
"====================================================",
len(events),
event_types,
sorted(event_types_set),
)

# Verify the full MCP streaming lifecycle
assert "response.output_item.added" in event_types_set, (
f"MCP call was not added. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.in_progress" in event_types_set, (
f"MCP call in_progress not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.delta" in event_types_set, (
f"MCP arguments delta not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call_arguments.done" in event_types_set, (
f"MCP arguments done not seen. Events: {sorted(event_types_set)}"
)
assert "response.mcp_call.completed" in event_types_set, (
f"MCP call completed not seen. Events: {sorted(event_types_set)}"
)
assert "response.output_item.done" in event_types_set, (
f"MCP item done not seen. Events: {sorted(event_types_set)}"
)
# Structural validation (pairing, ordering, field consistency)
validate_streaming_event_stack(events, pairs_of_event_types)

# Validate specific MCP event details
# Validate code interpreter item fields
for event in events:
if event.type == "response.output_item.added":
if hasattr(event.item, "type") and event.item.type == "mcp_call":
assert event.item.name == "python"
assert event.item.server_label == "code_interpreter"
elif event.type == "response.mcp_call_arguments.done":
assert event.name == "python"
assert event.arguments is not None
if (
event.type == "response.output_item.added"
and hasattr(event.item, "type")
and event.item.type == "code_interpreter_call"
):
assert event.item.status == "in_progress"
elif event.type == "response.code_interpreter_call_code.done":
assert event.code is not None
elif (
event.type == "response.output_item.done"
and hasattr(event.item, "type")
and event.item.type == "mcp_call"
and event.item.type == "code_interpreter_call"
):
assert event.item.name == "python"
assert event.item.status == "completed"

# code_interpreter events should NOT appear when using MCP type
code_interp_events = [e.type for e in events if "code_interpreter" in e.type]
assert not code_interp_events, (
"Should not see code_interpreter events when using MCP type, "
f"but got: {code_interp_events}"
)
assert event.item.code is not None


@pytest.mark.asyncio
Expand Down
Loading