From 275df266f3f8485542803b7bce49b8ef9e6306f0 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 8 Jun 2026 17:26:58 +0200 Subject: [PATCH 1/2] Filter MCP tool kwargs to declared params via allowlist Previously MCPTool combined framework runtime kwargs (from FunctionInvocationContext.kwargs) with the LLM-supplied arguments and stripped only a hardcoded denylist of known framework keys before forwarding to the MCP server. Any new framework-injected kwarg leaked to the server unless the denylist was updated. Switch to an allowlist built from each tool's declared parameters (inputSchema.properties). Only declared params are forwarded; everything else is stripped. Add an `additional_tool_argument_names` constructor argument so users can opt extra names back in, globally (Sequence[str]) and/or per remote tool name (Mapping with reserved "*" global key). The existing denylist is kept as a safety net for framework-named params a server declares in its schema; explicitly opted-in extras always win. The reserved _meta handling is unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/packages/core/AGENTS.md | 2 + python/packages/core/agent_framework/_mcp.py | 178 +++++++++++++---- python/packages/core/tests/core/test_mcp.py | 189 +++++++++++++++++++ 3 files changed, 333 insertions(+), 36 deletions(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index aadda788b0..ca0c5843a3 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -80,6 +80,8 @@ agent_framework/ - **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s. - **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses. +- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument. +- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins. - **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields: - `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies. - `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel. diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 8c2bdaefac..ac4c1fb97d 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -70,6 +70,29 @@ class MCPSpecificApproval(TypedDict, total=False): _MCP_REMOTE_NAME_KEY = "_mcp_remote_name" _MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name" +# Reserved key in an ``additional_tool_argument_names`` mapping that applies its +# values to every tool on the server rather than a single named tool. +_MCP_GLOBAL_EXTRA_ARGS_KEY = "*" +# Framework kwargs that flow through the function-invocation pipeline (via +# ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP +# server: they are internal objects that the MCP SDK cannot serialize. These are +# always dropped as a safety net, even when a tool oddly declares one of them. +# - chat_options/tools/tool_choice/session/thread: framework runtime objects. +# - conversation_id: internal tracking ID used by services like Azure AI. +# - options: metadata/store used by AG-UI for Azure AI client requirements. +# - response_format: a Pydantic model class for structured output (not serializable). +# - _meta: reserved key extracted separately as MCP request metadata. +_MCP_FRAMEWORK_DENYLIST: frozenset[str] = frozenset({ + "chat_options", + "tools", + "tool_choice", + "session", + "thread", + "conversation_id", + "options", + "response_format", + "_meta", +}) _mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers") MCP_DEFAULT_TIMEOUT = 30 MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5 @@ -135,6 +158,31 @@ def _build_prefixed_mcp_name( return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix +def _normalize_additional_tool_argument_names( + additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None, +) -> tuple[set[str], dict[str, set[str]]]: + """Split user-supplied extra argument names into global and per-tool sets. + + Accepts either a sequence (applied to every tool) or a mapping keyed by remote + tool name, where the reserved key ``"*"`` is treated as global. Returns a + ``(global_extras, per_tool_extras)`` tuple. + """ + if additional_tool_argument_names is None: + return set(), {} + if isinstance(additional_tool_argument_names, str): + return {additional_tool_argument_names}, {} + if isinstance(additional_tool_argument_names, Mapping): + global_extras: set[str] = set() + per_tool_extras: dict[str, set[str]] = {} + for tool_name, names in additional_tool_argument_names.items(): + if tool_name == _MCP_GLOBAL_EXTRA_ARGS_KEY: + global_extras.update(names) + else: + per_tool_extras[tool_name] = set(names) + return global_extras, per_tool_extras + return set(additional_tool_argument_names), {} + + def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None: """Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s).""" carrier: dict[str, str] = {} @@ -294,6 +342,7 @@ def __init__( client: SupportsChatGetResponse | None = None, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, + additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, ) -> None: """Initialize the MCP Tool base. @@ -328,6 +377,10 @@ def __init__( task_options: Options controlling how long-running MCP tasks are driven for tools that advertise ``execution.taskSupport == "required"``. When ``None``, the defaults from :class:`MCPTaskOptions` are used. + additional_tool_argument_names: Extra argument names to forward to the MCP server + in addition to each tool's declared parameters. A ``Sequence[str]`` applies to + every tool; a ``Mapping[str, Sequence[str]]`` is keyed by remote tool name with + ``"*"`` as a global key. See the transport subclasses for full details. """ self.name = name self.description = description or "" @@ -355,6 +408,10 @@ def __init__( self._functions: list[FunctionTool] = [] self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {} self._tool_task_support_by_name: dict[str, str] = {} + self._tool_param_names_by_name: dict[str, set[str]] = {} + self._global_extra_arg_names, self._tool_extra_arg_names = _normalize_additional_tool_argument_names( + additional_tool_argument_names + ) self.is_connected: bool = False self._tools_loaded: bool = False self._prompts_loaded: bool = False @@ -1229,6 +1286,7 @@ async def _load_tools_locked(self) -> None: existing_names = {func.name for func in self._functions} tool_call_meta_by_name: dict[str, dict[str, Any]] = {} tool_task_support_by_name: dict[str, str] = {} + tool_param_names_by_name: dict[str, set[str]] = {} params: types.PaginatedRequestParams | None = None while True: @@ -1288,6 +1346,11 @@ async def _load_tools_locked(self) -> None: if input_schema.get("type") == "object" and "properties" not in input_schema: input_schema["properties"] = {} + schema_properties = input_schema.get("properties") + tool_param_names_by_name[tool.name] = ( + set(schema_properties) if isinstance(schema_properties, dict) else set() + ) + async def _call_tool_with_runtime_kwargs( ctx: FunctionInvocationContext, *, @@ -1320,6 +1383,7 @@ async def _call_tool_with_runtime_kwargs( self._tool_call_meta_by_name = tool_call_meta_by_name self._tool_task_support_by_name = tool_task_support_by_name + self._tool_param_names_by_name = tool_param_names_by_name async def _close_on_owner(self) -> None: # Cancel any pending reload tasks before tearing down the session. @@ -1530,10 +1594,14 @@ async def _call_tool_with_retries( raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.") + def _resolved_extra_args(self, tool_name: str) -> set[str]: + """Return the user-configured extra argument names allowed for a tool.""" + return self._global_extra_arg_names | self._tool_extra_arg_names.get(tool_name, set()) + def _prepare_call_kwargs( self, tool_name: str, kwargs: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, Any] | None]: - """Filter framework-only kwargs and build the merged MCP request metadata.""" + """Filter kwargs down to the tool's arguments and build the merged MCP request metadata.""" raw_user_meta: object | None = kwargs.get("_meta") user_meta: dict[str, Any] | None = None if raw_user_meta is not None and not isinstance(raw_user_meta, dict): @@ -1546,27 +1614,28 @@ def _prepare_call_kwargs( raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.") user_meta[key] = value - # Filter out framework kwargs that cannot be serialized by the MCP SDK. - # These are internal objects passed through the function invocation pipeline - # that should not be forwarded to external MCP servers. - # conversation_id is an internal tracking ID used by services like Azure AI. - # options contains metadata/store used by AG-UI for Azure AI client requirements. - # response_format is a Pydantic model class used for structured output (not serializable). + # Allowlist: forward only the tool's declared parameters (from inputSchema.properties) + # plus any user-configured extra argument names. Everything else - notably the + # framework runtime kwargs injected through the function-invocation pipeline - is + # stripped so it is never forwarded to the MCP server. Tools that declare no usable + # properties forward only the user-configured extras. + # + # The extra names come exclusively from additional_tool_argument_names, which is set in + # user code at construction time; there is no per-call override, so a model-issued tool + # call cannot change which names are allowed through. + # + # The framework denylist acts as a safety net for keys a server *declares* in its + # schema that collide with internal, non-serializable framework objects (e.g. a tool + # that declares a parameter literally named "thread"): such declared-but-denylisted + # keys are dropped. Names the user explicitly opts in via additional_tool_argument_names + # always win. The reserved _meta key is handled separately above and never forwarded as + # an argument. + declared = self._tool_param_names_by_name.get(tool_name, set()) + extras = self._resolved_extra_args(tool_name) filtered_kwargs = { k: v for k, v in kwargs.items() - if k - not in { - "chat_options", - "tools", - "tool_choice", - "session", - "thread", - "conversation_id", - "options", - "response_format", - "_meta", - } + if k != "_meta" and (k in extras or (k in declared and k not in _MCP_FRAMEWORK_DENYLIST)) } # Some MCP proxies require their tools/list metadata to be echoed on tools/call. @@ -1643,9 +1712,7 @@ async def call_tool_as_task(self, tool_name: str, **kwargs: Any) -> str | list[C return parser(fallback_result) if task_id is None: - raise ToolExecutionException( - f"MCP server did not return a task_id or fallback result for '{tool_name}'." - ) + raise ToolExecutionException(f"MCP server did not return a task_id or fallback result for '{tool_name}'.") # Track to completion: poll until terminal, then fetch payload. Never re-issue # tools/call past this point; reconnect-and-retry only against the same task_id. @@ -1765,9 +1832,7 @@ async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult: transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)}) while True: - request = types.ClientRequest( - types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)) - ) + request = types.ClientRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))) try: # GetTaskResult.ttl is required-but-Optional in the SDK; coerce below. lenient = await self._send_with_one_reconnect( @@ -1775,9 +1840,7 @@ async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult: ) except McpError as ex: if ex.error.code in transient_codes: - logger.debug( - "Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id - ) + logger.debug("Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id) await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds()) continue # Hard server error mid-poll: task may still be running. @@ -1906,9 +1969,7 @@ async def _send_with_one_reconnect( if not self._is_connection_lost(ex): raise if attempt < _MCP_RECONNECT_ATTEMPTS - 1: - logger.info( - "MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id - ) + logger.info("MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id) try: await self.connect(reset=True) except Exception as reconn_ex: @@ -1967,9 +2028,7 @@ async def _try_cancel_task(self, task_id: str) -> None: """ from mcp import types - request = types.ClientRequest( - types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)) - ) + request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))) try: await asyncio.wait_for( self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr] @@ -1979,8 +2038,7 @@ async def _try_cancel_task(self, task_id: str) -> None: raise except asyncio.TimeoutError: logger.warning( - "Best-effort tasks/cancel for '%s' timed out after %.1fs; " - "remote task may still be running.", + "Best-effort tasks/cancel for '%s' timed out after %.1fs; remote task may still be running.", task_id, _MCP_TASK_CANCEL_TIMEOUT.total_seconds(), ) @@ -2153,6 +2211,7 @@ def __init__( client: SupportsChatGetResponse | None = None, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, + additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, **kwargs: Any, ) -> None: """Initialize the MCP stdio tool. @@ -2199,6 +2258,20 @@ def __init__( client: The chat client to use for sampling. task_options: Options for tools that advertise ``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`. + additional_tool_argument_names: Extra argument names to forward to the MCP server in + addition to each tool's declared parameters (from its ``inputSchema.properties``). + By default only declared parameters are sent; framework runtime kwargs injected + through the function-invocation pipeline are stripped. Use this to opt specific + keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a + ``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key + ``"*"`` applies to every tool. This is configured only here in user code; there is + no per-call override, so a model-issued tool call cannot change which names pass + through. To use a server that accepts ``additionalProperties: true``, list the + extra names here and then either (1) manually extend that tool's ``inputSchema`` + (via the ``.functions`` list after connecting) so the model is prompted to supply + them, or (2) supply the values yourself through ``function_invocation_kwargs``. If + a name is supplied via both the model and ``function_invocation_kwargs``, the + model-supplied value wins. kwargs: Any extra arguments to pass to the stdio client. """ super().__init__( @@ -2216,6 +2289,7 @@ def __init__( parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, task_options=task_options, + additional_tool_argument_names=additional_tool_argument_names, ) self.command = command self.args = args or [] @@ -2295,6 +2369,7 @@ def __init__( http_client: AsyncClient | None = None, header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None, task_options: MCPTaskOptions | None = None, + additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, **kwargs: Any, ) -> None: """Initialize the MCP streamable HTTP tool. @@ -2349,6 +2424,20 @@ def __init__( agent middleware) without creating a separate ``httpx.AsyncClient``. task_options: Options for tools that advertise ``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`. + additional_tool_argument_names: Extra argument names to forward to the MCP server in + addition to each tool's declared parameters (from its ``inputSchema.properties``). + By default only declared parameters are sent; framework runtime kwargs injected + through the function-invocation pipeline are stripped. Use this to opt specific + keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a + ``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key + ``"*"`` applies to every tool. This is configured only here in user code; there is + no per-call override, so a model-issued tool call cannot change which names pass + through. To use a server that accepts ``additionalProperties: true``, list the + extra names here and then either (1) manually extend that tool's ``inputSchema`` + (via the ``.functions`` list after connecting) so the model is prompted to supply + them, or (2) supply the values yourself through ``function_invocation_kwargs``. If + a name is supplied via both the model and ``function_invocation_kwargs``, the + model-supplied value wins. kwargs: Additional keyword arguments (accepted for backward compatibility but not used). """ super().__init__( @@ -2366,6 +2455,7 @@ def __init__( parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, task_options=task_options, + additional_tool_argument_names=additional_tool_argument_names, ) self.url = url self.terminate_on_close = terminate_on_close @@ -2492,6 +2582,7 @@ def __init__( client: SupportsChatGetResponse | None = None, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, + additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, **kwargs: Any, ) -> None: """Initialize the MCP WebSocket tool. @@ -2536,6 +2627,20 @@ def __init__( client: The chat client to use for sampling. task_options: Options for tools that advertise ``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`. + additional_tool_argument_names: Extra argument names to forward to the MCP server in + addition to each tool's declared parameters (from its ``inputSchema.properties``). + By default only declared parameters are sent; framework runtime kwargs injected + through the function-invocation pipeline are stripped. Use this to opt specific + keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a + ``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key + ``"*"`` applies to every tool. This is configured only here in user code; there is + no per-call override, so a model-issued tool call cannot change which names pass + through. To use a server that accepts ``additionalProperties: true``, list the + extra names here and then either (1) manually extend that tool's ``inputSchema`` + (via the ``.functions`` list after connecting) so the model is prompted to supply + them, or (2) supply the values yourself through ``function_invocation_kwargs``. If + a name is supplied via both the model and ``function_invocation_kwargs``, the + model-supplied value wins. kwargs: Any extra arguments to pass to the WebSocket client. """ super().__init__( @@ -2553,6 +2658,7 @@ def __init__( parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, task_options=task_options, + additional_tool_argument_names=additional_tool_argument_names, ) self.url = url self._client_kwargs = kwargs diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 52a3f05f2c..8c101b0c44 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -30,6 +30,7 @@ MCPTool, _build_prefixed_mcp_name, _get_input_model_from_mcp_prompt, + _normalize_additional_tool_argument_names, _normalize_mcp_name, _should_propagate_cancelled_error, logger, @@ -6057,3 +6058,191 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An # endregion + + +# region additional_tool_argument_names / allowlist filtering + + +def test_normalize_additional_tool_argument_names_none() -> None: + global_extras, per_tool = _normalize_additional_tool_argument_names(None) + assert global_extras == set() + assert per_tool == {} + + +def test_normalize_additional_tool_argument_names_sequence() -> None: + global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"]) + assert global_extras == {"a", "b"} + assert per_tool == {} + + +def test_normalize_additional_tool_argument_names_single_string() -> None: + # A bare string must be treated as a single name, not split into characters. + global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id") + assert global_extras == {"conversation_id"} + assert per_tool == {} + + +def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None: + global_extras, per_tool = _normalize_additional_tool_argument_names({ + "*": ["g1"], + "tool_a": ["a1", "a2"], + "tool_b": ["b1"], + }) + assert global_extras == {"g1"} + assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}} + + +def test_prepare_call_kwargs_strips_undeclared_arguments() -> None: + server = MCPTool(name="test_server") + server._tool_param_names_by_name = {"test_tool": {"param"}} + + filtered, meta = server._prepare_call_kwargs( + "test_tool", + {"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1}, + ) + + assert filtered == {"param": "value"} + assert meta is None + + +def test_prepare_call_kwargs_global_extras_allowed() -> None: + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server._tool_param_names_by_name = {"test_tool": {"param"}} + + filtered, _ = server._prepare_call_kwargs( + "test_tool", + {"param": "value", "conversation_id": "c", "options": {}}, + ) + + assert filtered == {"param": "value", "conversation_id": "c"} + + +def test_prepare_call_kwargs_per_tool_and_global_extras() -> None: + server = MCPTool( + name="test_server", + additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]}, + ) + server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}} + + filtered, _ = server._prepare_call_kwargs( + "test_tool", + {"param": "v", "conversation_id": "c", "custom": "y", "thread": object()}, + ) + assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"} + + # The per-tool extra does not leak to other tools; the global one still applies. + filtered_other, _ = server._prepare_call_kwargs( + "other_tool", + {"x": 1, "conversation_id": "c", "custom": "y"}, + ) + assert filtered_other == {"x": 1, "conversation_id": "c"} + + +def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None: + # The denylist is a safety net for framework-named params a server *declares* in its + # schema: they are dropped so internal objects never leak. Names explicitly opted in + # via extras always win. + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server._tool_param_names_by_name = {"test_tool": {"param", "thread"}} + + filtered, _ = server._prepare_call_kwargs( + "test_tool", + {"param": "v", "thread": object(), "conversation_id": "c"}, + ) + # "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept. + assert filtered == {"param": "v", "conversation_id": "c"} + + +def test_prepare_call_kwargs_extras_override_denylist() -> None: + # A user that explicitly opts a denylisted name back in takes responsibility for it. + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server._tool_param_names_by_name = {"test_tool": {"param"}} + + filtered, _ = server._prepare_call_kwargs( + "test_tool", + {"param": "v", "conversation_id": "c", "thread": object()}, + ) + assert filtered == {"param": "v", "conversation_id": "c"} + + +def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None: + server = MCPTool(name="test_server") + server._tool_param_names_by_name = {"test_tool": set()} + + filtered, _ = server._prepare_call_kwargs( + "test_tool", + {"conversation_id": "c", "thread": object(), "stray": 1}, + ) + assert filtered == {} + + +def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None: + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + # No entry in _tool_param_names_by_name for this tool name. + + filtered, _ = server._prepare_call_kwargs( + "unknown_tool", + {"conversation_id": "c", "other": 1}, + ) + assert filtered == {"conversation_id": "c"} + + +def test_prepare_call_kwargs_extracts_meta() -> None: + server = MCPTool(name="test_server") + server._tool_param_names_by_name = {"test_tool": {"param"}} + + filtered, meta = server._prepare_call_kwargs( + "test_tool", + {"param": "v", "_meta": {"trace": "abc"}}, + ) + assert filtered == {"param": "v"} + assert meta is not None + assert meta.get("trace") == "abc" + + +async def test_call_tool_forwards_only_declared_arguments() -> None: + """End-to-end: framework runtime kwargs are stripped before reaching the server.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"]) + async with server: + await server.load_tools() + session_mock = server.session + await server.call_tool( + "test_tool", + param="value", + conversation_id="c", + thread=object(), + response_format=object(), + ) + + session_mock.call_tool.assert_called_once() + _, call_kwargs = session_mock.call_tool.call_args + assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"} + + +# endregion From 6ec7313434dd5c9932421e4240c3ef2cdeb3ac09 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 8 Jun 2026 17:39:57 +0200 Subject: [PATCH 2/2] Address MCP allowlist review comments and fix reload arg loss - Fix pyright reportUnknownArgumentType in _load_tools (cast schema properties). - Register declared param names before the existing-tool skip guard so that tool-list reloads preserve the allowlist for already-loaded tools (previously unchanged tools silently dropped all declared args after a background reload). - Handle bare-string values in an additional_tool_argument_names mapping instead of iterating their characters. - Clarify the framework denylist comment: explicit extras override the denylist. - Make the extras-override-denylist test unambiguous (opt in a denylisted name). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/packages/core/agent_framework/_mcp.py | 38 ++++++++++++-------- python/packages/core/tests/core/test_mcp.py | 22 +++++++++--- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index ac4c1fb97d..784c618302 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -75,8 +75,10 @@ class MCPSpecificApproval(TypedDict, total=False): _MCP_GLOBAL_EXTRA_ARGS_KEY = "*" # Framework kwargs that flow through the function-invocation pipeline (via # ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP -# server: they are internal objects that the MCP SDK cannot serialize. These are -# always dropped as a safety net, even when a tool oddly declares one of them. +# server: they are internal objects that the MCP SDK cannot serialize. They are +# dropped as a safety net when a tool declares one of them in its schema, unless +# the user explicitly opts the name back in via ``additional_tool_argument_names`` +# (explicit extras always win over the denylist). # - chat_options/tools/tool_choice/session/thread: framework runtime objects. # - conversation_id: internal tracking ID used by services like Azure AI. # - options: metadata/store used by AG-UI for Azure AI client requirements. @@ -164,7 +166,8 @@ def _normalize_additional_tool_argument_names( """Split user-supplied extra argument names into global and per-tool sets. Accepts either a sequence (applied to every tool) or a mapping keyed by remote - tool name, where the reserved key ``"*"`` is treated as global. Returns a + tool name, where the reserved key ``"*"`` is treated as global. Mapping values + may be a sequence or a single string. Returns a ``(global_extras, per_tool_extras)`` tuple. """ if additional_tool_argument_names is None: @@ -175,10 +178,12 @@ def _normalize_additional_tool_argument_names( global_extras: set[str] = set() per_tool_extras: dict[str, set[str]] = {} for tool_name, names in additional_tool_argument_names.items(): + # Treat a bare string value as a single name rather than iterating its characters. + names_set = {names} if isinstance(names, str) else set(names) if tool_name == _MCP_GLOBAL_EXTRA_ARGS_KEY: - global_extras.update(names) + global_extras.update(names_set) else: - per_tool_extras[tool_name] = set(names) + per_tool_extras[tool_name] = names_set return global_extras, per_tool_extras return set(additional_tool_argument_names), {} @@ -1329,14 +1334,6 @@ async def _load_tools_locked(self) -> None: if task_support is not None: tool_task_support_by_name[tool.name] = task_support - normalized_name = _normalize_mcp_name(tool.name) - local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) - - # Skip if already loaded - if local_name in existing_names: - continue - - approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) # Normalize inputSchema: ensure "properties" exists for object schemas. # Some MCP servers (e.g. zero-argument tools) omit "properties", # which causes OpenAI API to reject the schema with a 400 error. @@ -1346,11 +1343,24 @@ async def _load_tools_locked(self) -> None: if input_schema.get("type") == "object" and "properties" not in input_schema: input_schema["properties"] = {} + # Register declared param names before the existing-tool skip below so that + # reloads (e.g. notifications/tools/list_changed) preserve the allowlist for + # tools that are already loaded, consistent with tool_call_meta_by_name and + # tool_task_support_by_name above. schema_properties = input_schema.get("properties") tool_param_names_by_name[tool.name] = ( - set(schema_properties) if isinstance(schema_properties, dict) else set() + set(cast(dict[str, Any], schema_properties)) if isinstance(schema_properties, dict) else set() ) + normalized_name = _normalize_mcp_name(tool.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) + + # Skip if already loaded + if local_name in existing_names: + continue + + approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) + async def _call_tool_with_runtime_kwargs( ctx: FunctionInvocationContext, *, diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 8c101b0c44..7c45296cbb 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -6092,6 +6092,16 @@ def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> N assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}} +def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None: + # A bare string mapping value is a single name, not an iterable of characters. + global_extras, per_tool = _normalize_additional_tool_argument_names({ + "*": "conversation_id", + "tool_a": "custom", + }) + assert global_extras == {"conversation_id"} + assert per_tool == {"tool_a": {"custom"}} + + def test_prepare_call_kwargs_strips_undeclared_arguments() -> None: server = MCPTool(name="test_server") server._tool_param_names_by_name = {"test_tool": {"param"}} @@ -6154,15 +6164,19 @@ def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None: def test_prepare_call_kwargs_extras_override_denylist() -> None: - # A user that explicitly opts a denylisted name back in takes responsibility for it. - server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + # Opting a denylisted framework name back in via extras takes precedence over the + # denylist safety net. "thread" is on the framework denylist, but an explicit extra wins. + server = MCPTool(name="test_server", additional_tool_argument_names=["thread"]) server._tool_param_names_by_name = {"test_tool": {"param"}} + sentinel = object() filtered, _ = server._prepare_call_kwargs( "test_tool", - {"param": "v", "conversation_id": "c", "thread": object()}, + {"param": "v", "thread": sentinel, "conversation_id": "c"}, ) - assert filtered == {"param": "v", "conversation_id": "c"} + # "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted, + # not declared, and not opted in -> dropped. + assert filtered == {"param": "v", "thread": sentinel} def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None: