3737from tensorrt_llm .functional import AllReduceParams , AllReduceStrategy
3838
3939
40- def check_mnnvl_available ():
41- """Check if MNNVL is available on this system"""
42- try :
43- from tensorrt_llm ._mnnvl_utils import supports_mnnvl
44- return supports_mnnvl ()
45- except Exception :
46- return False
40+ def test_mnnvl_available (allreduce_module , rank ):
41+ """
42+ Test if MNNVL is actually available and active in the AllReduce module.
43+
44+ Returns:
45+ tuple: (is_available: bool, details: dict)
46+ """
47+ import platform
48+
49+ details = {
50+ 'has_mnnvl_attribute' : False ,
51+ 'mnnvl_is_not_none' : False ,
52+ 'architecture' : platform .machine (),
53+ 'is_arm64' : platform .machine () in ['aarch64' , 'arm64' ],
54+ }
55+
56+ # Check if module has mnnvl_allreduce attribute
57+ details ['has_mnnvl_attribute' ] = hasattr (allreduce_module ,
58+ 'mnnvl_allreduce' )
59+
60+ if details ['has_mnnvl_attribute' ]:
61+ details [
62+ 'mnnvl_is_not_none' ] = allreduce_module .mnnvl_allreduce is not None
63+
64+ # MNNVL is available if both checks pass
65+ is_available = details ['has_mnnvl_attribute' ] and details [
66+ 'mnnvl_is_not_none' ]
67+
68+ return is_available , details
4769
4870
4971def verify_strategy_active (allreduce_module , requested_strategy , rank ):
@@ -54,16 +76,15 @@ def verify_strategy_active(allreduce_module, requested_strategy, rank):
5476 tuple: (is_active: bool, actual_implementation: str, message: str)
5577 """
5678 if requested_strategy == AllReduceStrategy .MNNVL :
57- # Check if MNNVL implementation is active
58- has_mnnvl = hasattr (allreduce_module , 'mnnvl_allreduce' ) and \
59- allreduce_module .mnnvl_allreduce is not None
79+ # Test MNNVL availability
80+ is_available , details = test_mnnvl_available (allreduce_module , rank )
6081
61- if has_mnnvl :
82+ if is_available :
6283 return True , "MNNVL" , "MNNVL AllReduce is active"
6384 else :
6485 # MNNVL requested but not active - likely fell back to C++ plugin
65- return False , "C++ Plugin (likely NCCL)" , \
66- "MNNVL requested but module does not have active mnnvl_allreduce. Likely fell back to NCCL via C++ plugin."
86+ msg = f"MNNVL requested but not active. Details: { details } "
87+ return False , " C++ Plugin (likely NCCL)" , msg
6788
6889 # For other strategies, we can't easily verify, so assume they're working
6990 return True , requested_strategy .name , f"{ requested_strategy .name } assumed active"
@@ -215,7 +236,11 @@ def run_benchmark(args):
215236 gpus_per_node = local_mpi_size ()
216237
217238 if world_size == 1 :
218- raise RuntimeError ("Benchmark must run with mpi_world_size > 1" )
239+ if rank == 0 :
240+ print ("ERROR: Benchmark must run with mpi_world_size > 1" ,
241+ file = sys .stderr ,
242+ flush = True )
243+ sys .exit (1 )
219244
220245 # Device setup
221246 torch .cuda .set_device (local_rank )
@@ -251,26 +276,6 @@ def run_benchmark(args):
251276 False )
252277 # NCCL, MIN_LATENCY, MNNVL don't need UB initialization
253278
254- # Pre-flight check for MNNVL
255- if strategy == AllReduceStrategy .MNNVL :
256- if rank == 0 :
257- if not check_mnnvl_available ():
258- print (
259- f"ERROR: MNNVL strategy requested but MNNVL is not available on this system" ,
260- file = sys .stderr ,
261- flush = True )
262- print (f"Possible reasons:" , file = sys .stderr )
263- print (f" - System is not ARM64/aarch64" , file = sys .stderr )
264- print (
265- f" - NVLink is not available or not all links are active" ,
266- file = sys .stderr )
267- print (f" - MNNVL dependencies not met" ,
268- file = sys .stderr ,
269- flush = True )
270- sys .exit (1 )
271- else :
272- print (f"MNNVL pre-flight check: PASSED" , flush = True )
273-
274279 # Create AllReduce module (pass dtype for MNNVL support) and capture logs
275280 def create_allreduce ():
276281 return AllReduce (mapping = mapping , strategy = strategy , dtype = torch_dtype )
@@ -310,10 +315,82 @@ def create_allreduce():
310315 print (f"The AllReduce module fell back to: { actual_impl } " ,
311316 file = sys .stderr ,
312317 flush = True )
318+
319+ # Provide strategy-specific diagnostic info
320+ if strategy == AllReduceStrategy .MNNVL :
321+ print (f"\n MNNVL troubleshooting:" , file = sys .stderr )
322+ print (f" - Ensure you're on ARM64/aarch64 architecture" ,
323+ file = sys .stderr )
324+ print (
325+ f" - Verify NVLink is available and all links are active" ,
326+ file = sys .stderr )
327+ print (f" - Check that world_size matches number of GPUs" ,
328+ file = sys .stderr )
329+ print (f" - Confirm MNNVL dependencies are installed" ,
330+ file = sys .stderr )
331+
313332 if init_logs :
314- print (f"Initialization logs:\n { init_logs } " ,
333+ print (f"\n Initialization logs:\n { init_logs } " ,
315334 file = sys .stderr ,
316335 flush = True )
336+
337+ # Ensure all ranks exit together
338+ tllm .mpi_barrier ()
339+ sys .exit (1 )
340+
341+ # Run a test operation to verify strategy works before benchmarking
342+ if rank == 0 :
343+ print (f"\n Testing { strategy .name } with a small operation..." ,
344+ flush = True )
345+
346+ test_tensor = torch .ones ((128 , args .hidden_size ),
347+ dtype = torch_dtype ,
348+ device = "cuda" )
349+
350+ # Use appropriate fusion op for testing (UB doesn't support NONE)
351+ test_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM if strategy == AllReduceStrategy .UB else AllReduceFusionOp .NONE
352+
353+ if test_fusion_op == AllReduceFusionOp .RESIDUAL_RMS_NORM :
354+ # Setup RMS norm for testing
355+ norm_weight = torch .randn ((args .hidden_size , ),
356+ dtype = torch_dtype ,
357+ device = "cuda" )
358+ norm = RMSNorm (hidden_size = args .hidden_size ,
359+ dtype = torch_dtype ,
360+ eps = 1e-5 ).cuda ()
361+ norm .weight .data .copy_ (norm_weight )
362+ test_params = AllReduceParams (
363+ fusion_op = test_fusion_op ,
364+ residual = test_tensor ,
365+ norm_weight = norm .weight ,
366+ eps = norm .variance_epsilon ,
367+ )
368+ else :
369+ test_params = AllReduceParams (fusion_op = test_fusion_op )
370+
371+ try :
372+ test_output = allreduce (test_tensor , all_reduce_params = test_params )
373+ torch .cuda .synchronize ()
374+
375+ # Verify correctness for NONE fusion
376+ if test_fusion_op == AllReduceFusionOp .NONE :
377+ expected = test_tensor * world_size
378+ torch .testing .assert_close (test_output , expected )
379+
380+ if rank == 0 :
381+ print (
382+ f"Test passed! Strategy { strategy .name } is working correctly." ,
383+ flush = True )
384+ except Exception as e :
385+ if rank == 0 :
386+ print (
387+ f"\n ERROR: Test operation failed for strategy { strategy .name } !" ,
388+ file = sys .stderr ,
389+ flush = True )
390+ print (f"Error: { e } " , file = sys .stderr , flush = True )
391+
392+ # Ensure all ranks exit together
393+ tllm .mpi_barrier ()
317394 sys .exit (1 )
318395
319396 # Print header
0 commit comments