77from dataclasses import dataclass
88from typing import Any , Callable
99
10+ from exceptiongroup import BaseExceptionGroup
1011from starlette .applications import Starlette
1112from starlette .middleware .cors import CORSMiddleware
1213from starlette .requests import Request
@@ -137,8 +138,6 @@ async def serve_index(request: Request) -> HTMLResponse:
137138def _setup_single_view_dispatcher_route (
138139 options : Options , app : Starlette , component : RootComponentConstructor
139140) -> None :
140- @app .websocket_route (str (STREAM_PATH ))
141- @app .websocket_route (f"{ STREAM_PATH } /{{path:path}}" )
142141 async def model_stream (socket : WebSocket ) -> None :
143142 await socket .accept ()
144143 send , recv = _make_send_recv_callbacks (socket )
@@ -162,8 +161,16 @@ async def model_stream(socket: WebSocket) -> None:
162161 send ,
163162 recv ,
164163 )
165- except WebSocketDisconnect as error :
166- logger .info (f"WebSocket disconnect: { error .code } " )
164+ except BaseExceptionGroup as egroup :
165+ for e in egroup .exceptions :
166+ if isinstance (e , WebSocketDisconnect ):
167+ logger .info (f"WebSocket disconnect: { e .code } " )
168+ break
169+ else :
170+ raise
171+
172+ app .add_websocket_route (str (STREAM_PATH ), model_stream )
173+ app .add_websocket_route (f"{ STREAM_PATH } /{{path:path}}" , model_stream )
167174
168175
169176def _make_send_recv_callbacks (
0 commit comments