@@ -363,11 +363,28 @@ def sigusr1_handler(signum, frame):
363363 self ._finalizer = weakref .finalize (self , self .resources )
364364
365365 parallel_config = vllm_config .parallel_config
366- dp_size = parallel_config .data_parallel_size
367366 local_engine_count = parallel_config .data_parallel_size_local
367+ start_index = parallel_config .data_parallel_rank
368+ local_start_index = parallel_config .data_parallel_rank_local
369+
370+ # SPMD mode is where there is an LLM instance per DP rank and one
371+ # core engine per LLM, see examples/offline_inference/data_parallel.py.
372+ spmd_mode = local_start_index is not None
373+ if spmd_mode :
374+ assert local_engine_count == 1
375+ self .core_engines = [
376+ CoreEngine (index = local_start_index , local = True )
377+ ]
378+ else :
379+ assert start_index == 0
380+ local_start_index = 0
381+ self .core_engines = [
382+ CoreEngine (index = i , local = (i < local_engine_count ))
383+ for i in range (parallel_config .data_parallel_size )
384+ ]
368385
369386 input_address , output_address = self ._get_zmq_addresses (
370- parallel_config )
387+ parallel_config , spmd_mode )
371388
372389 # Create input and output sockets.
373390 self .input_socket = self .resources .input_socket = make_zmq_socket (
@@ -378,6 +395,7 @@ def sigusr1_handler(signum, frame):
378395 zmq .constants .PULL )
379396 # Start local engines.
380397 if local_engine_count :
398+ # In server mode, start_index and local_start_index will both be 0.
381399 self .resources .local_engine_manager = CoreEngineProcManager (
382400 EngineCoreProc .run_engine_core ,
383401 vllm_config = vllm_config ,
@@ -386,12 +404,9 @@ def sigusr1_handler(signum, frame):
386404 input_address = input_address ,
387405 on_head_node = True ,
388406 local_engine_count = local_engine_count ,
389- start_index = 0 )
407+ start_index = start_index ,
408+ local_start_index = local_start_index )
390409
391- self .core_engines = [
392- CoreEngine (index = i , local = (i < local_engine_count ))
393- for i in range (dp_size )
394- ]
395410 self .core_engine = self .core_engines [0 ]
396411
397412 # Wait for engine core process(es) to start.
@@ -400,12 +415,13 @@ def sigusr1_handler(signum, frame):
400415 self .utility_results : dict [int , AnyFuture ] = {}
401416
402417 @staticmethod
403- def _get_zmq_addresses (parallel_config : ParallelConfig ) -> tuple [str , str ]:
418+ def _get_zmq_addresses (parallel_config : ParallelConfig ,
419+ spmd_mode : bool ) -> tuple [str , str ]:
404420 """Returns (input_address, output_address)."""
405421 dp_size = parallel_config .data_parallel_size
406422 local_engine_count = parallel_config .data_parallel_size_local
407423
408- if local_engine_count == dp_size :
424+ if local_engine_count == dp_size or spmd_mode :
409425 input_address = get_open_zmq_ipc_path ()
410426 output_address = get_open_zmq_ipc_path ()
411427 else :
@@ -422,8 +438,6 @@ def _wait_for_engine_startup(self, output_address: str,
422438 # Get a sync handle to the socket which can be sync or async.
423439 sync_input_socket = zmq .Socket .shadow (self .input_socket )
424440
425- # TODO offline case compatibility
426-
427441 # Wait for engine core process(es) to send ready messages.
428442 local_count = parallel_config .data_parallel_size_local
429443 remote_count = len (self .core_engines ) - local_count
@@ -444,18 +458,20 @@ def _wait_for_engine_startup(self, output_address: str,
444458 # Receive HELLO and READY messages from the input socket.
445459 eng_identity , ready_msg_bytes = sync_input_socket .recv_multipart ()
446460 eng_index = int .from_bytes (eng_identity , byteorder = "little" )
447- if eng_index > len (self .core_engines ):
448- raise RuntimeError (
449- f"Message from engine rank larger than "
450- f"configured data parallel size: { eng_index } " )
451- engine = self .core_engines [eng_index ]
461+ engine = next (
462+ (e for e in self .core_engines if e .identity == eng_identity ),
463+ None )
464+ if engine is None :
465+ raise RuntimeError (f"Message from engine with unexpected data "
466+ f"parallel rank: { eng_index } " )
452467 msg = msgspec .msgpack .decode (ready_msg_bytes )
453468 status , local = msg ["status" ], msg ["local" ]
454469 if local != engine .local :
455470 raise RuntimeError (f"{ status } message from "
456471 f"{ 'local' if local else 'remote' } "
457- f" engine { eng_index } , expected it to be "
472+ f"engine { eng_index } , expected it to be "
458473 f"{ 'local' if engine .local else 'remote' } " )
474+
459475 if status == "HELLO" and engine .state == CoreEngineState .NEW :
460476
461477 # Send init message with DP config info.
0 commit comments