Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/servers/middleware.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,36 @@ stats = middleware.statistics()
print(f"Total cache operations: {stats}")
```

### Ping Middleware

<VersionBadge version="3.0.0" />

For long-lived connections, servers can send periodic pings to keep connections alive and prevent clients from dropping the connection due to inactivity. FastMCP provides ping middleware at `fastmcp.server.middleware.ping`.

```python
from fastmcp import FastMCP
from fastmcp.server.middleware import PingMiddleware

mcp = FastMCP("MyServer")
mcp.add_middleware(PingMiddleware(interval_ms=5000))
```

The middleware starts a background ping task on the first message from each client session. Pings continue at the configured interval until the session ends. When the session closes, the ping task is automatically cancelled and cleaned up.

Configure the ping interval based on your client's timeout settings:

```python
# Ping every 30 seconds (default)
mcp.add_middleware(PingMiddleware())

# Ping every 5 seconds for clients with short timeouts
mcp.add_middleware(PingMiddleware(interval_ms=5000))
```

<Note>
Ping middleware is most useful for stateful HTTP connections where clients maintain long-lived sessions. For stateless connections, pings have no effect since each request is independent.
</Note>

### Logging Middleware

Request and response logging is crucial for debugging, monitoring, and understanding usage patterns in your MCP server. FastMCP provides comprehensive logging middleware at `fastmcp.server.middleware.logging`.
Expand Down
2 changes: 2 additions & 0 deletions src/fastmcp/server/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Middleware,
MiddlewareContext,
)
from .ping import PingMiddleware

__all__ = [
"CallNext",
"Middleware",
"MiddlewareContext",
"PingMiddleware",
]
70 changes: 70 additions & 0 deletions src/fastmcp/server/middleware/ping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Ping middleware for keeping client connections alive."""

from typing import Any

import anyio

from .middleware import CallNext, Middleware, MiddlewareContext


class PingMiddleware(Middleware):
"""Middleware that sends periodic pings to keep client connections alive.

Starts a background ping task on first message from each session. The task
sends server-to-client pings at the configured interval until the session
ends.

Example:
```python
from fastmcp import FastMCP
from fastmcp.server.middleware import PingMiddleware

