1111from collections .abc import Awaitable
1212from concurrent .futures import Future
1313from dataclasses import dataclass
14+ from enum import Enum , auto
1415from threading import Thread
1516from typing import Any , Callable , Optional , TypeVar , Union
1617
@@ -257,11 +258,20 @@ def collective_rpc(self,
257258 return self .engine_core .collective_rpc (method , timeout , args , kwargs )
258259
259260
261+ class CoreEngineState (Enum ):
262+ NEW = auto ()
263+ CONNECTED = auto ()
264+ READY = auto ()
265+
266+
260267class CoreEngine :
261268 """One per data parallel rank."""
262269
263- def __init__ (self , index : int = 0 ):
270+ def __init__ (self , index : int = 0 , local : bool = True ):
271+ self .local = local
264272 self .identity = index .to_bytes (length = 2 , byteorder = "little" )
273+
274+ self .state = CoreEngineState .NEW
265275 self .num_reqs_in_flight = 0
266276
267277
@@ -352,20 +362,12 @@ def sigusr1_handler(signum, frame):
352362 self .resources = BackgroundResources (ctx = sync_ctx )
353363 self ._finalizer = weakref .finalize (self , self .resources )
354364
355- # TODO move address setup to separate method
356365 parallel_config = vllm_config .parallel_config
357366 dp_size = parallel_config .data_parallel_size
358367 local_engine_count = parallel_config .data_parallel_size_local
359368
360- if local_engine_count == dp_size :
361- input_address = get_open_zmq_ipc_path ()
362- output_address = get_open_zmq_ipc_path ()
363- else :
364- host = parallel_config .data_parallel_master_ip
365- input_port = parallel_config .data_parallel_rpc_port
366- output_port = get_open_port ()
367- input_address = f"tcp://{ host } :{ input_port } "
368- output_address = f"tcp://{ host } :{ output_port } "
369+ input_address , output_address = self ._get_zmq_addresses (
370+ parallel_config )
369371
370372 # Create input and output sockets.
371373 self .input_socket = self .resources .input_socket = make_zmq_socket (
@@ -386,14 +388,35 @@ def sigusr1_handler(signum, frame):
386388 local_engine_count = local_engine_count ,
387389 start_index = 0 )
388390
389- self .core_engines = [CoreEngine (i ) for i in range (dp_size )]
391+ self .core_engines = [
392+ CoreEngine (index = i , local = (i < local_engine_count ))
393+ for i in range (dp_size )
394+ ]
390395 self .core_engine = self .core_engines [0 ]
391396
392397 # Wait for engine core process(es) to start.
393398 self ._wait_for_engine_startup (output_address , parallel_config )
394399
395400 self .utility_results : dict [int , AnyFuture ] = {}
396401
402+ @staticmethod
403+ def _get_zmq_addresses (parallel_config : ParallelConfig ) -> tuple [str , str ]:
404+ """Returns (input_address, output_address)."""
405+ dp_size = parallel_config .data_parallel_size
406+ local_engine_count = parallel_config .data_parallel_size_local
407+
408+ if local_engine_count == dp_size :
409+ input_address = get_open_zmq_ipc_path ()
410+ output_address = get_open_zmq_ipc_path ()
411+ else :
412+ host = parallel_config .data_parallel_master_ip
413+ input_port = parallel_config .data_parallel_rpc_port
414+ output_port = get_open_port ()
415+ input_address = f"tcp://{ host } :{ input_port } "
416+ output_address = f"tcp://{ host } :{ output_port } "
417+
418+ return input_address , output_address
419+
397420 def _wait_for_engine_startup (self , output_address : str ,
398421 parallel_config : ParallelConfig ):
399422 # Get a sync handle to the socket which can be sync or async.
@@ -402,60 +425,39 @@ def _wait_for_engine_startup(self, output_address: str,
402425 # TODO offline case compatibility
403426
404427 # Wait for engine core process(es) to send ready messages.
405- local_engine_count = parallel_config .data_parallel_size_local
406- remote_engine_count = len (self .core_engines ) - local_engine_count
407-
408- # TODO simplify the startup tracking logic below!
409- pending_hello_local = set (range (local_engine_count ))
410- pending_hello_remote = set (
411- range (local_engine_count , len (self .core_engines )))
412- pending_ready_local = set (pending_hello_local )
413- pending_ready_remote = set (pending_hello_remote )
414- while pending_ready_local or pending_ready_remote :
428+ local_count = parallel_config .data_parallel_size_local
429+ remote_count = len (self .core_engines ) - local_count
430+ # [local, remote] counts
431+ conn_pending , start_pending = [local_count , remote_count ], [0 , 0 ]
432+
433+ while any (conn_pending ) or any (start_pending ):
415434 while not sync_input_socket .poll (timeout = STARTUP_POLL_PERIOD_MS ):
416- local_conn = local_engine_count - len (pending_hello_local )
417- local_ready = local_engine_count - len (pending_ready_local )
418- if local_ready != local_engine_count :
435+ if any (conn_pending ):
419436 logger .info (
420- "Waiting for local core engine procs: "
421- "%d/%d connected, %d/%d ready." , local_conn ,
422- local_engine_count , local_ready , local_engine_count )
423- if remote_engine_count :
424- remote_conn = remote_engine_count - len (
425- pending_hello_remote )
426- remote_ready = remote_engine_count - len (
427- pending_ready_remote )
428- if remote_ready != remote_engine_count :
429- logger .info (
430- "Waiting for remote core engine procs: "
431- "%d/%d connected, %d/%d ready." , remote_conn ,
432- remote_engine_count , remote_ready ,
433- remote_engine_count )
437+ "Waiting for %d local, %d remote core engine proc(s) "
438+ "to connect." , * conn_pending )
439+ if any (start_pending ):
440+ logger .info (
441+ "Waiting for %d local, %d remote core engine proc(s) "
442+ "to start." , * start_pending )
434443
435444 # Receive HELLO and READY messages from the input socket.
436445 eng_identity , ready_msg_bytes = sync_input_socket .recv_multipart ()
437446 eng_index = int .from_bytes (eng_identity , byteorder = "little" )
447+ if eng_index > len (self .core_engines ):
448+ raise RuntimeError (
449+ f"Message from engine rank larger than "
450+ f"configured data parallel size: { eng_index } " )
451+ engine = self .core_engines [eng_index ]
438452 msg = msgspec .msgpack .decode (ready_msg_bytes )
439453 status , local = msg ["status" ], msg ["local" ]
440- hello_set = pending_hello_local if local else pending_hello_remote
441- ready_set = pending_ready_local if local else pending_ready_remote
442- if status == "HELLO" :
443- index_set = hello_set
444- elif status == "READY" :
445- index_set = ready_set
446- else :
447- raise RuntimeError (f"{ 'Local' if local else 'Remote' } engine "
448- f"{ eng_index } failed: { status } " )
449- if eng_index not in index_set :
450- raise RuntimeError (
451- f"Unexpected or duplicate { status } "
452- f"{ 'local' if local else 'remote' } engine: { eng_index } " )
453- if status == "READY" and eng_index in hello_set :
454- raise RuntimeError (
455- f"Unexpected READY before HELLO for "
456- f"{ 'local' if local else 'remote' } engine: { eng_index } " )
454+ if local != engine .local :
455+ raise RuntimeError (f"{ status } message from "
456+ f"{ 'local' if local else 'remote' } "
457+ f" engine { eng_index } , expected it to be "
458+ f"{ 'local' if engine .local else 'remote' } " )
459+ if status == "HELLO" and engine .state == CoreEngineState .NEW :
457460
458- if status == "HELLO" :
459461 # Send init message with DP config info.
460462 init_message = self .encoder .encode ({
461463 "output_socket_address" : output_address ,
@@ -470,10 +472,20 @@ def _wait_for_engine_startup(self, output_address: str,
470472 })
471473 sync_input_socket .send_multipart ((eng_identity , init_message ),
472474 copy = False )
475+ conn_pending [0 if local else 1 ] -= 1
476+ start_pending [0 if local else 1 ] += 1
477+ engine .state = CoreEngineState .CONNECTED
478+ elif status == "READY" and (engine .state
479+ == CoreEngineState .CONNECTED ):
480+ start_pending [0 if local else 1 ] -= 1
481+ engine .state = CoreEngineState .READY
482+ else :
483+ raise RuntimeError (f"Unexpected { status } message for "
484+ f"{ 'local' if local else 'remote' } engine "
485+ f"{ eng_index } in { engine .state } state." )
473486
474487 logger .debug ("%s from %s core engine process %s." , status ,
475488 "local" if local else "remote" , eng_index )
476- index_set .discard (eng_index )
477489
478490 # Double check that the process are running.
479491 engine_manager = self .resources .local_engine_manager
0 commit comments