Skip to content

Commit a662169

Browse files
committed
Fix offline DP compatibility
Signed-off-by: Nick Hill <[email protected]>
1 parent a551183 commit a662169

File tree

4 files changed

+42
-23
lines changed

4 files changed

+42
-23
lines changed

vllm/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1546,7 +1546,6 @@ def __post_init__(self) -> None:
15461546
if self.data_parallel_size > 1:
15471547
# Data parallel was specified in the engine args.
15481548
self.data_parallel_master_port = get_open_port()
1549-
# TODO multi-node
15501549
else:
15511550
# Otherwise fall back to env vars (e.g. for offline SPMD case).
15521551
self.data_parallel_size = envs.VLLM_DP_SIZE

vllm/entrypoints/cli/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def run_headless(args: argparse.Namespace):
105105
target_fn=EngineCoreProc.run_engine_core,
106106
local_engine_count=local_engine_count,
107107
start_index=engine_args.data_parallel_start_rank,
108+
local_start_index=0,
108109
vllm_config=vllm_config,
109110
on_head_node=False,
110111
input_address=input_address,

vllm/v1/engine/core_client.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,28 @@ def sigusr1_handler(signum, frame):
363363
self._finalizer = weakref.finalize(self, self.resources)
364364

365365
parallel_config = vllm_config.parallel_config
366-
dp_size = parallel_config.data_parallel_size
367366
local_engine_count = parallel_config.data_parallel_size_local
367+
start_index = parallel_config.data_parallel_rank
368+
local_start_index = parallel_config.data_parallel_rank_local
369+
370+
# SPMD mode is where there is an LLM instance per DP rank and one
371+
# core engine per LLM, see examples/offline_inference/data_parallel.py.
372+
spmd_mode = local_start_index is not None
373+
if spmd_mode:
374+
assert local_engine_count == 1
375+
self.core_engines = [
376+
CoreEngine(index=local_start_index, local=True)
377+
]
378+
else:
379+
assert start_index == 0
380+
local_start_index = 0
381+
self.core_engines = [
382+
CoreEngine(index=i, local=(i < local_engine_count))
383+
for i in range(parallel_config.data_parallel_size)
384+
]
368385

369386
input_address, output_address = self._get_zmq_addresses(
370-
parallel_config)
387+
parallel_config, spmd_mode)
371388

372389
# Create input and output sockets.
373390
self.input_socket = self.resources.input_socket = make_zmq_socket(
@@ -378,6 +395,7 @@ def sigusr1_handler(signum, frame):
378395
zmq.constants.PULL)
379396
# Start local engines.
380397
if local_engine_count:
398+
# In server mode, start_index and local_start_index will both be 0.
381399
self.resources.local_engine_manager = CoreEngineProcManager(
382400
EngineCoreProc.run_engine_core,
383401
vllm_config=vllm_config,
@@ -386,12 +404,9 @@ def sigusr1_handler(signum, frame):
386404
input_address=input_address,
387405
on_head_node=True,
388406
local_engine_count=local_engine_count,
389-
start_index=0)
407+
start_index=start_index,
408+
local_start_index=local_start_index)
390409

391-
self.core_engines = [
392-
CoreEngine(index=i, local=(i < local_engine_count))
393-
for i in range(dp_size)
394-
]
395410
self.core_engine = self.core_engines[0]
396411

397412
# Wait for engine core process(es) to start.
@@ -400,12 +415,13 @@ def sigusr1_handler(signum, frame):
400415
self.utility_results: dict[int, AnyFuture] = {}
401416

402417
@staticmethod
403-
def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]:
418+
def _get_zmq_addresses(parallel_config: ParallelConfig,
419+
spmd_mode: bool) -> tuple[str, str]:
404420
"""Returns (input_address, output_address)."""
405421
dp_size = parallel_config.data_parallel_size
406422
local_engine_count = parallel_config.data_parallel_size_local
407423

408-
if local_engine_count == dp_size:
424+
if local_engine_count == dp_size or spmd_mode:
409425
input_address = get_open_zmq_ipc_path()
410426
output_address = get_open_zmq_ipc_path()
411427
else:
@@ -422,8 +438,6 @@ def _wait_for_engine_startup(self, output_address: str,
422438
# Get a sync handle to the socket which can be sync or async.
423439
sync_input_socket = zmq.Socket.shadow(self.input_socket)
424440

425-
# TODO offline case compatibility
426-
427441
# Wait for engine core process(es) to send ready messages.
428442
local_count = parallel_config.data_parallel_size_local
429443
remote_count = len(self.core_engines) - local_count
@@ -444,18 +458,20 @@ def _wait_for_engine_startup(self, output_address: str,
444458
# Receive HELLO and READY messages from the input socket.
445459
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
446460
eng_index = int.from_bytes(eng_identity, byteorder="little")
447-
if eng_index > len(self.core_engines):
448-
raise RuntimeError(
449-
f"Message from engine rank larger than "
450-
f"configured data parallel size: {eng_index}")
451-
engine = self.core_engines[eng_index]
461+
engine = next(
462+
(e for e in self.core_engines if e.identity == eng_identity),
463+
None)
464+
if engine is None:
465+
raise RuntimeError(f"Message from engine with unexpected data "
466+
f"parallel rank: {eng_index}")
452467
msg = msgspec.msgpack.decode(ready_msg_bytes)
453468
status, local = msg["status"], msg["local"]
454469
if local != engine.local:
455470
raise RuntimeError(f"{status} message from "
456471
f"{'local' if local else 'remote'} "
457-
f" engine {eng_index}, expected it to be "
472+
f"engine {eng_index}, expected it to be "
458473
f"{'local' if engine.local else 'remote'}")
474+
459475
if status == "HELLO" and engine.state == CoreEngineState.NEW:
460476

461477
# Send init message with DP config info.

vllm/v1/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
target_fn: Callable,
106106
local_engine_count: int,
107107
start_index: int,
108+
local_start_index: int,
108109
vllm_config: VllmConfig,
109110
on_head_node: bool,
110111
input_address: str,
@@ -121,14 +122,15 @@ def __init__(
121122
}
122123

123124
self.processes = []
124-
for local_index in range(local_engine_count):
125-
index = local_index + start_index
125+
for index in range(local_engine_count):
126+
local_index = local_start_index + index
127+
global_index = start_index + index
126128
# Start EngineCore in background process.
127129
self.processes.append(
128130
context.Process(target=target_fn,
129-
name=f"EngineCore_{index}",
131+
name=f"EngineCore_{global_index}",
130132
kwargs=common_kwargs | {
131-
"dp_rank": index,
133+
"dp_rank": global_index,
132134
"local_dp_rank": local_index,
133135
}))
134136

@@ -172,7 +174,8 @@ def shutdown(procs: list[multiprocessing.Process], input_address: str):
172174
remaining = deadline - time.monotonic()
173175
if remaining <= 0:
174176
break
175-
proc.join(remaining)
177+
if proc.is_alive():
178+
proc.join(remaining)
176179

177180
for proc in procs:
178181
if proc.is_alive():

0 commit comments

Comments
 (0)