44import os
55import threading
66import time
7+ import traceback
8+ from typing import Optional , cast
79
810import openai # use the official client for correctness check
911import pytest
@@ -41,12 +43,15 @@ def __init__(self,
4143 self .tp_size = tp_size
4244 self .api_server_count = api_server_count
4345 self .base_server_args = base_server_args
44- self .servers : list [tuple [RemoteOpenAIServer , list [str ]]] = []
46+ self .servers : list [Optional [tuple [RemoteOpenAIServer ,
47+ list [str ]]]] = [None ] * (dp_size //
48+ dp_per_node )
4549 self .server_threads : list [threading .Thread ] = []
4650
4751 def __enter__ (self ) -> list [tuple [RemoteOpenAIServer , list [str ]]]:
4852 """Start all server instances for multi-node internal LB mode."""
49- for rank in range (0 , self .dp_size , self .dp_per_node ):
53+ for server_idx , rank in enumerate (
54+ range (0 , self .dp_size , self .dp_per_node )):
5055 # Create server args for this specific rank
5156 server_args = self .base_server_args .copy ()
5257
@@ -87,7 +92,7 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
8792 ])
8893
8994 # Use a thread to start each server to allow parallel initialization
90- def start_server (r : int , sargs : list [str ]):
95+ def start_server (sidx : int , r : int , sargs : list [str ]):
9196 gpus_per_node = self .tp_size * self .dp_per_node
9297 try :
9398 # Start the server
@@ -110,13 +115,14 @@ def start_server(r: int, sargs: list[str]):
110115 f"{ self .api_server_count } API servers" )
111116 else :
112117 print (f"Headless node (rank { r } ) started successfully" )
113- self .servers . append (( server , sargs ) )
118+ self .servers [ sidx ] = ( server , sargs )
114119 except Exception as e :
115120 print (f"Failed to start server rank { r } : { e } " )
121+ traceback .print_exc ()
116122 raise
117123
118124 thread = threading .Thread (target = start_server ,
119- args = (rank , server_args ))
125+ args = (server_idx , rank , server_args ))
120126 thread .start ()
121127
122128 self .server_threads .append (thread )
@@ -128,18 +134,20 @@ def start_server(r: int, sargs: list[str]):
128134 # Give servers additional time to fully initialize and coordinate
129135 time .sleep (3 )
130136
131- if len (self .servers ) != self . dp_size // self . dp_per_node :
137+ if not all (self .servers ):
132138 raise Exception ("Servers failed to start" )
133139
134- return self .servers
140+ return cast ( list [ tuple [ RemoteOpenAIServer , list [ str ]]], self .servers )
135141
136142 def __exit__ (self , exc_type , exc_val , exc_tb ):
137143 """Stop all server instances."""
138144 while self .servers :
139- try :
140- self .servers .pop ()[0 ].__exit__ (exc_type , exc_val , exc_tb )
141- except Exception as e :
142- print (f"Error stopping server: { e } " )
145+ if server := self .servers .pop ():
146+ try :
147+ server [0 ].__exit__ (exc_type , exc_val , exc_tb )
148+ except Exception as e :
149+ print (f"Error stopping server: { e } " )
150+ traceback .print_exc ()
143151
144152
145153class APIOnlyServerManager :
@@ -157,7 +165,8 @@ def __init__(self,
157165 self .tp_size = tp_size
158166 self .api_server_count = api_server_count
159167 self .base_server_args = base_server_args
160- self .servers : list [tuple [RemoteOpenAIServer , list [str ]]] = []
168+ self .servers : list [Optional [tuple [RemoteOpenAIServer ,
169+ list [str ]]]] = [None ] * 2
161170 self .server_threads : list [threading .Thread ] = []
162171
163172 def __enter__ (self ) -> list [tuple [RemoteOpenAIServer , list [str ]]]:
@@ -209,7 +218,7 @@ def start_api_server():
209218 server .__enter__ ()
210219 print (f"API-only server started successfully with "
211220 f"{ self .api_server_count } API servers" )
212- self .servers . append (( server , api_server_args ) )
221+ self .servers [ 0 ] = ( server , api_server_args )
213222 except Exception as e :
214223 print (f"Failed to start API-only server: { e } " )
215224 raise
@@ -231,7 +240,7 @@ def start_engines_server():
231240 server .__enter__ ()
232241 print (f"Headless engines server started successfully with "
233242 f"{ self .dp_size } engines" )
234- self .servers . append (( server , engines_server_args ) )
243+ self .servers [ 1 ] = ( server , engines_server_args )
235244 except Exception as e :
236245 print (f"Failed to start headless engines server: { e } " )
237246 raise
@@ -253,18 +262,20 @@ def start_engines_server():
253262 # Give servers additional time to fully initialize and coordinate
254263 time .sleep (3 )
255264
256- if len (self .servers ) != 2 :
265+ if not all (self .servers ):
257266 raise Exception ("Both servers failed to start" )
258267
259- return self .servers
268+ return cast ( list [ tuple [ RemoteOpenAIServer , list [ str ]]], self .servers )
260269
261270 def __exit__ (self , exc_type , exc_val , exc_tb ):
262271 """Stop both server instances."""
263272 while self .servers :
264- try :
265- self .servers .pop ()[0 ].__exit__ (exc_type , exc_val , exc_tb )
266- except Exception as e :
267- print (f"Error stopping server: { e } " )
273+ if server := self .servers .pop ():
274+ try :
275+ server [0 ].__exit__ (exc_type , exc_val , exc_tb )
276+ except Exception as e :
277+ print (f"Error stopping server: { e } " )
278+ traceback .print_exc ()
268279
269280
270281@pytest .fixture (scope = "module" )
@@ -560,7 +571,7 @@ async def make_request():
560571 assert len (results ) == num_requests
561572 assert all (completion is not None for completion in results )
562573
563- _ , api_server_args = api_only_servers [0 ]
574+ api_server , api_server_args = api_only_servers [0 ]
564575 api_server_count = (
565576 api_server_args .count ('--api-server-count' )
566577 and api_server_args [api_server_args .index ('--api-server-count' ) + 1 ]
@@ -569,7 +580,6 @@ async def make_request():
569580 f"engines on headless server (API server count: { api_server_count } )" )
570581
571582 # Check request balancing via Prometheus metrics
572- api_server = api_only_servers [0 ][0 ]
573583 check_request_balancing (api_server , DP_SIZE )
574584
575585
0 commit comments