Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from opentelemetry.trace import get_current_span
from pydantic import BaseModel, create_model

from ._tools import AIFunction, HostedMCPSpecificApproval
Expand Down Expand Up @@ -810,6 +811,7 @@ def __init__(
terminate_on_close: bool | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
include_traceparent: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
Expand Down Expand Up @@ -843,6 +845,7 @@ def __init__(
sse_read_timeout: The timeout for reading from the SSE stream.
terminate_on_close: Close the transport when the MCP client is terminated.
chat_client: The chat client to use for sampling.
include_traceparent: Whether to include the traceparent header in requests.
kwargs: Any extra arguments to pass to the SSE client.
"""
super().__init__(
Expand All @@ -863,6 +866,7 @@ def __init__(
self.sse_read_timeout = sse_read_timeout
self.terminate_on_close = terminate_on_close
self._client_kwargs = kwargs
self._include_traceparent = include_traceparent

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.
Expand All @@ -883,6 +887,10 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
args["terminate_on_close"] = self.terminate_on_close
if self._client_kwargs:
args.update(self._client_kwargs)
if self._include_traceparent:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s better to build this around per-request dynamic headers, not mutate the tool's static dict. Capture the static headers once in MCPStreamableHTTPTool, then pass an optional dynamic_headers callable down into streamablehttp_client / StreamableHTTPTransport. The callable can pull the current span, make sure the context is valid, derive the sampled flag, and return {"traceparent": ...} only when it should be sent. We can then have _prepare_request_headers merge that callable's output on every POST so each RPC carries the active span while the static headers stay untouched. That then handles a possible KeyError when no headers were supplied and keeps the trace flag accurate instead of hardcoding 01.

span_ctx = get_current_span().get_span_context()
args["headers"]["traceparent"] = \
f"00-{span_ctx.trace_id:032x}-{span_ctx.span_id:016x}-01"
return streamablehttp_client(**args)


Expand Down
Loading