diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 63def1785..6ded93f79 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -191,6 +191,26 @@ async def open_connection(url): assert is_open +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_extra_headers(ws_protocol_cls, http_protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send( + {"type": "websocket.accept", "headers": [(b"extra", b"header")]} + ) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.response_headers + + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + extra_headers = await open_connection("ws://127.0.0.1:8000") + assert extra_headers.get("extra") == "header" + + @pytest.mark.asyncio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 6c61f3ea5..44c1f2306 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -87,6 +87,7 @@ def __init__( ping_timeout=self.config.ws_ping_timeout, extensions=extensions, logger=logging.getLogger("uvicorn.error"), + extra_headers=[], ) def connection_made(self, transport): @@ -236,6 +237,13 @@ async def asgi_send(self, message): ) self.initial_response = None self.accepted_subprotocol = message.get("subprotocol") + if "headers" in message: + self.extra_headers.extend( + # ASGI spec requires bytes + # But for compability we need to convert it to strings + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in message.get("headers") + ) self.handshake_started_event.set() elif message_type == "websocket.close": diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 6ed3ab702..bf4c90441 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -257,12 +257,15 @@ async def send(self, message): ) self.handshake_complete = True subprotocol = message.get("subprotocol") + extra_headers = message.get("headers", []) extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) output = self.conn.send( wsproto.events.AcceptConnection( - subprotocol=subprotocol, extensions=extensions + subprotocol=subprotocol, + extensions=extensions, + extra_headers=extra_headers, ) ) self.transport.write(output)