Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,27 +263,21 @@ def _deduplicate_messages(messages: list[Message]) -> list[Message]:
return unique_messages


def _parse_multimodal_media_part(part: dict[str, Any]) -> Content | None:
"""Convert a multimodal media part into Agent Framework content."""
part_type = str(part.get("type", "")).lower()
source = part.get("source")

mime_type = cast(
str | None,
part.get("mimeType")
or part.get("mime_type")
or {
"image": "image/*",
"audio": "audio/*",
"video": "video/*",
"document": "application/octet-stream",
"binary": "application/octet-stream",
}.get(part_type, "application/octet-stream"),
)
def _extract_multimodal_source_fields(
part: dict[str, Any],
) -> tuple[str | None, str | None, str | None, str | None]:
"""Extract ``(url, data, binary_id, mime_type)`` from an AG-UI multimodal part.

Handles both the current AG-UI spec (``source.value`` for base64 payloads) and the
legacy ``source.data`` field for backward compatibility. Returned values are the
raw extracted strings (or ``None`` when absent); callers apply their own defaults.
"""
mime_type = cast(str | None, part.get("mimeType") or part.get("mime_type"))
url = cast(str | None, part.get("url") or part.get("uri"))
data = cast(str | None, part.get("data"))
binary_id = cast(str | None, part.get("id"))

source = part.get("source")
if isinstance(source, dict):
source_dict = cast(dict[str, Any], source)
source_type = str(source_dict.get("type", "")).lower()
Expand All @@ -294,14 +288,31 @@ def _parse_multimodal_media_part(part: dict[str, Any]) -> Content | None:
if source_type in {"url", "uri"}:
url = cast(str | None, source_dict.get("url") or source_dict.get("uri"))
elif source_type in {"base64", "data", "binary"}:
data = cast(str | None, source_dict.get("data"))
data = cast(str | None, source_dict.get("value") or source_dict.get("data"))
elif source_type in {"id", "file"}:
binary_id = cast(str | None, source_dict.get("id"))
else:
url = cast(str | None, source_dict.get("url") or source_dict.get("uri") or url)
data = cast(str | None, source_dict.get("data") or data)
data = cast(str | None, source_dict.get("value") or source_dict.get("data") or data)
binary_id = cast(str | None, source_dict.get("id") or binary_id)

return url, data, binary_id, mime_type


def _parse_multimodal_media_part(part: dict[str, Any]) -> Content | None:
"""Convert a multimodal media part into Agent Framework content."""
part_type = str(part.get("type", "")).lower()
url, data, binary_id, mime_type = _extract_multimodal_source_fields(part)

if not mime_type:
mime_type = {
"image": "image/*",
"audio": "audio/*",
"video": "video/*",
"document": "application/octet-stream",
"binary": "application/octet-stream",
}.get(part_type, "application/octet-stream")

if isinstance(url, str) and url:
return Content.from_uri(uri=url, media_type=mime_type)

Expand Down Expand Up @@ -389,30 +400,7 @@ def _normalize_snapshot_content(content: Any) -> Any:
def _legacy_binary_part(part: dict[str, Any]) -> dict[str, Any]:
"""Convert draft/legacy multimodal parts to AG-UI snapshot binary shape."""
normalized: dict[str, Any] = {"type": "binary"}

mime_type = cast(str | None, part.get("mimeType") or part.get("mime_type"))
url = cast(str | None, part.get("url") or part.get("uri"))
data = cast(str | None, part.get("data"))
binary_id = cast(str | None, part.get("id"))

source = part.get("source")
if isinstance(source, dict):
source_part = cast(dict[str, Any], source)
source_mime = source_part.get("mimeType") or source_part.get("mime_type")
if isinstance(source_mime, str) and source_mime:
mime_type = source_mime

source_type = str(source_part.get("type", "")).lower()
if source_type in {"url", "uri"}:
url = cast(str | None, source_part.get("url") or source_part.get("uri"))
elif source_type in {"base64", "data", "binary"}:
data = cast(str | None, source_part.get("data"))
elif source_type in {"id", "file"}:
binary_id = cast(str | None, source_part.get("id"))
else:
url = cast(str | None, source_part.get("url") or source_part.get("uri") or url)
data = cast(str | None, source_part.get("data") or data)
binary_id = cast(str | None, source_part.get("id") or binary_id)
url, data, binary_id, mime_type = _extract_multimodal_source_fields(part)

if isinstance(mime_type, str) and mime_type:
normalized["mimeType"] = mime_type
Expand Down
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_run_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> lis
events.extend(_close_reasoning_block(flow))
# Open new reasoning block.
events.append(ReasoningStartEvent(message_id=message_id))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="assistant"))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="reasoning"))
flow.reasoning_message_id = message_id

if text:
Expand All @@ -613,7 +613,7 @@ def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> lis
else:
# No flow -- backward-compatible full sequence per call.
events.append(ReasoningStartEvent(message_id=message_id))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="assistant"))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="reasoning"))

if text:
events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text))
Expand Down
2 changes: 1 addition & 1 deletion python/packages/ag-ui/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
]
dependencies = [
"agent-framework-core>=1.1.1,<2",
"ag-ui-protocol==0.1.13",
"ag-ui-protocol>=0.1.16,<0.2",
"fastapi>=0.115.0,<0.133.1",
"uvicorn[standard]>=0.30.0,<0.42.0"
]
Expand Down
135 changes: 135 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_message_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,77 @@ def test_agui_snapshot_format_preserves_multimodal_content():
assert content_parts[1]["url"] == "https://example.com/image.png"


