44Contains tests for both server and client sides of the StreamableHTTP transport.
55"""
66
7+ import json
78import multiprocessing
89import socket
910import time
1819import uvicorn
1920from pydantic import AnyUrl
2021from starlette .applications import Starlette
22+ from starlette .requests import Request
23+ from starlette .responses import Response
2124from starlette .routing import Mount
2225
2326import mcp .types as types
@@ -244,8 +247,46 @@ def create_app(
244247 return app
245248
246249
250+ def create_header_capture_app () -> Starlette :
251+ """Implement a minimal Starlette app that intercepts every request,
252+ extracts its headers, and responds with status 418 (Test Status code),
253+ embedding the captured headers as the JSON response body.
254+ We use this server solely to verify that the MCP Server is forwarding
255+ headers correctly."""
256+
257+ # Create a wrapper that captures headers and returns them in error response
258+ async def header_capture_wrapper (scope , receive , send ):
259+ # Capture headers
260+ request = Request (scope , receive = receive )
261+ headers = dict (request .headers )
262+
263+ # Return error response with headers in body
264+ response = Response (
265+ "[TESTING_HEADER_CAPTURE]:" + json .dumps ({"headers" : headers }),
266+ status_code = 418 ,
267+ )
268+ await response (scope , receive , send )
269+
270+ # Create an ASGI application that uses our wrapper
271+ app = Starlette (
272+ debug = True ,
273+ routes = [
274+ Mount ("/mcp" , app = header_capture_wrapper ),
275+ ],
276+ )
277+
278+ return app
279+
280+
281+ def _get_captured_headrs (str ) -> dict [str , str ]:
282+ return json .loads (str .split ("[TESTING_HEADER_CAPTURE]:" )[1 ])["headers" ]
283+
284+
247285def run_server (
248- port : int , is_json_response_enabled = False , event_store : EventStore | None = None
286+ port : int ,
287+ is_json_response_enabled = False ,
288+ event_store : EventStore | None = None ,
289+ testing_header_capture : bool = False ,
249290) -> None :
250291 """Run the test server.
251292
@@ -255,7 +296,11 @@ def run_server(
255296 event_store: Optional event store for testing resumability.
256297 """
257298
258- app = create_app (is_json_response_enabled , event_store )
299+ if testing_header_capture :
300+ app = create_header_capture_app ()
301+ else :
302+ app = create_app (is_json_response_enabled , event_store )
303+
259304 # Configure server
260305 config = uvicorn .Config (
261306 app = app ,
@@ -296,33 +341,48 @@ def json_server_port() -> int:
296341 return s .getsockname ()[1 ]
297342
298343
299- @ pytest . fixture
300- def basic_server ( basic_server_port : int ) -> Generator [ None , None , None ]:
301- """Start a basic server."""
344+ def _start_basic_server (
345+ basic_server_port : int , testing_header_capture : bool
346+ ) -> Generator [ None , None , None ]:
302347 proc = multiprocessing .Process (
303- target = run_server , kwargs = {"port" : basic_server_port }, daemon = True
348+ target = run_server ,
349+ kwargs = {
350+ "port" : basic_server_port ,
351+ "testing_header_capture" : testing_header_capture ,
352+ },
353+ daemon = True ,
304354 )
305355 proc .start ()
306356
307357 # Wait for server to be running
308358 max_attempts = 20
309- attempt = 0
310- while attempt < max_attempts :
359+ for attempt in range (max_attempts ):
311360 try :
312361 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
313362 s .connect (("127.0.0.1" , basic_server_port ))
314363 break
315364 except ConnectionRefusedError :
316365 time .sleep (0.1 )
317- attempt += 1
318366 else :
319367 raise RuntimeError (f"Server failed to start after { max_attempts } attempts" )
320368
321- yield
369+ try :
370+ yield
371+ finally :
372+ proc .kill ()
373+ proc .join (timeout = 2 )
322374
323- # Clean up
324- proc .kill ()
325- proc .join (timeout = 2 )
375+
376+ @pytest .fixture
377+ def basic_server (basic_server_port : int ) -> Generator [None , None , None ]:
378+ yield from _start_basic_server (basic_server_port , testing_header_capture = False )
379+
380+
381+ @pytest .fixture
382+ def basic_server_with_header_capture (
383+ basic_server_port : int ,
384+ ) -> Generator [None , None , None ]:
385+ yield from _start_basic_server (basic_server_port , testing_header_capture = True )
326386
327387
328388@pytest .fixture
@@ -1232,79 +1292,84 @@ class MockAuthClientProvider:
12321292 def __init__ (self , token : str ):
12331293 self .token = token
12341294
1235- async def get_token (self ) -> str :
1236- return self .token
1295+ async def get_auth_headers (self ) -> dict [ str , str ] :
1296+ return { "Authorization" : f"Bearer { self .token } " }
12371297
12381298
12391299@pytest .mark .anyio
1240- async def test_auth_client_provider_headers (basic_server , basic_server_url ):
1300+ async def test_auth_client_provider_headers (
1301+ basic_server_with_header_capture , basic_server_url
1302+ ):
12411303 """Test that auth token provider correctly sets Authorization header."""
12421304 # Create a mock token provider
1243- client_provider = MockAuthClientProvider ("test-token-123" )
1244- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1305+ client_provider = MockAuthClientProvider ("short-lived-token-123" )
12451306
12461307 # Create client with token provider
12471308 async with streamablehttp_client (
12481309 f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
12491310 ) as (read_stream , write_stream , _ ):
12501311 async with ClientSession (read_stream , write_stream ) as session :
12511312 # Initialize the session
1252- result = await session .initialize ()
1253- assert isinstance (result , InitializeResult )
1254-
1255- # Make a request to verify headers
1256- tools = await session .list_tools ()
1257- assert len (tools .tools ) == 4
1258-
1259- client_provider .get_token .assert_called ()
1313+ with pytest .raises (McpError ) as mcpError :
1314+ _ = await session .initialize ()
1315+ assert (
1316+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1317+ == "Bearer short-lived-token-123"
1318+ )
12601319
12611320
12621321@pytest .mark .anyio
1263- async def test_auth_client_provider_token_update (basic_server , basic_server_url ):
1322+ async def test_auth_client_provider_token_called_on_every_request (
1323+ basic_server_with_header_capture , basic_server_url
1324+ ):
12641325 """Test that auth token provider can return different tokens."""
12651326 # Create a dynamic token provider
1266- client_provider = MockAuthClientProvider ("test-token-123" )
1267- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1327+ client_provider = MockAuthClientProvider ("short-lived-token-123" )
12681328
1269- # Create client with dynamic token provider
12701329 async with streamablehttp_client (
12711330 f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
12721331 ) as (read_stream , write_stream , _ ):
12731332 async with ClientSession (read_stream , write_stream ) as session :
12741333 # Initialize the session
1275- result = await session .initialize ()
1276- assert isinstance (result , InitializeResult )
1277-
1278- # Make multiple requests to verify token updates
1279- for i in range (3 ):
1280- tools = await session .list_tools ()
1281- assert len (tools .tools ) == 4
1334+ with pytest .raises (McpError ) as mcpError :
1335+ _ = await session .initialize ()
1336+ assert (
1337+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1338+ == "Bearer short-lived-token-123"
1339+ )
12821340
1283- client_provider .get_token .call_count > 1
1341+ # Mock a new token and ensure the new token is returned
1342+ client_provider .get_auth_headers = AsyncMock (
1343+ return_value = {"Authorization" : "Bearer short-lived-token-456" }
1344+ )
1345+ with pytest .raises (McpError ) as mcpError :
1346+ _ = await session .initialize ()
1347+ assert (
1348+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1349+ == "Bearer short-lived-token-456"
1350+ )
12841351
12851352
12861353@pytest .mark .anyio
12871354async def test_auth_client_provider_headers_not_overridden (
1288- basic_server , basic_server_url
1355+ basic_server_with_header_capture , basic_server_url
12891356):
1290- """Test that auth token provider correctly sets Authorization header ."""
1357+ """Test that provided headers override auth client provider headers ."""
12911358 # Create a mock token provider
1292- client_provider = MockAuthClientProvider ("test-token-123" )
1293- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1359+ client_provider = MockAuthClientProvider ("short-lived-token" )
12941360
1295- # Create client with token provider
1361+ # Create client with token provider and custom headers
1362+ custom_headers = {"Authorization" : "Bearer original-long-lived-token" }
12961363 async with streamablehttp_client (
12971364 f"{ basic_server_url } /mcp" ,
12981365 auth_client_provider = client_provider ,
1299- headers = { "Authorization" : "test-token-123" } ,
1366+ headers = custom_headers ,
13001367 ) as (read_stream , write_stream , _ ):
13011368 async with ClientSession (read_stream , write_stream ) as session :
1302- # Initialize the session
1303- result = await session .initialize ()
1304- assert isinstance (result , InitializeResult )
1305-
1306- # Make a request to verify headers
1307- tools = await session .list_tools ()
1308- assert len (tools .tools ) == 4
1309-
1310- client_provider .get_token .assert_not_called ()
1369+ # Original token is used and not short-lived-token from the provider
1370+ with pytest .raises (McpError ) as mcpError :
1371+ _ = await session .initialize ()
1372+ assert (
1373+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1374+ == "Bearer original-long-lived-token"
1375+ )
0 commit comments