@@ -126,10 +126,12 @@ def allreduce_benchmark(
126126 strategy : str = None ,
127127 inner_loop : int = 200 ,
128128 outer_loop : int = 10 ,
129+ tokens_range : str = "1,16384,2" ,
130+ hidden_sizes_range : str = "128,8192,2" ,
129131):
130132 """
131133 Benchmark AllReduce operations.
132-
134+
133135 Args:
134136 dtype: Data type for benchmarking
135137 test_range: Range specification (min,max,ratio)
@@ -139,6 +141,8 @@ def allreduce_benchmark(
139141 strategy: Specific strategy to test (if None, tests default set: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL)
140142 inner_loop: Number of iterations per timing measurement (default: 200)
141143 outer_loop: Number of timing measurements to take (default: 10)
144+ tokens_range: Range for number of tokens in 2D mode (min,max,ratio) (default: "1,16384,2")
145+ hidden_sizes_range: Range for hidden sizes in 2D mode (min,max,ratio) (default: "128,8192,2")
142146 """
143147 tllm .logger .set_level ('error' )
144148 world_size = tllm .mpi_world_size ()
@@ -166,15 +170,36 @@ def allreduce_benchmark(
166170
167171 # Parse test range
168172 min_size , max_size , ratio = [int (i ) for i in test_range .split ("," )]
169-
173+
170174 # generate shape list
171175 shape_list = []
172176
173177 if explore_2d :
174- num_seqs_list = [
175- 1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384
178+ # Parse tokens range
179+ min_tokens , max_tokens , tokens_ratio = [
180+ int (i ) for i in tokens_range .split ("," )
181+ ]
182+
183+ # Parse hidden sizes range
184+ min_hidden , max_hidden , hidden_ratio = [
185+ int (i ) for i in hidden_sizes_range .split ("," )
176186 ]
177- hidden_size_list = [128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 ]
187+
188+ # Generate token counts list
189+ num_seqs_list = []
190+ current = min_tokens
191+ while current <= max_tokens :
192+ num_seqs_list .append (current )
193+ current *= tokens_ratio
194+
195+ # Generate hidden sizes list
196+ hidden_size_list = []
197+ current = min_hidden
198+ while current <= max_hidden :
199+ hidden_size_list .append (current )
200+ current *= hidden_ratio
201+
202+ # Create all combinations
178203 for num_tokens , hidden_size in product (num_seqs_list , hidden_size_list ):
179204 shape_list .append ((num_tokens , hidden_size ))
180205 else :
@@ -196,7 +221,7 @@ def allreduce_benchmark(
196221 AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_FP8 ,
197222 AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 ,
198223 ]
199-
224+
200225 # Map strategy names to enum values
201226 strategy_map = {
202227 "NCCL" : AllReduceStrategy .NCCL ,
@@ -209,12 +234,14 @@ def allreduce_benchmark(
209234 "TWOSHOT" : AllReduceStrategy .TWOSHOT ,
210235 "AUTO" : AllReduceStrategy .AUTO ,
211236 }
212-
237+
213238 # Select strategies based on input
214239 if strategy :
215240 # Single strategy specified
216241 if strategy .upper () not in strategy_map :
217- raise ValueError (f"Unknown strategy: { strategy } . Available: { ', ' .join (strategy_map .keys ())} " )
242+ raise ValueError (
243+ f"Unknown strategy: { strategy } . Available: { ', ' .join (strategy_map .keys ())} "
244+ )
218245 strategies = [strategy_map [strategy .upper ()]]
219246 else :
220247 # Default: test main strategies
@@ -224,38 +251,44 @@ def allreduce_benchmark(
224251 AllReduceStrategy .NCCL_DEVICE ,
225252 AllReduceStrategy .MNNVL ,
226253 ]
227-
254+
228255 # Validate strategy compatibility for user buffer initialization
229256 # NCCL_SYMMETRIC and NCCL_DEVICE need UB with use_multicast=True
230257 # UB strategy needs UB with use_multicast=False
231258 # These two groups cannot be mixed in a single run
232- ub_multicast_strategies = {AllReduceStrategy .NCCL_SYMMETRIC , AllReduceStrategy .NCCL_DEVICE }
259+ ub_multicast_strategies = {
260+ AllReduceStrategy .NCCL_SYMMETRIC , AllReduceStrategy .NCCL_DEVICE
261+ }
233262 ub_no_multicast_strategies = {AllReduceStrategy .UB }
234-
235- has_multicast_strategies = any (s in ub_multicast_strategies for s in strategies )
236- has_no_multicast_strategies = any (s in ub_no_multicast_strategies for s in strategies )
237-
263+
264+ has_multicast_strategies = any (s in ub_multicast_strategies
265+ for s in strategies )
266+ has_no_multicast_strategies = any (s in ub_no_multicast_strategies
267+ for s in strategies )
268+
238269 # Error out if incompatible strategies are mixed
239270 if has_multicast_strategies and has_no_multicast_strategies :
240- multicast_strats = [s .name for s in strategies if s in ub_multicast_strategies ]
241- no_multicast_strats = [s .name for s in strategies if s in ub_no_multicast_strategies ]
271+ multicast_strats = [
272+ s .name for s in strategies if s in ub_multicast_strategies
273+ ]
274+ no_multicast_strats = [
275+ s .name for s in strategies if s in ub_no_multicast_strategies
276+ ]
242277 raise ValueError (
243278 f"Incompatible strategies selected: { multicast_strats } require use_multicast=True "
244279 f"while { no_multicast_strats } require use_multicast=False. "
245- f"Please run these strategies separately using --strategy."
246- )
247-
280+ f"Please run these strategies separately using --strategy." )
281+
248282 # Initialize user buffers if any strategy needs it
249283 needs_ub = has_multicast_strategies or has_no_multicast_strategies
250-
284+
251285 if needs_ub :
252286 max_bytes = max_size * dtype_size_bytes
253287 use_multicast = has_multicast_strategies # True for NCCL_SYMMETRIC/NCCL_DEVICE, False for UB
254-
255- ub .initialize_userbuffers_manager (
256- world_size , 1 , 1 , rank ,
257- torch .cuda .device_count (), max_bytes , use_multicast
258- )
288+
289+ ub .initialize_userbuffers_manager (world_size , 1 , 1 , rank ,
290+ torch .cuda .device_count (), max_bytes ,
291+ use_multicast )
259292
260293 df = pd .DataFrame ()
261294 for (num_tokens , hidden_size ) in shape_list :
@@ -285,6 +318,10 @@ def allreduce_benchmark(
285318 if fusion == AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 and sm_version < 100 :
286319 continue
287320
321+ # UB strategy doesn't support NONE fusion
322+ if strategy == AllReduceStrategy .UB and fusion == AllReduceFusionOp .NONE :
323+ continue
324+
288325 median_ms = profile_allreduce (
289326 mapping = mapping ,
290327 enable_cudagraph = enable_cudagraph ,
@@ -327,26 +364,59 @@ def allreduce_benchmark(
327364
328365if __name__ == "__main__" :
329366 parser = ArgumentParser ()
330- parser .add_argument ("--dtype" , "-t" , default = "bfloat16" ,
367+ parser .add_argument ("--dtype" ,
368+ "-t" ,
369+ default = "bfloat16" ,
331370 help = "Data type for benchmarking" )
332371 parser .add_argument (
333372 "--range" ,
334373 "-r" ,
335374 default = "256,256000000,4" , # 256 to 256M
336375 help = "min_size,max_size,multiplicative_ratio" )
337- parser .add_argument ("--explore_2d" , action = "store_true" , default = False ,
338- help = "Explore 2D parameter space (num_tokens x hidden_size)" )
339- parser .add_argument ("--enable_cudagraph" , action = "store_true" ,
376+ parser .add_argument (
377+ "--explore_2d" ,
378+ action = "store_true" ,
379+ default = False ,
380+ help = "Explore 2D parameter space (num_tokens x hidden_size)" )
381+ parser .add_argument ("--enable_cudagraph" ,
382+ action = "store_true" ,
340383 help = "Enable CUDA graph capture" )
341- parser .add_argument ("--save_csv" , type = str , default = None ,
384+ parser .add_argument ("--save_csv" ,
385+ type = str ,
386+ default = None ,
342387 help = "Path to save CSV results" )
343- parser .add_argument ("--strategy" , type = str , default = None ,
344- help = "Test specific strategy. If not specified, defaults to: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL. "
345- "Available: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL, MIN_LATENCY, UB, ONESHOT, TWOSHOT, AUTO" )
346- parser .add_argument ("--inner_loop" , type = int , default = 200 ,
347- help = "Number of iterations per timing measurement (default: 200)" )
348- parser .add_argument ("--outer_loop" , type = int , default = 10 ,
349- help = "Number of timing measurements to take (default: 10)" )
388+ parser .add_argument (
389+ "--strategy" ,
390+ type = str ,
391+ default = None ,
392+ help =
393+ "Test specific strategy. If not specified, defaults to: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL. "
394+ "Available: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL, MIN_LATENCY, UB, ONESHOT, TWOSHOT, AUTO"
395+ )
396+ parser .add_argument (
397+ "--inner_loop" ,
398+ type = int ,
399+ default = 200 ,
400+ help = "Number of iterations per timing measurement (default: 200)" )
401+ parser .add_argument (
402+ "--outer_loop" ,
403+ type = int ,
404+ default = 10 ,
405+ help = "Number of timing measurements to take (default: 10)" )
406+ parser .add_argument (
407+ "--tokens_range" ,
408+ type = str ,
409+ default = "1,16384,2" ,
410+ help =
411+ "Range for number of tokens in 2D mode: min,max,ratio (default: 1,16384,2)"
412+ )
413+ parser .add_argument (
414+ "--hidden_sizes_range" ,
415+ type = str ,
416+ default = "128,8192,2" ,
417+ help =
418+ "Range for hidden sizes in 2D mode: min,max,ratio (default: 128,8192,2)"
419+ )
350420
351421 args = parser .parse_args ()
352422
@@ -359,4 +429,6 @@ def allreduce_benchmark(
359429 strategy = args .strategy ,
360430 inner_loop = args .inner_loop ,
361431 outer_loop = args .outer_loop ,
432+ tokens_range = args .tokens_range ,
433+ hidden_sizes_range = args .hidden_sizes_range ,
362434 )
0 commit comments