def test_agui_snapshot_format_reads_base64_value_field():
"""Snapshot normalization reads the spec 'value' field for base64 sources."""
payload = base64.b64encode(b"abc").decode("utf-8")
normalized = agui_messages_to_snapshot_format(
[
{
"role": "user",
"content": [
{
"type": "image",
"source": {"type": "base64", "value": payload, "mimeType": "image/png"},
},
],
}
]
)

binary_part = normalized[0]["content"][0]
assert binary_part["type"] == "binary"
assert binary_part["mimeType"] == "image/png"
assert binary_part["data"] == payload


def test_agui_snapshot_format_base64_value_preferred_over_data():
"""Snapshot normalization prefers 'value' when both 'value' and 'data' are set."""
value_payload = base64.b64encode(b"new-spec").decode("utf-8")
data_payload = base64.b64encode(b"legacy").decode("utf-8")
normalized = agui_messages_to_snapshot_format(
[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"value": value_payload,
"data": data_payload,
"mimeType": "image/png",
},
},
],
}
]
)

binary_part = normalized[0]["content"][0]
assert binary_part["data"] == value_payload


def test_agui_snapshot_format_base64_data_field_backward_compat():
"""Snapshot normalization still reads the legacy 'data' field when 'value' is absent."""
payload = base64.b64encode(b"legacy").decode("utf-8")
normalized = agui_messages_to_snapshot_format(
[
{
"role": "user",
"content": [
{
"type": "image",
"source": {"type": "base64", "data": payload, "mimeType": "image/png"},
},
],
}
]
)

binary_part = normalized[0]["content"][0]
assert binary_part["data"] == payload


def test_agui_with_tool_calls_to_agent_framework():
"""Assistant message with tool_calls is converted to FunctionCallContent."""
agui_msg = {
Expand Down Expand Up @@ -1760,3 +1831,67 @@ def test_multi_turn_with_reasoning_in_prior_snapshot(self):
assert "First answer" in texts
assert "Follow-up question" in texts
assert "Prior reasoning" not in texts


def test_parse_multimodal_media_part_base64_value_field():
"""Source with type='base64' reads data from the 'value' field per AG-UI spec."""
from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part

result = _parse_multimodal_media_part(
{"type": "image", "source": {"type": "base64", "value": "aGVsbG8=", "mimeType": "image/png"}}
)
assert result is not None
assert "aGVsbG8=" in result.uri


def test_parse_multimodal_media_part_data_source_value_field():
"""Source with type='data' reads data from the 'value' field per AG-UI spec."""
from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part

result = _parse_multimodal_media_part(
{"type": "image", "source": {"type": "data", "value": "aGVsbG8=", "mimeType": "image/png"}}
)
assert result is not None
assert "aGVsbG8=" in result.uri


def test_parse_multimodal_media_part_base64_data_field_backward_compat():
"""Source with type='base64' still supports deprecated 'data' field."""
from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part

result = _parse_multimodal_media_part(
{"type": "image", "source": {"type": "base64", "data": "aGVsbG8=", "mimeType": "image/png"}}
)
assert result is not None
assert "aGVsbG8=" in result.uri


def test_parse_multimodal_media_part_value_preferred_over_data():
"""When both 'value' and 'data' are present, 'value' takes precedence."""
from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part

result = _parse_multimodal_media_part(
{
"type": "image",
"source": {
"type": "base64",
"value": "dmFsdWU=",
"data": "ZGF0YQ==",
"mimeType": "image/png",
},
}
)
assert result is not None
# 'value' field content should be used (base64 of "value")
assert "dmFsdWU=" in result.uri


def test_parse_multimodal_media_part_unknown_source_value_fallback():
"""Unknown source type falls back to 'value' field before 'data' field."""
from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part

result = _parse_multimodal_media_part(
{"type": "image", "source": {"type": "custom", "value": "aGVsbG8=", "mimeType": "image/png"}}
)
assert result is not None
assert "aGVsbG8=" in result.uri
33 changes: 32 additions & 1 deletion python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def test_produces_reasoning_events(self):
assert events[0].message_id == "reason_1"
assert isinstance(events[1], ReasoningMessageStartEvent)
assert events[1].message_id == "reason_1"
assert events[1].role == "assistant"
assert events[1].role == "reasoning"
assert isinstance(events[2], ReasoningMessageContentEvent)
assert events[2].message_id == "reason_1"
assert events[2].delta == "The user is asking about weather, so I should call the weather tool."
Expand Down Expand Up @@ -1642,6 +1642,37 @@ def test_reasoning_distinct_ids_close_previous_block(self):
assert close[0].message_id == "block2"


class TestReasoningEventRole:
"""Tests that reasoning events use role='reasoning' per AG-UI spec."""

def test_reasoning_role_without_flow(self):
"""ReasoningMessageStartEvent uses role='reasoning' in non-flow mode."""
content = Content.from_text_reasoning(
id="reason_role_1",
text="Thinking about the question.",
)

events = _emit_text_reasoning(content)

msg_starts = [e for e in events if isinstance(e, ReasoningMessageStartEvent)]
assert len(msg_starts) == 1
assert msg_starts[0].role == "reasoning"

def test_reasoning_role_with_flow(self):
"""ReasoningMessageStartEvent uses role='reasoning' in streaming flow mode."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_role_2",
text="Reasoning in streaming mode.",
)

events = _emit_text_reasoning(content, flow)

msg_starts = [e for e in events if isinstance(e, ReasoningMessageStartEvent)]
assert len(msg_starts) == 1
assert msg_starts[0].role == "reasoning"


async def test_session_id_matches_thread_id():
"""Session created by run_agent_stream uses the client thread_id as session_id."""
from conftest import StubAgent
Expand Down
8 changes: 4 additions & 4 deletions python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading