Skip to content

Commit ec23b86

Browse files
njhillyiz-liu
authored andcommitted
[BugFix] Fix multi-node offline data-parallel (vllm-project#18981)
Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Yizhou Liu <[email protected]> Signed-off-by: amit <[email protected]>
1 parent 2158a3d commit ec23b86

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,14 @@ def main(
9797
# with DP, each rank should process different prompts.
9898
# usually all the DP ranks process a full dataset,
9999
# and each rank processes a different part of the dataset.
100-
promts_per_rank = len(prompts) // dp_size
101-
start = global_dp_rank * promts_per_rank
102-
end = start + promts_per_rank
103-
prompts = prompts[start:end]
100+
floor = len(prompts) // dp_size
101+
remainder = len(prompts) % dp_size
102+
103+
# Distribute prompts into even groups.
104+
def start(rank):
105+
return rank * floor + min(rank, remainder)
106+
107+
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
104108
if len(prompts) == 0:
105109
# if any rank has no prompts to process,
106110
# we need to set a placeholder prompt

vllm/v1/engine/core_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,18 +363,17 @@ def __init__(
363363
local_engine_count = parallel_config.data_parallel_size_local
364364
local_start_index = parallel_config.data_parallel_rank_local
365365
dp_size = parallel_config.data_parallel_size
366+
dp_rank = parallel_config.data_parallel_rank
366367

367368
# SPMD mode is where there is an LLM instance per DP rank and
368369
# one core engine per LLM, see
369370
# examples/offline_inference/data_parallel.py.
370371
spmd_mode = local_start_index is not None
371372
if spmd_mode:
372373
assert local_engine_count == 1
373-
self.core_engines = [
374-
CoreEngine(index=local_start_index, local=True)
375-
]
374+
self.core_engines = [CoreEngine(index=dp_rank, local=True)]
376375
else:
377-
assert parallel_config.data_parallel_rank == 0
376+
assert dp_rank == 0
378377
local_start_index = 0
379378
self.core_engines = [
380379
CoreEngine(index=i, local=(i < local_engine_count))

0 commit comments

Comments
 (0)