diff --git a/docs/servers/middleware.mdx b/docs/servers/middleware.mdx index 234de648f1..c99ae54667 100644 --- a/docs/servers/middleware.mdx +++ b/docs/servers/middleware.mdx @@ -617,6 +617,36 @@ stats = middleware.statistics() print(f"Total cache operations: {stats}") ``` +### Ping Middleware + + + +For long-lived connections, servers can send periodic pings to keep connections alive and prevent clients from dropping the connection due to inactivity. FastMCP provides ping middleware at `fastmcp.server.middleware.ping`. + +```python +from fastmcp import FastMCP +from fastmcp.server.middleware import PingMiddleware + +mcp = FastMCP("MyServer") +mcp.add_middleware(PingMiddleware(interval_ms=5000)) +``` + +The middleware starts a background ping task on the first message from each client session. Pings continue at the configured interval until the session ends. When the session closes, the ping task is automatically cancelled and cleaned up. + +Configure the ping interval based on your client's timeout settings: + +```python +# Ping every 30 seconds (default) +mcp.add_middleware(PingMiddleware()) + +# Ping every 5 seconds for clients with short timeouts +mcp.add_middleware(PingMiddleware(interval_ms=5000)) +``` + + +Ping middleware is most useful for stateful HTTP connections where clients maintain long-lived sessions. For stateless connections, pings have no effect since each request is independent. + + ### Logging Middleware Request and response logging is crucial for debugging, monitoring, and understanding usage patterns in your MCP server. FastMCP provides comprehensive logging middleware at `fastmcp.server.middleware.logging`. diff --git a/src/fastmcp/server/middleware/__init__.py b/src/fastmcp/server/middleware/__init__.py index d53d5f05d5..98afc1875f 100644 --- a/src/fastmcp/server/middleware/__init__.py +++ b/src/fastmcp/server/middleware/__init__.py @@ -3,9 +3,11 @@ Middleware, MiddlewareContext, ) +from .ping import PingMiddleware __all__ = [ "CallNext", "Middleware", "MiddlewareContext", + "PingMiddleware", ] diff --git a/src/fastmcp/server/middleware/ping.py b/src/fastmcp/server/middleware/ping.py new file mode 100644 index 0000000000..3ee61dac27 --- /dev/null +++ b/src/fastmcp/server/middleware/ping.py @@ -0,0 +1,70 @@ +"""Ping middleware for keeping client connections alive.""" + +from typing import Any + +import anyio + +from .middleware import CallNext, Middleware, MiddlewareContext + + +class PingMiddleware(Middleware): + """Middleware that sends periodic pings to keep client connections alive. + + Starts a background ping task on first message from each session. The task + sends server-to-client pings at the configured interval until the session + ends. + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.middleware import PingMiddleware + + mcp = FastMCP("MyServer") + mcp.add_middleware(PingMiddleware(interval_ms=5000)) + ``` + """ + + def __init__(self, interval_ms: int = 30000): + """Initialize ping middleware. + + Args: + interval_ms: Interval between pings in milliseconds (default: 30000) + + Raises: + ValueError: If interval_ms is not positive + """ + if interval_ms <= 0: + raise ValueError("interval_ms must be positive") + self.interval_ms = interval_ms + self._active_sessions: set[int] = set() + self._lock = anyio.Lock() + + async def on_message(self, context: MiddlewareContext, call_next: CallNext) -> Any: + """Start ping task on first message from a session.""" + if ( + context.fastmcp_context is None + or context.fastmcp_context.request_context is None + ): + return await call_next(context) + + session = context.fastmcp_context.session + session_id = id(session) + + async with self._lock: + if session_id not in self._active_sessions: + # _subscription_task_group is added by MiddlewareServerSession + tg = session._subscription_task_group # type: ignore[attr-defined] + if tg is not None: + self._active_sessions.add(session_id) + tg.start_soon(self._ping_loop, session, session_id) + + return await call_next(context) + + async def _ping_loop(self, session: Any, session_id: int) -> None: + """Send periodic pings until session ends.""" + try: + while True: + await anyio.sleep(self.interval_ms / 1000) + await session.send_ping() + finally: + self._active_sessions.discard(session_id) diff --git a/tests/server/middleware/test_ping.py b/tests/server/middleware/test_ping.py new file mode 100644 index 0000000000..ba5e10fdf6 --- /dev/null +++ b/tests/server/middleware/test_ping.py @@ -0,0 +1,241 @@ +"""Tests for ping middleware.""" + +from unittest.mock import AsyncMock, MagicMock + +import anyio +import pytest + +from fastmcp import FastMCP +from fastmcp.client import Client +from fastmcp.server.middleware.ping import PingMiddleware + + +class TestPingMiddlewareInit: + """Test PingMiddleware initialization.""" + + def test_init_default(self): + """Test default initialization.""" + middleware = PingMiddleware() + assert middleware.interval_ms == 30000 + assert middleware._active_sessions == set() + + def test_init_custom(self): + """Test custom interval initialization.""" + middleware = PingMiddleware(interval_ms=5000) + assert middleware.interval_ms == 5000 + + def test_init_invalid_interval_zero(self): + """Test that zero interval raises ValueError.""" + with pytest.raises(ValueError, match="interval_ms must be positive"): + PingMiddleware(interval_ms=0) + + def test_init_invalid_interval_negative(self): + """Test that negative interval raises ValueError.""" + with pytest.raises(ValueError, match="interval_ms must be positive"): + PingMiddleware(interval_ms=-1000) + + +class TestPingMiddlewareOnMessage: + """Test on_message hook behavior.""" + + async def test_starts_ping_task_on_first_message(self): + """Test that ping task is started on first message from a session.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_session = MagicMock() + mock_session._subscription_task_group = MagicMock() + mock_session._subscription_task_group.start_soon = MagicMock() + + mock_context = MagicMock() + mock_context.fastmcp_context.session = mock_session + + mock_call_next = AsyncMock(return_value="result") + + result = await middleware.on_message(mock_context, mock_call_next) + + assert result == "result" + assert id(mock_session) in middleware._active_sessions + mock_session._subscription_task_group.start_soon.assert_called_once() + + async def test_does_not_start_duplicate_task(self): + """Test that duplicate messages from same session don't spawn duplicate tasks.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_session = MagicMock() + mock_session._subscription_task_group = MagicMock() + mock_session._subscription_task_group.start_soon = MagicMock() + + mock_context = MagicMock() + mock_context.fastmcp_context.session = mock_session + + mock_call_next = AsyncMock(return_value="result") + + # First message + await middleware.on_message(mock_context, mock_call_next) + # Second message from same session + await middleware.on_message(mock_context, mock_call_next) + # Third message from same session + await middleware.on_message(mock_context, mock_call_next) + + # Should only start task once + assert mock_session._subscription_task_group.start_soon.call_count == 1 + + async def test_starts_separate_task_per_session(self): + """Test that different sessions get separate ping tasks.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_session1 = MagicMock() + mock_session1._subscription_task_group = MagicMock() + mock_session1._subscription_task_group.start_soon = MagicMock() + + mock_session2 = MagicMock() + mock_session2._subscription_task_group = MagicMock() + mock_session2._subscription_task_group.start_soon = MagicMock() + + mock_context1 = MagicMock() + mock_context1.fastmcp_context.session = mock_session1 + + mock_context2 = MagicMock() + mock_context2.fastmcp_context.session = mock_session2 + + mock_call_next = AsyncMock(return_value="result") + + await middleware.on_message(mock_context1, mock_call_next) + await middleware.on_message(mock_context2, mock_call_next) + + mock_session1._subscription_task_group.start_soon.assert_called_once() + mock_session2._subscription_task_group.start_soon.assert_called_once() + assert len(middleware._active_sessions) == 2 + + async def test_skips_task_when_no_task_group(self): + """Test graceful handling when session has no task group.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_session = MagicMock() + mock_session._subscription_task_group = None + + mock_context = MagicMock() + mock_context.fastmcp_context.session = mock_session + + mock_call_next = AsyncMock(return_value="result") + + result = await middleware.on_message(mock_context, mock_call_next) + + assert result == "result" + # Session should NOT be added if task group is None + assert id(mock_session) not in middleware._active_sessions + + async def test_skips_when_fastmcp_context_is_none(self): + """Test that middleware passes through when fastmcp_context is None.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_context = MagicMock() + mock_context.fastmcp_context = None + + mock_call_next = AsyncMock(return_value="result") + + result = await middleware.on_message(mock_context, mock_call_next) + + assert result == "result" + assert len(middleware._active_sessions) == 0 + + async def test_skips_when_request_context_is_none(self): + """Test that middleware passes through when request_context is None.""" + middleware = PingMiddleware(interval_ms=1000) + + mock_context = MagicMock() + mock_context.fastmcp_context = MagicMock() + mock_context.fastmcp_context.request_context = None + + mock_call_next = AsyncMock(return_value="result") + + result = await middleware.on_message(mock_context, mock_call_next) + + assert result == "result" + assert len(middleware._active_sessions) == 0 + + +class TestPingLoop: + """Test the ping loop behavior.""" + + async def test_ping_loop_sends_pings_at_interval(self): + """Test that ping loop sends pings at configured interval.""" + middleware = PingMiddleware(interval_ms=50) + + mock_session = MagicMock() + mock_session.send_ping = AsyncMock() + + session_id = id(mock_session) + middleware._active_sessions.add(session_id) + + # Run ping loop for a short time then cancel + with anyio.move_on_after(0.15): + await middleware._ping_loop(mock_session, session_id) + + # Should have sent at least 2 pings in 150ms with 50ms interval + assert mock_session.send_ping.call_count >= 2 + + async def test_ping_loop_cleans_up_on_cancellation(self): + """Test that session is removed from active sessions on cancellation.""" + middleware = PingMiddleware(interval_ms=50) + + mock_session = MagicMock() + mock_session.send_ping = AsyncMock() + + session_id = 12345 + middleware._active_sessions.add(session_id) + + # Run and cancel the ping loop + with anyio.move_on_after(0.1): + await middleware._ping_loop(mock_session, session_id) + + # Session should be cleaned up after cancellation + assert session_id not in middleware._active_sessions + + +class TestPingMiddlewareIntegration: + """Integration tests for PingMiddleware with real FastMCP server.""" + + async def test_ping_middleware_registers_session(self): + """Test that PingMiddleware registers sessions on first request.""" + mcp = FastMCP("PingTestServer") + middleware = PingMiddleware(interval_ms=50) + mcp.add_middleware(middleware) + + @mcp.tool + def hello() -> str: + return "Hello!" + + assert len(middleware._active_sessions) == 0 + + async with Client(mcp) as client: + result = await client.call_tool("hello") + assert result.content[0].text == "Hello!" + + # Should have registered the session + assert len(middleware._active_sessions) == 1 + + # Make another request - should not add duplicate + await client.call_tool("hello") + assert len(middleware._active_sessions) == 1 + + async def test_ping_task_cancelled_on_disconnect(self): + """Test that ping task is properly cancelled when client disconnects.""" + mcp = FastMCP("PingTestServer") + middleware = PingMiddleware(interval_ms=50) + mcp.add_middleware(middleware) + + @mcp.tool + def hello() -> str: + return "Hello!" + + async with Client(mcp) as client: + await client.call_tool("hello") + # Should have one active session + assert len(middleware._active_sessions) == 1 + + # After disconnect, give a moment for cleanup + await anyio.sleep(0.01) + + # Session should be cleaned up + assert len(middleware._active_sessions) == 0