|
1 | 1 | import contextlib |
2 | 2 | import logging |
3 | | -from http import HTTPStatus |
4 | | -from uuid import uuid4 |
| 3 | +from collections.abc import AsyncIterator |
5 | 4 |
|
6 | 5 | import anyio |
7 | 6 | import click |
8 | 7 | import mcp.types as types |
9 | 8 | from mcp.server.lowlevel import Server |
10 | | -from mcp.server.streamable_http import ( |
11 | | - MCP_SESSION_ID_HEADER, |
12 | | - StreamableHTTPServerTransport, |
13 | | -) |
| 9 | +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
14 | 10 | from pydantic import AnyUrl |
15 | 11 | from starlette.applications import Starlette |
16 | | -from starlette.requests import Request |
17 | | -from starlette.responses import Response |
18 | 12 | from starlette.routing import Mount |
| 13 | +from starlette.types import Receive, Scope, Send |
19 | 14 |
|
20 | 15 | from .event_store import InMemoryEventStore |
21 | 16 |
|
22 | 17 | # Configure logging |
23 | 18 | logger = logging.getLogger(__name__) |
24 | 19 |
|
25 | | -# Global task group that will be initialized in the lifespan |
26 | | -task_group = None |
27 | | - |
28 | | -# Event store for resumability |
29 | | -# The InMemoryEventStore enables resumability support for StreamableHTTP transport. |
30 | | -# It stores SSE events with unique IDs, allowing clients to: |
31 | | -# 1. Receive event IDs for each SSE message |
32 | | -# 2. Resume streams by sending Last-Event-ID in GET requests |
33 | | -# 3. Replay missed events after reconnection |
34 | | -# Note: This in-memory implementation is for demonstration ONLY. |
35 | | -# For production, use a persistent storage solution. |
36 | | -event_store = InMemoryEventStore() |
37 | | - |
38 | | - |
39 | | -@contextlib.asynccontextmanager |
40 | | -async def lifespan(app): |
41 | | - """Application lifespan context manager for managing task group.""" |
42 | | - global task_group |
43 | | - |
44 | | - async with anyio.create_task_group() as tg: |
45 | | - task_group = tg |
46 | | - logger.info("Application started, task group initialized!") |
47 | | - try: |
48 | | - yield |
49 | | - finally: |
50 | | - logger.info("Application shutting down, cleaning up resources...") |
51 | | - if task_group: |
52 | | - tg.cancel_scope.cancel() |
53 | | - task_group = None |
54 | | - logger.info("Resources cleaned up successfully.") |
55 | | - |
56 | 20 |
|
57 | 21 | @click.command() |
58 | 22 | @click.option("--port", default=3000, help="Port to listen on for HTTP") |
@@ -156,60 +120,38 @@ async def list_tools() -> list[types.Tool]: |
156 | 120 | ) |
157 | 121 | ] |
158 | 122 |
|
159 | | - # We need to store the server instances between requests |
160 | | - server_instances = {} |
161 | | - # Lock to prevent race conditions when creating new sessions |
162 | | - session_creation_lock = anyio.Lock() |
| 123 | + # Create event store for resumability |
| 124 | + # The InMemoryEventStore enables resumability support for StreamableHTTP transport. |
| 125 | + # It stores SSE events with unique IDs, allowing clients to: |
| 126 | + # 1. Receive event IDs for each SSE message |
| 127 | + # 2. Resume streams by sending Last-Event-ID in GET requests |
| 128 | + # 3. Replay missed events after reconnection |
| 129 | + # Note: This in-memory implementation is for demonstration ONLY. |
| 130 | + # For production, use a persistent storage solution. |
| 131 | + event_store = InMemoryEventStore() |
| 132 | + |
| 133 | + # Create the session manager with our app and event store |
| 134 | + session_manager = StreamableHTTPSessionManager( |
| 135 | + app=app, |
| 136 | + event_store=event_store, # Enable resumability |
| 137 | + json_response=json_response, |
| 138 | + ) |
163 | 139 |
|
164 | 140 | # ASGI handler for streamable HTTP connections |
165 | | - async def handle_streamable_http(scope, receive, send): |
166 | | - request = Request(scope, receive) |
167 | | - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) |
168 | | - if ( |
169 | | - request_mcp_session_id is not None |
170 | | - and request_mcp_session_id in server_instances |
171 | | - ): |
172 | | - transport = server_instances[request_mcp_session_id] |
173 | | - logger.debug("Session already exists, handling request directly") |
174 | | - await transport.handle_request(scope, receive, send) |
175 | | - elif request_mcp_session_id is None: |
176 | | - # try to establish new session |
177 | | - logger.debug("Creating new transport") |
178 | | - # Use lock to prevent race conditions when creating new sessions |
179 | | - async with session_creation_lock: |
180 | | - new_session_id = uuid4().hex |
181 | | - http_transport = StreamableHTTPServerTransport( |
182 | | - mcp_session_id=new_session_id, |
183 | | - is_json_response_enabled=json_response, |
184 | | - event_store=event_store, # Enable resumability |
185 | | - ) |
186 | | - server_instances[http_transport.mcp_session_id] = http_transport |
187 | | - logger.info(f"Created new transport with session ID: {new_session_id}") |
188 | | - |
189 | | - async def run_server(task_status=None): |
190 | | - async with http_transport.connect() as streams: |
191 | | - read_stream, write_stream = streams |
192 | | - if task_status: |
193 | | - task_status.started() |
194 | | - await app.run( |
195 | | - read_stream, |
196 | | - write_stream, |
197 | | - app.create_initialization_options(), |
198 | | - ) |
199 | | - |
200 | | - if not task_group: |
201 | | - raise RuntimeError("Task group is not initialized") |
202 | | - |
203 | | - await task_group.start(run_server) |
204 | | - |
205 | | - # Handle the HTTP request and return the response |
206 | | - await http_transport.handle_request(scope, receive, send) |
207 | | - else: |
208 | | - response = Response( |
209 | | - "Bad Request: No valid session ID provided", |
210 | | - status_code=HTTPStatus.BAD_REQUEST, |
211 | | - ) |
212 | | - await response(scope, receive, send) |
| 141 | + async def handle_streamable_http( |
| 142 | + scope: Scope, receive: Receive, send: Send |
| 143 | + ) -> None: |
| 144 | + await session_manager.handle_request(scope, receive, send) |
| 145 | + |
| 146 | + @contextlib.asynccontextmanager |
| 147 | + async def lifespan(app: Starlette) -> AsyncIterator[None]: |
| 148 | + """Context manager for managing session manager lifecycle.""" |
| 149 | + async with session_manager.run(): |
| 150 | + logger.info("Application started with StreamableHTTP session manager!") |
| 151 | + try: |
| 152 | + yield |
| 153 | + finally: |
| 154 | + logger.info("Application shutting down...") |
213 | 155 |
|
214 | 156 | # Create an ASGI application using the transport |
215 | 157 | starlette_app = Starlette( |
|
0 commit comments