mcp = FastMCP("MyServer")
mcp.add_middleware(PingMiddleware(interval_ms=5000))
```
"""

def __init__(self, interval_ms: int = 30000):
"""Initialize ping middleware.

Args:
interval_ms: Interval between pings in milliseconds (default: 30000)

Raises:
ValueError: If interval_ms is not positive
"""
if interval_ms <= 0:
raise ValueError("interval_ms must be positive")
self.interval_ms = interval_ms
self._active_sessions: set[int] = set()
self._lock = anyio.Lock()

async def on_message(self, context: MiddlewareContext, call_next: CallNext) -> Any:
"""Start ping task on first message from a session."""
if (
context.fastmcp_context is None
or context.fastmcp_context.request_context is None
):
return await call_next(context)

session = context.fastmcp_context.session
session_id = id(session)

async with self._lock:
if session_id not in self._active_sessions:
# _subscription_task_group is added by MiddlewareServerSession
tg = session._subscription_task_group # type: ignore[attr-defined]
if tg is not None:
self._active_sessions.add(session_id)
tg.start_soon(self._ping_loop, session, session_id)

return await call_next(context)

async def _ping_loop(self, session: Any, session_id: int) -> None:
"""Send periodic pings until session ends."""
try:
while True:
await anyio.sleep(self.interval_ms / 1000)
await session.send_ping()
finally:
self._active_sessions.discard(session_id)
241 changes: 241 additions & 0 deletions tests/server/middleware/test_ping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""Tests for ping middleware."""

from unittest.mock import AsyncMock, MagicMock

import anyio
import pytest

from fastmcp import FastMCP
from fastmcp.client import Client
from fastmcp.server.middleware.ping import PingMiddleware


class TestPingMiddlewareInit:
"""Test PingMiddleware initialization."""

def test_init_default(self):
"""Test default initialization."""
middleware = PingMiddleware()
assert middleware.interval_ms == 30000
assert middleware._active_sessions == set()

def test_init_custom(self):
"""Test custom interval initialization."""
middleware = PingMiddleware(interval_ms=5000)
assert middleware.interval_ms == 5000

def test_init_invalid_interval_zero(self):
"""Test that zero interval raises ValueError."""
with pytest.raises(ValueError, match="interval_ms must be positive"):
PingMiddleware(interval_ms=0)

def test_init_invalid_interval_negative(self):
"""Test that negative interval raises ValueError."""
with pytest.raises(ValueError, match="interval_ms must be positive"):
PingMiddleware(interval_ms=-1000)


class TestPingMiddlewareOnMessage:
"""Test on_message hook behavior."""

async def test_starts_ping_task_on_first_message(self):
"""Test that ping task is started on first message from a session."""
middleware = PingMiddleware(interval_ms=1000)

mock_session = MagicMock()
mock_session._subscription_task_group = MagicMock()
mock_session._subscription_task_group.start_soon = MagicMock()

mock_context = MagicMock()
mock_context.fastmcp_context.session = mock_session

mock_call_next = AsyncMock(return_value="result")

result = await middleware.on_message(mock_context, mock_call_next)

assert result == "result"
assert id(mock_session) in middleware._active_sessions
mock_session._subscription_task_group.start_soon.assert_called_once()

async def test_does_not_start_duplicate_task(self):
"""Test that duplicate messages from same session don't spawn duplicate tasks."""
middleware = PingMiddleware(interval_ms=1000)

mock_session = MagicMock()
mock_session._subscription_task_group = MagicMock()
mock_session._subscription_task_group.start_soon = MagicMock()

mock_context = MagicMock()
mock_context.fastmcp_context.session = mock_session

mock_call_next = AsyncMock(return_value="result")

# First message
await middleware.on_message(mock_context, mock_call_next)
# Second message from same session
await middleware.on_message(mock_context, mock_call_next)
# Third message from same session
await middleware.on_message(mock_context, mock_call_next)

# Should only start task once
assert mock_session._subscription_task_group.start_soon.call_count == 1

async def test_starts_separate_task_per_session(self):
"""Test that different sessions get separate ping tasks."""
middleware = PingMiddleware(interval_ms=1000)

mock_session1 = MagicMock()
mock_session1._subscription_task_group = MagicMock()
mock_session1._subscription_task_group.start_soon = MagicMock()

mock_session2 = MagicMock()
mock_session2._subscription_task_group = MagicMock()
mock_session2._subscription_task_group.start_soon = MagicMock()

mock_context1 = MagicMock()
mock_context1.fastmcp_context.session = mock_session1

mock_context2 = MagicMock()
mock_context2.fastmcp_context.session = mock_session2

mock_call_next = AsyncMock(return_value="result")

await middleware.on_message(mock_context1, mock_call_next)
await middleware.on_message(mock_context2, mock_call_next)

mock_session1._subscription_task_group.start_soon.assert_called_once()
mock_session2._subscription_task_group.start_soon.assert_called_once()
assert len(middleware._active_sessions) == 2

async def test_skips_task_when_no_task_group(self):
"""Test graceful handling when session has no task group."""
middleware = PingMiddleware(interval_ms=1000)

mock_session = MagicMock()
mock_session._subscription_task_group = None

mock_context = MagicMock()
mock_context.fastmcp_context.session = mock_session

mock_call_next = AsyncMock(return_value="result")

result = await middleware.on_message(mock_context, mock_call_next)

assert result == "result"
# Session should NOT be added if task group is None
assert id(mock_session) not in middleware._active_sessions

async def test_skips_when_fastmcp_context_is_none(self):
"""Test that middleware passes through when fastmcp_context is None."""
middleware = PingMiddleware(interval_ms=1000)

mock_context = MagicMock()
mock_context.fastmcp_context = None

mock_call_next = AsyncMock(return_value="result")

result = await middleware.on_message(mock_context, mock_call_next)

assert result == "result"
assert len(middleware._active_sessions) == 0

async def test_skips_when_request_context_is_none(self):
"""Test that middleware passes through when request_context is None."""
middleware = PingMiddleware(interval_ms=1000)

mock_context = MagicMock()
mock_context.fastmcp_context = MagicMock()
mock_context.fastmcp_context.request_context = None

mock_call_next = AsyncMock(return_value="result")

result = await middleware.on_message(mock_context, mock_call_next)

assert result == "result"
assert len(middleware._active_sessions) == 0


class TestPingLoop:
"""Test the ping loop behavior."""

async def test_ping_loop_sends_pings_at_interval(self):
"""Test that ping loop sends pings at configured interval."""
middleware = PingMiddleware(interval_ms=50)

mock_session = MagicMock()
mock_session.send_ping = AsyncMock()

session_id = id(mock_session)
middleware._active_sessions.add(session_id)

# Run ping loop for a short time then cancel
with anyio.move_on_after(0.15):
await middleware._ping_loop(mock_session, session_id)

# Should have sent at least 2 pings in 150ms with 50ms interval
assert mock_session.send_ping.call_count >= 2

async def test_ping_loop_cleans_up_on_cancellation(self):
"""Test that session is removed from active sessions on cancellation."""
middleware = PingMiddleware(interval_ms=50)

mock_session = MagicMock()
mock_session.send_ping = AsyncMock()

session_id = 12345
middleware._active_sessions.add(session_id)

# Run and cancel the ping loop
with anyio.move_on_after(0.1):
await middleware._ping_loop(mock_session, session_id)

# Session should be cleaned up after cancellation
assert session_id not in middleware._active_sessions


class TestPingMiddlewareIntegration:
"""Integration tests for PingMiddleware with real FastMCP server."""

async def test_ping_middleware_registers_session(self):
"""Test that PingMiddleware registers sessions on first request."""
mcp = FastMCP("PingTestServer")
middleware = PingMiddleware(interval_ms=50)
mcp.add_middleware(middleware)

@mcp.tool
def hello() -> str:
return "Hello!"

assert len(middleware._active_sessions) == 0

async with Client(mcp) as client:
result = await client.call_tool("hello")
assert result.content[0].text == "Hello!"

# Should have registered the session
assert len(middleware._active_sessions) == 1

# Make another request - should not add duplicate
await client.call_tool("hello")
assert len(middleware._active_sessions) == 1

async def test_ping_task_cancelled_on_disconnect(self):
"""Test that ping task is properly cancelled when client disconnects."""
mcp = FastMCP("PingTestServer")
middleware = PingMiddleware(interval_ms=50)
mcp.add_middleware(middleware)

@mcp.tool
def hello() -> str:
return "Hello!"

async with Client(mcp) as client:
await client.call_tool("hello")
# Should have one active session
assert len(middleware._active_sessions) == 1

# After disconnect, give a moment for cleanup
await anyio.sleep(0.01)

# Session should be cleaned up
assert len(middleware._active_sessions) == 0