Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
Expand Down Expand Up @@ -119,6 +120,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
) = Field(None, description="Lifespan context manager")

auth: AuthSettings | None = None

# Transport security settings (DNS rebinding protection)
transport_security: TransportSecuritySettings | None = None


def lifespan_wrapper(
Expand Down Expand Up @@ -672,6 +676,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette:

sse = SseServerTransport(
normalized_message_endpoint,
security_settings=self.settings.transport_security,
)

async def handle_sse(scope: Scope, receive: Receive, send: Send):
Expand Down Expand Up @@ -779,6 +784,7 @@ def streamable_http_app(self) -> Starlette:
event_store=self._event_store,
json_response=self.settings.json_response,
stateless=self.settings.stateless_http, # Use the stateless setting
security_settings=self.settings.transport_security,
)

# Create the ASGI handler
Expand Down
26 changes: 25 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)
Expand All @@ -71,16 +75,24 @@ class SseServerTransport:

_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
_security: TransportSecurityMiddleware

def __init__(self, endpoint: str) -> None:
def __init__(
self, endpoint: str, security_settings: TransportSecuritySettings | None = None
) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.

Args:
endpoint: The relative or absolute URL for POST messages.
security_settings: Optional security settings for DNS rebinding protection.
"""

super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

@asynccontextmanager
Expand All @@ -89,6 +101,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")

# Validate request headers for DNS rebinding protection
request = Request(scope, receive)
error_response = await self._security.validate_request(request, is_post=False)
if error_response:
await error_response(scope, receive, send)
raise ValueError("Request validation failed")

logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
Expand Down Expand Up @@ -169,6 +188,11 @@ async def handle_post_message(
) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)

# Validate request headers for DNS rebinding protection
error_response = await self._security.validate_request(request, is_post=True)
if error_response:
return await error_response(scope, receive, send)

session_id_param = request.query_params.get("session_id")
if session_id_param is None:
Expand Down
16 changes: 16 additions & 0 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import (
INTERNAL_ERROR,
Expand Down Expand Up @@ -131,12 +135,14 @@ class StreamableHTTPServerTransport:
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
_security: TransportSecurityMiddleware

def __init__(
self,
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
security_settings: TransportSecuritySettings | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
Expand All @@ -149,6 +155,7 @@ def __init__(
event_store: Event store for resumability support. If provided,
resumability will be enabled, allowing clients to
reconnect and resume messages.
security_settings: Optional security settings for DNS rebinding protection.

Raises:
ValueError: If the session ID contains invalid characters.
Expand All @@ -163,6 +170,7 @@ def __init__(
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._security = TransportSecurityMiddleware(security_settings)
self._request_streams: dict[
RequestId,
tuple[
Expand Down Expand Up @@ -260,6 +268,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)

# Validate request headers for DNS rebinding protection
is_post = request.method == "POST"
error_response = await self._security.validate_request(request, is_post=is_post)
if error_response:
await error_response(scope, receive, send)
return

if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
Expand Down
5 changes: 5 additions & 0 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EventStore,
StreamableHTTPServerTransport,
)
from mcp.server.transport_security import TransportSecuritySettings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,11 +61,13 @@ def __init__(
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,
security_settings: TransportSecuritySettings | None = None,
):
self.app = app
self.event_store = event_store
self.json_response = json_response
self.stateless = stateless
self.security_settings = security_settings

# Session tracking (only used if not stateless)
self._session_creation_lock = anyio.Lock()
Expand Down Expand Up @@ -162,6 +165,7 @@ async def _handle_stateless_request(
mcp_session_id=None, # No session tracking in stateless mode
is_json_response_enabled=self.json_response,
event_store=None, # No event store in stateless mode
security_settings=self.security_settings,
)

# Start server in a new task
Expand Down Expand Up @@ -222,6 +226,7 @@ async def _handle_stateful_request(
mcp_session_id=new_session_id,
is_json_response_enabled=self.json_response,
event_store=self.event_store, # May be None (no resumability)
security_settings=self.security_settings,
)

assert http_transport.mcp_session_id is not None
Expand Down
133 changes: 133 additions & 0 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""DNS rebinding protection for MCP server transports."""

import logging

from pydantic import BaseModel, Field
from starlette.requests import Request
from starlette.responses import Response

logger = logging.getLogger(__name__)


class TransportSecuritySettings(BaseModel):
"""Settings for MCP transport security features.

These settings help protect against DNS rebinding attacks by validating
incoming request headers.
"""

enable_dns_rebinding_protection: bool = Field(
default=True,
description="Enable DNS rebinding protection (recommended for production)"
)

allowed_hosts: list[str] = Field(
default=[],
description="List of allowed Host header values. Only applies when " +
"enable_dns_rebinding_protection is True."
)

allowed_origins: list[str] = Field(
default=[],
description="List of allowed Origin header values. Only applies when " +
"enable_dns_rebinding_protection is True."
)


class TransportSecurityMiddleware:
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""

def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default
# for backwards compatibility
self.settings = settings or TransportSecuritySettings(
enable_dns_rebinding_protection=False
)

def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not self.settings.enable_dns_rebinding_protection:
return True

if not host:
logger.warning("Missing Host header in request")
return False

# Check exact match first
if host in self.settings.allowed_hosts:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
if host.startswith(base_host + ":"):
return True

logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
if not self.settings.enable_dns_rebinding_protection:
return True

# Origin can be absent for same-origin requests
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
if origin.startswith(base_origin + ":"):
return True

logger.warning(f"Invalid Origin header: {origin}")
return False

def _validate_content_type(self, content_type: str | None) -> bool:
"""Validate the Content-Type header for POST requests."""
if not content_type:
logger.warning("Missing Content-Type header in POST request")
return False

# Content-Type must start with application/json
if not content_type.lower().startswith("application/json"):
logger.warning(f"Invalid Content-Type header: {content_type}")
return False

return True

async def validate_request(
self, request: Request, is_post: bool = False
) -> Response | None:
"""Validate request headers for DNS rebinding protection.

Returns None if validation passes, or an error Response if validation fails.
"""
# Validate Host header
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=400)

# Validate Origin header
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=400)

# Validate Content-Type for POST requests
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)

return None
Loading
Loading