diff --git a/docs/servers/context.mdx b/docs/servers/context.mdx index 05e65163d0..cf365cd28a 100644 --- a/docs/servers/context.mdx +++ b/docs/servers/context.mdx @@ -300,6 +300,31 @@ async def my_tool(ctx: Context) -> None: ... ``` +### Transport + +The `ctx.transport` property indicates which transport is being used to run the server. This is useful when your tool needs to behave differently depending on whether the server is running over STDIO, SSE, or Streamable HTTP. For example, you might want to return shorter responses over STDIO or adjust timeout behavior based on transport characteristics. + +The transport type is set once when the server starts and remains constant for the server's lifetime. It returns `None` when called outside of a server context (for example, in unit tests or when running code outside of an MCP request). + +```python +from fastmcp import FastMCP, Context + +mcp = FastMCP("example") + +@mcp.tool +def connection_info(ctx: Context) -> str: + if ctx.transport == "stdio": + return "Connected via STDIO" + elif ctx.transport == "sse": + return "Connected via SSE" + elif ctx.transport == "streamable-http": + return "Connected via Streamable HTTP" + else: + return "Transport unknown" +``` + +**Property signature:** `ctx.transport -> Literal["stdio", "sse", "streamable-http"] | None` + ### MCP Request Access metadata about the current request and client. diff --git a/loq.toml b/loq.toml index dcbd5b57c9..48e868b310 100644 --- a/loq.toml +++ b/loq.toml @@ -60,7 +60,7 @@ max_lines = 590 [[rules]] path = "src/fastmcp/server/context.py" -max_lines = 1246 +max_lines = 1272 [[rules]] path = "tests/test_mcp_config.py" @@ -288,7 +288,7 @@ max_lines = 1019 [[rules]] path = "src/fastmcp/server/server.py" -max_lines = 2676 +max_lines = 2682 [[rules]] path = "tests/deprecated/test_import_server.py" @@ -312,7 +312,7 @@ max_lines = 1584 [[rules]] path = "docs/servers/context.mdx" -max_lines = 623 +max_lines = 648 [[rules]] path = "src/fastmcp/resources/resource.py" diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index f65cba8f28..f0a9dce5a8 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -77,6 +77,23 @@ _current_context: ContextVar[Context | None] = ContextVar("context", default=None) # type: ignore[assignment] +TransportType = Literal["stdio", "sse", "streamable-http"] +_current_transport: ContextVar[TransportType | None] = ContextVar( + "transport", default=None +) + + +def set_transport( + transport: TransportType, +) -> Token[TransportType | None]: + """Set the current transport type. Returns token for reset.""" + return _current_transport.set(transport) + + +def reset_transport(token: Token[TransportType | None]) -> None: + """Reset transport to previous value.""" + _current_transport.reset(token) + @dataclass class LogData: @@ -401,6 +418,15 @@ async def log( related_request_id=self.request_id, ) + @property + def transport(self) -> TransportType | None: + """Get the current transport type. + + Returns the transport type used to run this server: "stdio", "sse", + or "streamable-http". Returns None if called outside of a server context. + """ + return _current_transport.get() + @property def client_id(self) -> str | None: """Get the client ID if available.""" diff --git a/src/fastmcp/server/http.py b/src/fastmcp/server/http.py index 825fcd0aa0..324fed1ec9 100644 --- a/src/fastmcp/server/http.py +++ b/src/fastmcp/server/http.py @@ -85,7 +85,7 @@ def set_http_request(request: Request) -> Generator[Request, None, None]: class RequestContextMiddleware: """ - Middleware that stores each request in a ContextVar + Middleware that stores each request in a ContextVar and sets transport type. """ def __init__(self, app): @@ -93,8 +93,17 @@ def __init__(self, app): async def __call__(self, scope, receive, send): if scope["type"] == "http": - with set_http_request(Request(scope)): - await self.app(scope, receive, send) + from fastmcp.server.context import reset_transport, set_transport + + # Get transport type from app state (set during app creation) + transport_type = getattr(scope["app"].state, "transport_type", None) + transport_token = set_transport(transport_type) if transport_type else None + try: + with set_http_request(Request(scope)): + await self.app(scope, receive, send) + finally: + if transport_token is not None: + reset_transport(transport_token) else: await self.app(scope, receive, send) @@ -255,6 +264,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: # Store the FastMCP server instance on the Starlette app state app.state.fastmcp_server = server app.state.path = sse_path + app.state.transport_type = "sse" return app @@ -366,7 +376,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: ) # Store the FastMCP server instance on the Starlette app state app.state.fastmcp_server = server - app.state.path = streamable_http_path + app.state.transport_type = "streamable-http" return app diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index c48d18d8cd..d1b48ae9a9 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -2114,32 +2114,38 @@ async def run_stdio_async( log_level: Log level for the server stateless: Whether to run in stateless mode (no session initialization) """ + from fastmcp.server.context import reset_transport, set_transport + # Display server banner if show_banner: log_server_banner(server=self) - with temporary_log_level(log_level): - async with self._lifespan_manager(): - async with stdio_server() as (read_stream, write_stream): - mode = " (stateless)" if stateless else "" - logger.info( - f"Starting MCP server {self.name!r} with transport 'stdio'{mode}" - ) - - # Build experimental capabilities - experimental_capabilities = get_task_capabilities() - - await self._mcp_server.run( - read_stream, - write_stream, - self._mcp_server.create_initialization_options( - notification_options=NotificationOptions( - tools_changed=True + token = set_transport("stdio") + try: + with temporary_log_level(log_level): + async with self._lifespan_manager(): + async with stdio_server() as (read_stream, write_stream): + mode = " (stateless)" if stateless else "" + logger.info( + f"Starting MCP server {self.name!r} with transport 'stdio'{mode}" + ) + + # Build experimental capabilities + experimental_capabilities = get_task_capabilities() + + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options( + notification_options=NotificationOptions( + tools_changed=True + ), + experimental_capabilities=experimental_capabilities, ), - experimental_capabilities=experimental_capabilities, - ), - stateless=stateless, - ) + stateless=stateless, + ) + finally: + reset_transport(token) async def run_http_async( self, diff --git a/tests/server/test_context.py b/tests/server/test_context.py index 45275f835a..c2f59c7f1f 100644 --- a/tests/server/test_context.py +++ b/tests/server/test_context.py @@ -6,6 +6,8 @@ from fastmcp.server.context import ( Context, _parse_model_preferences, + reset_transport, + set_transport, ) from fastmcp.server.server import FastMCP @@ -180,3 +182,112 @@ def test_request_context_meta_none(self, context): assert retrieved_meta is None request_ctx.reset(token) + + +class TestTransport: + """Test suite for Context transport property.""" + + def test_transport_returns_none_outside_server_context(self, context): + """Test that transport returns None when not in a server context.""" + assert context.transport is None + + def test_transport_returns_stdio(self, context): + """Test that transport returns 'stdio' when set.""" + token = set_transport("stdio") + try: + assert context.transport == "stdio" + finally: + reset_transport(token) + + def test_transport_returns_sse(self, context): + """Test that transport returns 'sse' when set.""" + token = set_transport("sse") + try: + assert context.transport == "sse" + finally: + reset_transport(token) + + def test_transport_returns_streamable_http(self, context): + """Test that transport returns 'streamable-http' when set.""" + token = set_transport("streamable-http") + try: + assert context.transport == "streamable-http" + finally: + reset_transport(token) + + def test_transport_reset(self, context): + """Test that transport resets correctly.""" + assert context.transport is None + token = set_transport("stdio") + assert context.transport == "stdio" + reset_transport(token) + assert context.transport is None + + +class TestTransportIntegration: + """Integration tests for transport property with actual server/client.""" + + async def test_transport_in_tool_via_client(self): + """Test that transport is accessible from within a tool via Client.""" + from fastmcp import Client + + mcp = FastMCP("test") + observed_transport = None + + @mcp.tool + def get_transport(ctx: Context) -> str: + nonlocal observed_transport + observed_transport = ctx.transport + return observed_transport or "none" + + # Client uses in-memory transport which doesn't set transport type + # so we expect None here (the transport is only set by run_* methods) + async with Client(mcp) as client: + result = await client.call_tool("get_transport", {}) + assert observed_transport is None + assert result.data == "none" + + async def test_transport_set_manually_is_visible_in_tool(self): + """Test that manually set transport is visible from within a tool.""" + from fastmcp import Client + + mcp = FastMCP("test") + observed_transport = None + + @mcp.tool + def get_transport(ctx: Context) -> str: + nonlocal observed_transport + observed_transport = ctx.transport + return observed_transport or "none" + + # Manually set transport before running + token = set_transport("stdio") + try: + async with Client(mcp) as client: + result = await client.call_tool("get_transport", {}) + assert observed_transport == "stdio" + assert result.data == "stdio" + finally: + reset_transport(token) + + async def test_transport_set_via_http_middleware(self): + """Test that transport is set per-request via HTTP middleware.""" + from fastmcp import Client + from fastmcp.client.transports import StreamableHttpTransport + from fastmcp.utilities.tests import run_server_async + + mcp = FastMCP("test") + observed_transport = None + + @mcp.tool + def get_transport(ctx: Context) -> str: + nonlocal observed_transport + observed_transport = ctx.transport + return observed_transport or "none" + + async with run_server_async(mcp, transport="streamable-http") as url: + transport = StreamableHttpTransport(url=url) + async with Client(transport=transport) as client: + result = await client.call_tool("get_transport", {}) + assert observed_transport == "streamable-http" + assert result.data == "streamable-http"