Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion docs/servers/middleware.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ This hierarchy allows you to target your middleware logic with the right level o
The `on_initialize` hook receives the client's initialization request but **returns `None`** rather than a result. The initialization response is handled internally by the MCP protocol and cannot be modified by middleware. This hook is useful for client detection, logging connections, or initializing session state, but not for modifying the initialization handshake itself.
</Note>

**Example:**

```python
from fastmcp.server.middleware import Middleware, MiddlewareContext
from mcp import McpError
from mcp.types import ErrorData

class InitializationMiddleware(Middleware):
async def on_initialize(self, context: MiddlewareContext, call_next):
# Check client capabilities before initialization
client_info = context.message.params.get("clientInfo", {})
client_name = client_info.get("name", "unknown")

# Reject unsupported clients BEFORE call_next
if client_name == "unsupported-client":
raise McpError(ErrorData(code=-32000, message="This client is not supported"))

# Log successful initialization
await call_next(context)
print(f"Client {client_name} initialized successfully")
```

<Warning>
If you raise `McpError` in `on_initialize` **after** calling `call_next()`, the error will only be logged and will not be sent to the client. The initialization response has already been sent at that point. Always raise `McpError` **before** `call_next()` if you want to reject the initialization.
</Warning>

### MCP Session Availability in Middleware

<VersionBadge version="2.13.1" />
Expand Down Expand Up @@ -787,4 +813,4 @@ class CustomHeaderMiddleware(Middleware):
return result

mcp.add_middleware(CustomHeaderMiddleware())
```
```
21 changes: 18 additions & 3 deletions src/fastmcp/server/low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import anyio
import mcp.types
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import McpError
from mcp.server.lowlevel.server import (
LifespanResultT,
NotificationOptions,
Expand Down Expand Up @@ -104,9 +105,23 @@ async def call_original_handler(
fastmcp_context=fastmcp_ctx,
)

return await self.fastmcp._apply_middleware(
mw_context, call_original_handler
)
try:
return await self.fastmcp._apply_middleware(
mw_context, call_original_handler
)
except McpError as e:
# McpError can be thrown from middleware in `on_initialize`
# send the error to responder.
if not responder._completed:
with responder:
await responder.respond(e.error)
else:
# Don't re-raise: prevents responding to initialize request twice
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.warning(
"Received McpError but responder is already completed. "
"Cannot send error response as response was already sent.",
exc_info=e,
)

# Fall through to default handling (task methods now handled via registered handlers)
return await super()._received_request(responder)
Expand Down
87 changes: 87 additions & 0 deletions tests/server/middleware/test_initialization_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any

import mcp.types as mt
import pytest
from mcp import McpError
from mcp.types import ErrorData

from fastmcp import Client, FastMCP
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
Expand Down Expand Up @@ -286,3 +289,87 @@ async def on_initialize(
assert middleware.initialize_result.serverInfo.name == "TestServer"
assert middleware.initialize_result.protocolVersion is not None
assert middleware.initialize_result.capabilities is not None


async def test_middleware_mcp_error_during_initialization():
"""Test that McpError raised in middleware during initialization is sent to client."""
server = FastMCP("TestServer")

class ErrorThrowingMiddleware(Middleware):
async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, None],
) -> None:
raise McpError(
ErrorData(
code=mt.INVALID_PARAMS, message="Invalid initialization parameters"
)
)

server.add_middleware(ErrorThrowingMiddleware())

with pytest.raises(McpError) as exc_info:
async with Client(server):
pass

assert exc_info.value.error.message == "Invalid initialization parameters"
assert exc_info.value.error.code == mt.INVALID_PARAMS


async def test_middleware_mcp_error_before_call_next():
"""Test McpError raised before calling next middleware."""
server = FastMCP("TestServer")

class EarlyErrorMiddleware(Middleware):
async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, None],
) -> None:
raise McpError(
ErrorData(code=mt.INVALID_REQUEST, message="Request validation failed")
)

server.add_middleware(EarlyErrorMiddleware())

with pytest.raises(McpError) as exc_info:
async with Client(server):
pass

assert exc_info.value.error.message == "Request validation failed"
assert exc_info.value.error.code == mt.INVALID_REQUEST


async def test_middleware_mcp_error_after_call_next():
"""Test that McpError raised after call_next doesn't break the connection.

When an error is raised after call_next, the responder has already completed,
so the error is caught but not sent to the responder (checked via _completed flag).
"""
server = FastMCP("TestServer")

class PostProcessingErrorMiddleware(Middleware):
def __init__(self):
super().__init__()
self.error_raised = False

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
) -> mt.InitializeResult | None:
await call_next(context)
self.error_raised = True
raise McpError(
ErrorData(code=mt.INTERNAL_ERROR, message="Post-processing failed")
)

middleware = PostProcessingErrorMiddleware()
server.add_middleware(middleware)

# Error is logged but not re-raised to prevent duplicate response
async with Client(server):
pass

assert middleware.error_raised is True