diff --git a/docs/servers/middleware.mdx b/docs/servers/middleware.mdx
index 341d910461..8ffd683c0a 100644
--- a/docs/servers/middleware.mdx
+++ b/docs/servers/middleware.mdx
@@ -108,6 +108,32 @@ This hierarchy allows you to target your middleware logic with the right level o
The `on_initialize` hook receives the client's initialization request but **returns `None`** rather than a result. The initialization response is handled internally by the MCP protocol and cannot be modified by middleware. This hook is useful for client detection, logging connections, or initializing session state, but not for modifying the initialization handshake itself.
+**Example:**
+
+```python
+from fastmcp.server.middleware import Middleware, MiddlewareContext
+from mcp import McpError
+from mcp.types import ErrorData
+
+class InitializationMiddleware(Middleware):
+ async def on_initialize(self, context: MiddlewareContext, call_next):
+ # Check client capabilities before initialization
+ client_info = context.message.params.get("clientInfo", {})
+ client_name = client_info.get("name", "unknown")
+
+ # Reject unsupported clients BEFORE call_next
+ if client_name == "unsupported-client":
+ raise McpError(ErrorData(code=-32000, message="This client is not supported"))
+
+ # Log successful initialization
+ await call_next(context)
+ print(f"Client {client_name} initialized successfully")
+```
+
+
+If you raise `McpError` in `on_initialize` **after** calling `call_next()`, the error will only be logged and will not be sent to the client. The initialization response has already been sent at that point. Always raise `McpError` **before** `call_next()` if you want to reject the initialization.
+
+
### MCP Session Availability in Middleware
@@ -787,4 +813,4 @@ class CustomHeaderMiddleware(Middleware):
return result
mcp.add_middleware(CustomHeaderMiddleware())
-```
\ No newline at end of file
+```
diff --git a/src/fastmcp/server/low_level.py b/src/fastmcp/server/low_level.py
index e2f2603aef..536b6aa1f4 100644
--- a/src/fastmcp/server/low_level.py
+++ b/src/fastmcp/server/low_level.py
@@ -7,6 +7,7 @@
import anyio
import mcp.types
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
+from mcp import McpError
from mcp.server.lowlevel.server import (
LifespanResultT,
NotificationOptions,
@@ -104,9 +105,23 @@ async def call_original_handler(
fastmcp_context=fastmcp_ctx,
)
- return await self.fastmcp._apply_middleware(
- mw_context, call_original_handler
- )
+ try:
+ return await self.fastmcp._apply_middleware(
+ mw_context, call_original_handler
+ )
+ except McpError as e:
+ # McpError can be thrown from middleware in `on_initialize`
+ # send the error to responder.
+ if not responder._completed:
+ with responder:
+ await responder.respond(e.error)
+ else:
+ # Don't re-raise: prevents responding to initialize request twice
+ logger.warning(
+ "Received McpError but responder is already completed. "
+ "Cannot send error response as response was already sent.",
+ exc_info=e,
+ )
# Fall through to default handling (task methods now handled via registered handlers)
return await super()._received_request(responder)
diff --git a/tests/server/middleware/test_initialization_middleware.py b/tests/server/middleware/test_initialization_middleware.py
index 9552413b79..5c2266268e 100644
--- a/tests/server/middleware/test_initialization_middleware.py
+++ b/tests/server/middleware/test_initialization_middleware.py
@@ -3,6 +3,9 @@
from typing import Any
import mcp.types as mt
+import pytest
+from mcp import McpError
+from mcp.types import ErrorData
from fastmcp import Client, FastMCP
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
@@ -286,3 +289,87 @@ async def on_initialize(
assert middleware.initialize_result.serverInfo.name == "TestServer"
assert middleware.initialize_result.protocolVersion is not None
assert middleware.initialize_result.capabilities is not None
+
+
+async def test_middleware_mcp_error_during_initialization():
+ """Test that McpError raised in middleware during initialization is sent to client."""
+ server = FastMCP("TestServer")
+
+ class ErrorThrowingMiddleware(Middleware):
+ async def on_initialize(
+ self,
+ context: MiddlewareContext[mt.InitializeRequest],
+ call_next: CallNext[mt.InitializeRequest, None],
+ ) -> None:
+ raise McpError(
+ ErrorData(
+ code=mt.INVALID_PARAMS, message="Invalid initialization parameters"
+ )
+ )
+
+ server.add_middleware(ErrorThrowingMiddleware())
+
+ with pytest.raises(McpError) as exc_info:
+ async with Client(server):
+ pass
+
+ assert exc_info.value.error.message == "Invalid initialization parameters"
+ assert exc_info.value.error.code == mt.INVALID_PARAMS
+
+
+async def test_middleware_mcp_error_before_call_next():
+ """Test McpError raised before calling next middleware."""
+ server = FastMCP("TestServer")
+
+ class EarlyErrorMiddleware(Middleware):
+ async def on_initialize(
+ self,
+ context: MiddlewareContext[mt.InitializeRequest],
+ call_next: CallNext[mt.InitializeRequest, None],
+ ) -> None:
+ raise McpError(
+ ErrorData(code=mt.INVALID_REQUEST, message="Request validation failed")
+ )
+
+ server.add_middleware(EarlyErrorMiddleware())
+
+ with pytest.raises(McpError) as exc_info:
+ async with Client(server):
+ pass
+
+ assert exc_info.value.error.message == "Request validation failed"
+ assert exc_info.value.error.code == mt.INVALID_REQUEST
+
+
+async def test_middleware_mcp_error_after_call_next():
+ """Test that McpError raised after call_next doesn't break the connection.
+
+ When an error is raised after call_next, the responder has already completed,
+ so the error is caught but not sent to the responder (checked via _completed flag).
+ """
+ server = FastMCP("TestServer")
+
+ class PostProcessingErrorMiddleware(Middleware):
+ def __init__(self):
+ super().__init__()
+ self.error_raised = False
+
+ async def on_initialize(
+ self,
+ context: MiddlewareContext[mt.InitializeRequest],
+ call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
+ ) -> mt.InitializeResult | None:
+ await call_next(context)
+ self.error_raised = True
+ raise McpError(
+ ErrorData(code=mt.INTERNAL_ERROR, message="Post-processing failed")
+ )
+
+ middleware = PostProcessingErrorMiddleware()
+ server.add_middleware(middleware)
+
+ # Error is logged but not re-raised to prevent duplicate response
+ async with Client(server):
+ pass
+
+ assert middleware.error_raised is True