Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ag_ui.core import BaseEvent
from agent_framework import SupportsAgentRun

from ._agent_run import run_agent_stream
from ._agent_run import PendingApprovalEntry, run_agent_stream


class AgentConfig:
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
# Populated when approval requests are emitted; consumed when responses arrive.
# Prevents bypass, function name spoofing, and replay attacks.
# Bounded to prevent unbounded growth from abandoned approval requests.
self._pending_approvals: OrderedDict[str, str] = OrderedDict()
self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict()
self._pending_approvals_max_size: int = 10_000

async def run(
Expand Down
53 changes: 47 additions & 6 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import uuid
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast

from ag_ui.core import (
BaseEvent,
Expand Down Expand Up @@ -56,6 +56,7 @@
_stringify_tool_result, # type: ignore
)
from ._utils import (
canonical_function_arguments,
convert_agui_tools_to_agent_framework,
generate_event_id,
get_conversation_id_from_update,
Expand Down Expand Up @@ -407,7 +408,33 @@ def _make_approval_tool_result_events(resolved_approval_results: list[Content])
return events


def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None:
class _PendingApproval(TypedDict):
"""Pending approval details for a requested function call."""

name: str
arguments: str | None


PendingApprovalEntry = _PendingApproval | str


def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval:
return {"name": name, "arguments": arguments}


def _pending_approval_name(entry: PendingApprovalEntry) -> str | None:
if isinstance(entry, str):
return entry
return entry["name"]


def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None:
if isinstance(entry, str):
return None
return entry["arguments"]


def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None:
"""Evict the oldest entries from the pending-approvals registry (LRU).

Only effective when *registry* is an ``OrderedDict``; plain dicts are
Expand All @@ -427,7 +454,7 @@ async def _resolve_approval_responses(
tools: list[Any],
agent: SupportsAgentRun,
run_kwargs: dict[str, Any],
pending_approvals: dict[str, str] | None = None,
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
thread_id: str = "",
) -> list[Content]:
"""Execute approved function calls and replace approval content with results.
Expand Down Expand Up @@ -480,7 +507,8 @@ async def _resolve_approval_responses(
invalid_ids.add(resp_id)
continue

pending_name = pending_approvals[registry_key]
pending_entry = pending_approvals[registry_key]
pending_name = _pending_approval_name(pending_entry)
if resp_name != pending_name:
logger.warning(
"Rejected approval response id=%s: function name mismatch (response=%s, pending=%s)",
Expand All @@ -491,6 +519,16 @@ async def _resolve_approval_responses(
invalid_ids.add(resp_id)
continue

pending_arguments = _pending_approval_arguments(pending_entry)
response_arguments = canonical_function_arguments(resp.function_call)
if pending_arguments is not None and response_arguments != pending_arguments:
logger.warning(
"Rejected approval response id=%s: function arguments mismatch",
resp_id,
)
invalid_ids.add(resp_id)
continue

# Valid — consume entry to prevent replay
del pending_approvals[registry_key]
if resp.approved:
Expand Down Expand Up @@ -714,7 +752,7 @@ async def run_agent_stream(
input_data: dict[str, Any],
agent: SupportsAgentRun,
config: AgentConfig,
pending_approvals: dict[str, str] | None = None,
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
) -> AsyncGenerator[BaseEvent]:
"""Run agent and yield AG-UI events.

Expand Down Expand Up @@ -917,7 +955,10 @@ async def run_agent_stream(
# Register pending approval requests so we can validate responses later
if content_type == "function_approval_request" and pending_approvals is not None:
if content.id and content.function_call and content.function_call.name:
pending_approvals[f"{thread_id}:{content.id}"] = content.function_call.name
pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry(
content.function_call.name,
canonical_function_arguments(content.function_call),
)
# Evict oldest entries if the registry exceeds a safe bound (LRU)
_evict_oldest_approvals(pending_approvals, max_size=10_000)
else:
Expand Down
16 changes: 16 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ def safe_json_parse(value: Any) -> dict[str, Any] | None:
return None


def canonical_function_arguments(function_call: Any) -> str | None:
"""Return a stable representation of function-call arguments."""
if function_call is None:
return None

try:
parsed_arguments = function_call.parse_arguments()
except Exception:
parsed_arguments = getattr(function_call, "arguments", None)

if parsed_arguments is None:
parsed_arguments = {}

return json.dumps(make_json_safe(parsed_arguments), sort_keys=True, separators=(",", ":"))


def get_role_value(message: Any) -> str:
"""Extract role string from a message object.

Expand Down
38 changes: 37 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
_extract_resume_payload,
_normalize_resume_interrupts,
)
from ._utils import generate_event_id, make_json_safe
from ._utils import canonical_function_arguments, generate_event_id, make_json_safe

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -324,6 +324,29 @@ def _coerce_response_for_request(request_event: Any, value: Any) -> Any | None:
return candidate


def _approval_response_matches_request(request_id: str, request_event: Any, response: Any) -> bool:
"""Check whether an approval response matches the pending approval request."""
request_data = getattr(request_event, "data", None)
if not isinstance(request_data, Content) or request_data.type != "function_approval_request":
return True

if not isinstance(response, Content) or response.type != "function_approval_response":
return False

if str(getattr(response, "id", "")) != request_id:
return False

request_call = getattr(request_data, "function_call", None)
response_call = getattr(response, "function_call", None)
if request_call is None or response_call is None:
return False

if getattr(response_call, "name", None) != getattr(request_call, "name", None):
return False

return canonical_function_arguments(response_call) == canonical_function_arguments(request_call)


def _single_pending_response_from_value(pending_events: dict[str, Any], value: Any) -> dict[str, Any]:
"""Map a scalar resume payload to the single pending request (if unambiguous)."""
if value is None or len(pending_events) != 1:
Expand All @@ -343,6 +366,13 @@ def _single_pending_response_from_value(pending_events: dict[str, Any], value: A
)
return {}

if not _approval_response_matches_request(str(request_id), request_event, coerced_value):
logger.info(
"Ignoring pending request response for request_id=%s: approval response does not match pending request",
request_id,
)
return {}

return {str(request_id): coerced_value}


Expand Down Expand Up @@ -372,6 +402,12 @@ def _coerce_responses_for_pending_requests(
_response_type_name(request_event),
)
continue
if not _approval_response_matches_request(request_key, request_event, coerced_value):
logger.info(
"Ignoring resume response for request_id=%s: approval response does not match pending request",
request_key,
)
continue
normalized[request_key] = coerced_value
return normalized

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,92 @@ async def stream_fn(
assert False, "Fabricated rejection response leaked as function_result into LLM messages"


async def test_approval_argument_mismatch_is_blocked(streaming_chat_client_stub):
"""An approval response must not execute changed arguments for the pending call."""
from agent_framework import tool
from agent_framework.ag_ui import AgentFrameworkAgent

executed_args: list[dict[str, Any]] = []

@tool(
name="update_record",
description="Update a record",
approval_mode="always_require",
)
def update_record(record_id: str, value: str) -> str:
executed_args.append({"record_id": record_id, "value": value})
return f"updated {record_id} to {value}"

async def stream_fn_approval(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[
Content.from_function_call(
name="update_record",
call_id="call_update_001",
arguments={"record_id": "alpha", "value": "approved"},
)
]
)

wrapper = AgentFrameworkAgent(
agent=Agent(
client=streaming_chat_client_stub(stream_fn_approval),
name="test_agent",
instructions="Test",
tools=[update_record],
)
)
thread_id = "thread-argument-mismatch-test"

events1: list[Any] = []
async for event in wrapper.run({"thread_id": thread_id, "messages": [{"role": "user", "content": "update"}]}):
events1.append(event)

assert any("call_update_001" in k for k in wrapper._pending_approvals)

async def stream_fn_post(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[Content.from_text(text="Done")])

wrapper.agent = Agent(
client=streaming_chat_client_stub(stream_fn_post),
name="test_agent",
instructions="Test",
tools=[update_record],
)

turn2_input: dict[str, Any] = {
"thread_id": thread_id,
"messages": [
{
"role": "user",
"content": "approve",
"function_approvals": [
{
"id": "call_update_001",
"call_id": "call_update_001",
"name": "update_record",
"approved": True,
"arguments": {"record_id": "beta", "value": "changed"},
}
],
},
],
}

events2: list[Any] = []
async for event in wrapper.run(turn2_input):
events2.append(event)

assert executed_args == []
assert any("call_update_001" in k for k in wrapper._pending_approvals), (
"Pending approval should be preserved after argument mismatch for legitimate retry"
)


async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub):
"""End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must
emit a deterministic STATE_SNAPSHOT through the full pipeline.
Expand Down
64 changes: 64 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,70 @@ async def handle_approval(self, original_request: Content, response: Content, ct
assert not resumed_finished.get("interrupt")


async def test_workflow_run_approval_argument_mismatch_keeps_interrupt_pending() -> None:
"""Workflow approval responses must not resume with changed function arguments."""

handled_responses: list[dict[str, Any]] = []

class ApprovalExecutor(Executor):
def __init__(self) -> None:
super().__init__(id="approval_executor")

@handler
async def start(self, message: Any, ctx: WorkflowContext) -> None:
del message
function_call = Content.from_function_call(
call_id="refund-call",
name="submit_refund",
arguments={"order_id": "12345", "amount": "$89.99"},
)
approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call)
await ctx.request_info(approval_request, Content, request_id="approval-1")

@response_handler
async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None:
del original_request
if response.function_call is not None:
handled_responses.append(response.function_call.parse_arguments() or {})
await ctx.yield_output("handled")

workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build()
first_events = [
event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)
]
first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump()
interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt"))
assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1

resumed_events = [
event
async for event in run_workflow_stream(
{
"messages": [
{
"role": "user",
"content": "",
"function_approvals": [
{
"approved": True,
"id": "approval-1",
"call_id": "refund-call",
"name": "submit_refund",
"arguments": {"order_id": "99999", "amount": "$1000.00"},
}
],
}
],
},
workflow,
)
]

assert handled_responses == []
resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump()
assert resumed_finished.get("interrupt")


async def test_workflow_run_approval_via_messages_denied() -> None:
"""Denied approval response sent via messages (function_approvals) should satisfy the pending request."""

Expand Down
Loading