File tree Expand file tree Collapse file tree 2 files changed +11
-8
lines changed
examples/offline_inference Expand file tree Collapse file tree 2 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -97,10 +97,14 @@ def main(
9797 # with DP, each rank should process different prompts.
9898 # usually all the DP ranks process a full dataset,
9999 # and each rank processes a different part of the dataset.
100- promts_per_rank = len (prompts ) // dp_size
101- start = global_dp_rank * promts_per_rank
102- end = start + promts_per_rank
103- prompts = prompts [start :end ]
100+ floor = len (prompts ) // dp_size
101+ remainder = len (prompts ) % dp_size
102+
103+ # Distribute prompts into even groups.
104+ def start (rank ):
105+ return rank * floor + min (rank , remainder )
106+
107+ prompts = prompts [start (global_dp_rank ) : start (global_dp_rank + 1 )]
104108 if len (prompts ) == 0 :
105109 # if any rank has no prompts to process,
106110 # we need to set a placeholder prompt
Original file line number Diff line number Diff line change @@ -363,18 +363,17 @@ def __init__(
363363 local_engine_count = parallel_config .data_parallel_size_local
364364 local_start_index = parallel_config .data_parallel_rank_local
365365 dp_size = parallel_config .data_parallel_size
366+ dp_rank = parallel_config .data_parallel_rank
366367
367368 # SPMD mode is where there is an LLM instance per DP rank and
368369 # one core engine per LLM, see
369370 # examples/offline_inference/data_parallel.py.
370371 spmd_mode = local_start_index is not None
371372 if spmd_mode :
372373 assert local_engine_count == 1
373- self .core_engines = [
374- CoreEngine (index = local_start_index , local = True )
375- ]
374+ self .core_engines = [CoreEngine (index = dp_rank , local = True )]
376375 else :
377- assert parallel_config . data_parallel_rank == 0
376+ assert dp_rank == 0
378377 local_start_index = 0
379378 self .core_engines = [
380379 CoreEngine (index = i , local = (i < local_engine_count ))
You can’t perform that action at this time.
0 commit comments