Skip to content
10 changes: 5 additions & 5 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
self, context: RequestContext["ClientSession", Any, Any]
) -> types.ListRootsResult | types.ErrorData: ...


Expand Down Expand Up @@ -53,7 +53,7 @@ async def _default_message_handler(


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -63,7 +63,7 @@ async def _default_sampling_callback(


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
Expand Down Expand Up @@ -367,7 +367,7 @@ async def send_roots_list_changed(self) -> None:
async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
ctx = RequestContext[ClientSession, Any, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
Expand Down
16 changes: 11 additions & 5 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context

Expand Down Expand Up @@ -684,6 +686,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
request=request,
)
return Response()

Expand Down Expand Up @@ -927,13 +930,14 @@ def my_tool(x: int, ctx: Context) -> str:
The context is optional - tools that don't need it can omit the parameter.
"""

_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_request_context: RequestContext[ServerSessionT, LifespanContextT, Request] | None
_fastmcp: FastMCP | None

def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
request_context: RequestContext[ServerSessionT, LifespanContextT, Request]
| None = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
Expand All @@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP:
return self._fastmcp

@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, Request]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")
Expand Down
27 changes: 20 additions & 7 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def main():
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.context import RequestContext, RequestT
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
Expand All @@ -93,7 +93,7 @@ async def main():
LifespanResultT = TypeVar("LifespanResultT")

# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx")
)

Expand All @@ -111,7 +111,7 @@ def __init__(


@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.

Args:
Expand All @@ -123,14 +123,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {}


class Server(Generic[LifespanResultT]):
class Server(Generic[LifespanResultT, RequestT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
):
self.name = name
Expand Down Expand Up @@ -215,7 +216,9 @@ def get_capabilities(
)

@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()

Expand Down Expand Up @@ -486,6 +489,7 @@ async def run(
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
request: RequestT | None = None,
# When True, the server is stateless and
# clients can perform initialization with any node. The client must still follow
# the initialization lifecycle, but can do so with any available node
Expand Down Expand Up @@ -513,6 +517,7 @@ async def run(
session,
lifespan_context,
raise_exceptions,
request,
)

async def _handle_message(
Expand All @@ -523,6 +528,7 @@ async def _handle_message(
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
request: RequestT | None = None,
):
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
Expand All @@ -532,7 +538,12 @@ async def _handle_message(
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
message,
req,
session,
lifespan_context,
raise_exceptions,
request,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
Expand All @@ -547,6 +558,7 @@ async def _handle_request(
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
request: RequestT | None,
):
logger.info(f"Processing request of type {type(req).__name__}")
if type(req) in self.request_handlers:
Expand All @@ -563,6 +575,7 @@ async def _handle_request(
message.request_meta,
session,
lifespan_context,
request=request,
)
)
response = await handler(req)
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT")


@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None
2 changes: 1 addition & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def create_client_server_memory_streams() -> (

@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any],
server: Server[Any, Any],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/shared/progress.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic
from typing import Any, Generic

from pydantic import BaseModel

Expand Down Expand Up @@ -62,6 +62,7 @@ def progress(
ReceiveNotificationT,
],
LifespanContextT,
Any,
],
total: float | None = None,
) -> Generator[
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_list_roots_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_list_roots_callback():
)

async def list_roots_callback(
context: RequestContext[ClientSession, None],
context: RequestContext[ClientSession, None, None],
) -> ListRootsResult:
return callback_return

Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_sampling_callback():
)

async def sampling_callback(
context: RequestContext[ClientSession, None],
context: RequestContext[ClientSession, None, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
Expand Down
Loading