diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index a5fcb54067e..ecde5a67e17 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -9,7 +9,7 @@ from ag_ui.core import BaseEvent from agent_framework import SupportsAgentRun -from ._agent_run import run_agent_stream +from ._agent_run import PendingApprovalEntry, run_agent_stream class AgentConfig: @@ -107,7 +107,7 @@ def __init__( # Populated when approval requests are emitted; consumed when responses arrive. # Prevents bypass, function name spoofing, and replay attacks. # Bounded to prevent unbounded growth from abandoned approval requests. - self._pending_approvals: OrderedDict[str, str] = OrderedDict() + self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict() self._pending_approvals_max_size: int = 10_000 async def run( diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 75e8e242dc6..38578f1bf2b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -8,7 +8,7 @@ import logging import uuid from collections.abc import AsyncIterable, Awaitable -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast from ag_ui.core import ( BaseEvent, @@ -56,6 +56,7 @@ _stringify_tool_result, # type: ignore ) from ._utils import ( + canonical_function_arguments, convert_agui_tools_to_agent_framework, generate_event_id, get_conversation_id_from_update, @@ -407,7 +408,33 @@ def _make_approval_tool_result_events(resolved_approval_results: list[Content]) return events -def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None: +class _PendingApproval(TypedDict): + """Pending approval details for a requested function call.""" + + name: str + arguments: str | None + + +PendingApprovalEntry = _PendingApproval | str + + +def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval: + return {"name": name, "arguments": arguments} + + +def _pending_approval_name(entry: PendingApprovalEntry) -> str | None: + if isinstance(entry, str): + return entry + return entry["name"] + + +def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None: + if isinstance(entry, str): + return None + return entry["arguments"] + + +def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None: """Evict the oldest entries from the pending-approvals registry (LRU). Only effective when *registry* is an ``OrderedDict``; plain dicts are @@ -427,7 +454,7 @@ async def _resolve_approval_responses( tools: list[Any], agent: SupportsAgentRun, run_kwargs: dict[str, Any], - pending_approvals: dict[str, str] | None = None, + pending_approvals: dict[str, PendingApprovalEntry] | None = None, thread_id: str = "", ) -> list[Content]: """Execute approved function calls and replace approval content with results. @@ -480,7 +507,8 @@ async def _resolve_approval_responses( invalid_ids.add(resp_id) continue - pending_name = pending_approvals[registry_key] + pending_entry = pending_approvals[registry_key] + pending_name = _pending_approval_name(pending_entry) if resp_name != pending_name: logger.warning( "Rejected approval response id=%s: function name mismatch (response=%s, pending=%s)", @@ -491,6 +519,16 @@ async def _resolve_approval_responses( invalid_ids.add(resp_id) continue + pending_arguments = _pending_approval_arguments(pending_entry) + response_arguments = canonical_function_arguments(resp.function_call) + if pending_arguments is not None and response_arguments != pending_arguments: + logger.warning( + "Rejected approval response id=%s: function arguments mismatch", + resp_id, + ) + invalid_ids.add(resp_id) + continue + # Valid — consume entry to prevent replay del pending_approvals[registry_key] if resp.approved: @@ -714,7 +752,7 @@ async def run_agent_stream( input_data: dict[str, Any], agent: SupportsAgentRun, config: AgentConfig, - pending_approvals: dict[str, str] | None = None, + pending_approvals: dict[str, PendingApprovalEntry] | None = None, ) -> AsyncGenerator[BaseEvent]: """Run agent and yield AG-UI events. @@ -917,7 +955,10 @@ async def run_agent_stream( # Register pending approval requests so we can validate responses later if content_type == "function_approval_request" and pending_approvals is not None: if content.id and content.function_call and content.function_call.name: - pending_approvals[f"{thread_id}:{content.id}"] = content.function_call.name + pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry( + content.function_call.name, + canonical_function_arguments(content.function_call), + ) # Evict oldest entries if the registry exceeds a safe bound (LRU) _evict_oldest_approvals(pending_approvals, max_size=10_000) else: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index c68301f7d29..db98e6bfc3f 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -56,6 +56,22 @@ def safe_json_parse(value: Any) -> dict[str, Any] | None: return None +def canonical_function_arguments(function_call: Any) -> str | None: + """Return a stable representation of function-call arguments.""" + if function_call is None: + return None + + try: + parsed_arguments = function_call.parse_arguments() + except Exception: + parsed_arguments = getattr(function_call, "arguments", None) + + if parsed_arguments is None: + parsed_arguments = {} + + return json.dumps(make_json_safe(parsed_arguments), sort_keys=True, separators=(",", ":")) + + def get_role_value(message: Any) -> str: """Extract role string from a message object. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index 211657e6885..fab6bb210cf 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -35,7 +35,7 @@ _extract_resume_payload, _normalize_resume_interrupts, ) -from ._utils import generate_event_id, make_json_safe +from ._utils import canonical_function_arguments, generate_event_id, make_json_safe logger = logging.getLogger(__name__) @@ -324,6 +324,29 @@ def _coerce_response_for_request(request_event: Any, value: Any) -> Any | None: return candidate +def _approval_response_matches_request(request_id: str, request_event: Any, response: Any) -> bool: + """Check whether an approval response matches the pending approval request.""" + request_data = getattr(request_event, "data", None) + if not isinstance(request_data, Content) or request_data.type != "function_approval_request": + return True + + if not isinstance(response, Content) or response.type != "function_approval_response": + return False + + if str(getattr(response, "id", "")) != request_id: + return False + + request_call = getattr(request_data, "function_call", None) + response_call = getattr(response, "function_call", None) + if request_call is None or response_call is None: + return False + + if getattr(response_call, "name", None) != getattr(request_call, "name", None): + return False + + return canonical_function_arguments(response_call) == canonical_function_arguments(request_call) + + def _single_pending_response_from_value(pending_events: dict[str, Any], value: Any) -> dict[str, Any]: """Map a scalar resume payload to the single pending request (if unambiguous).""" if value is None or len(pending_events) != 1: @@ -343,6 +366,13 @@ def _single_pending_response_from_value(pending_events: dict[str, Any], value: A ) return {} + if not _approval_response_matches_request(str(request_id), request_event, coerced_value): + logger.info( + "Ignoring pending request response for request_id=%s: approval response does not match pending request", + request_id, + ) + return {} + return {str(request_id): coerced_value} @@ -372,6 +402,12 @@ def _coerce_responses_for_pending_requests( _response_type_name(request_event), ) continue + if not _approval_response_matches_request(request_key, request_event, coerced_value): + logger.info( + "Ignoring resume response for request_id=%s: approval response does not match pending request", + request_key, + ) + continue normalized[request_key] = coerced_value return normalized diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index 5ea284c68de..b4b8fa04a70 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -1407,6 +1407,92 @@ async def stream_fn( assert False, "Fabricated rejection response leaked as function_result into LLM messages" +async def test_approval_argument_mismatch_is_blocked(streaming_chat_client_stub): + """An approval response must not execute changed arguments for the pending call.""" + from agent_framework import tool + from agent_framework.ag_ui import AgentFrameworkAgent + + executed_args: list[dict[str, Any]] = [] + + @tool( + name="update_record", + description="Update a record", + approval_mode="always_require", + ) + def update_record(record_id: str, value: str) -> str: + executed_args.append({"record_id": record_id, "value": value}) + return f"updated {record_id} to {value}" + + async def stream_fn_approval( + messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="update_record", + call_id="call_update_001", + arguments={"record_id": "alpha", "value": "approved"}, + ) + ] + ) + + wrapper = AgentFrameworkAgent( + agent=Agent( + client=streaming_chat_client_stub(stream_fn_approval), + name="test_agent", + instructions="Test", + tools=[update_record], + ) + ) + thread_id = "thread-argument-mismatch-test" + + events1: list[Any] = [] + async for event in wrapper.run({"thread_id": thread_id, "messages": [{"role": "user", "content": "update"}]}): + events1.append(event) + + assert any("call_update_001" in k for k in wrapper._pending_approvals) + + async def stream_fn_post( + messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[Content.from_text(text="Done")]) + + wrapper.agent = Agent( + client=streaming_chat_client_stub(stream_fn_post), + name="test_agent", + instructions="Test", + tools=[update_record], + ) + + turn2_input: dict[str, Any] = { + "thread_id": thread_id, + "messages": [ + { + "role": "user", + "content": "approve", + "function_approvals": [ + { + "id": "call_update_001", + "call_id": "call_update_001", + "name": "update_record", + "approved": True, + "arguments": {"record_id": "beta", "value": "changed"}, + } + ], + }, + ], + } + + events2: list[Any] = [] + async for event in wrapper.run(turn2_input): + events2.append(event) + + assert executed_args == [] + assert any("call_update_001" in k for k in wrapper._pending_approvals), ( + "Pending approval should be preserved after argument mismatch for legitimate retry" + ) + + async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub): """End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must emit a deterministic STATE_SNAPSHOT through the full pipeline. diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index a52cc4dd2cd..235a16c6e1c 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -1352,6 +1352,70 @@ async def handle_approval(self, original_request: Content, response: Content, ct assert not resumed_finished.get("interrupt") +async def test_workflow_run_approval_argument_mismatch_keeps_interrupt_pending() -> None: + """Workflow approval responses must not resume with changed function arguments.""" + + handled_responses: list[dict[str, Any]] = [] + + class ApprovalExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="approval_executor") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + del message + function_call = Content.from_function_call( + call_id="refund-call", + name="submit_refund", + arguments={"order_id": "12345", "amount": "$89.99"}, + ) + approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call) + await ctx.request_info(approval_request, Content, request_id="approval-1") + + @response_handler + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: + del original_request + if response.function_call is not None: + handled_responses.append(response.function_call.parse_arguments() or {}) + await ctx.yield_output("handled") + + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() + first_events = [ + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) + ] + first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump() + interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) + assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 + + resumed_events = [ + event + async for event in run_workflow_stream( + { + "messages": [ + { + "role": "user", + "content": "", + "function_approvals": [ + { + "approved": True, + "id": "approval-1", + "call_id": "refund-call", + "name": "submit_refund", + "arguments": {"order_id": "99999", "amount": "$1000.00"}, + } + ], + } + ], + }, + workflow, + ) + ] + + assert handled_responses == [] + resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump() + assert resumed_finished.get("interrupt") + + async def test_workflow_run_approval_via_messages_denied() -> None: """Denied approval response sent via messages (function_approvals) should satisfy the pending request."""