You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help.
"trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future."
446
+
)
447
+
deftrtllm_mnnvl_all_reduce(
448
+
inp: torch.Tensor,
449
+
multicast_buffer_ptr: int, # Pointer address as integer
450
+
buffer_ptrs_dev: int, # Pointer address as integer
451
+
buffer_M: int,
452
+
buffer_flags_mnnvl: torch.Tensor,
453
+
nranks: int,
454
+
rank: int,
455
+
wait_for_results: bool,
456
+
launch_with_pdl: bool,
457
+
out: Optional[torch.Tensor] =None,
458
+
) ->None:
459
+
"""Perform a multi-node NVLink all-reduce operation across multiple GPUs.
460
+
461
+
This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL)
462
+
technology to efficiently combine tensors across multiple GPUs and nodes.
463
+
464
+
There are 3 steps:
465
+
1. scatter each GPU's input shard to the right unicast buffer
466
+
2. perform all-reduce on each GPU
467
+
3. broadcast the result to all GPUs
468
+
469
+
Args:
470
+
inp: Local Input Shard
471
+
multicast_buffer_ptr: Pointer to the multicast buffer as an integer
472
+
buffer_ptrs_dev: Pointer to device buffer pointers as an integer
473
+
buffer_M: Maximum number of elements // hidden_dim
474
+
buffer_flags_mnnvl: Tensor containing buffer state flags
475
+
nranks: Total number of ranks participating in the all-reduce
476
+
rank: Current process rank
477
+
wait_for_results: If True, store the result to out
478
+
launch_with_pdl: If True, launch using Programmatic Dependent Launch
479
+
[Optional] out: Output tensor to store the result (required if wait_for_results is True)
480
+
481
+
"""
482
+
483
+
iflen(inp.shape) !=2:
484
+
raiseValueError(
485
+
f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}."
486
+
)
487
+
488
+
# buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior.
489
+
ifinp.shape[0] >buffer_M:
490
+
raiseValueError(
491
+
f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}."
492
+
)
493
+
494
+
# Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm.
495
+
assertwait_for_resultsand (outisnotNone), (
496
+
"Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead."
497
+
)
498
+
module=get_trtllm_mnnvl_comm_module()
499
+
module.trtllm_mnnvl_allreduce_fusion(
500
+
input,
501
+
multicast_buffer_ptr,
502
+
buffer_ptrs_dev,
503
+
0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility.
504
+
buffer_flags_mnnvl,
505
+
nranks,
506
+
rank,
507
+
False, # No RMSNorm Fusion
508
+
launch_with_pdl,
509
+
False, # Use two-shot
510
+
out,
511
+
None,
512
+
None,
513
+
None,
514
+
None,
515
+
)
516
+
517
+
518
+
@deprecated(
519
+
"trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future."
520
+
)
521
+
deftrtllm_mnnvl_fused_allreduce_rmsnorm(
522
+
prenorm_output: torch.Tensor,
523
+
normed_output: torch.Tensor,
524
+
shard_input: torch.Tensor,
525
+
multicast_buffer_ptr: int, # Pointer address as integer
526
+
buffer_ptrs_dev: int, # Pointer address as integer
527
+
unicast_ptr: int, # Local unicast buffer pointer
528
+
buffer_M: int,
529
+
buffer_flags_mnnvl: torch.Tensor,
530
+
nranks: int,
531
+
rank: int,
532
+
gamma: torch.Tensor,
533
+
epsilon: float,
534
+
residual: torch.Tensor,
535
+
launch_with_pdl: bool,
536
+
) ->None:
537
+
"""Performs MNNVL TwoShot Allreduce + RMSNorm.
538
+
539
+
This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input.
540
+
After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer.
541
+
Note: multicast buffer is the same as the unicast buffer for the current rank.
542
+
543
+
Args:
544
+
prenorm_output: Output tensor for prenorm results
545
+
normed_output: Output tensor for normalized results
546
+
shard_input: Input tensor shard
547
+
multicast_buffer_ptr: Pointer address as integer for multicast buffer
548
+
buffer_ptrs_dev: Pointer address as integer for device buffer pointers
549
+
unicast_ptr: Pointer address as integer for unicast buffer
550
+
buffer_M: Maximum number of elements // hidden_dim
551
+
buffer_flags_mnnvl: Buffer flags for synchronization
552
+
nranks: Number of ranks in the tensor parallel group
553
+
rank: Current rank in the tensor parallel group
554
+
gamma: The gamma (norm weight) parameter for RMSNorm
555
+
epsilon: The epsilon parameter for RMSNorm
556
+
residual: The residual tensor to add
557
+
launch_with_pdl: Whether to launch with PDL
558
+
559
+
"""
560
+
iflen(shard_input.shape) !=2:
561
+
raiseValueError(
562
+
f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}."
563
+
)
564
+
565
+
# buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior.
566
+
ifshard_input.shape[0] >buffer_M:
567
+
raiseValueError(
568
+
f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}."
569
+
)
570
+
571
+
iflen(residual.shape) !=2:
572
+
raiseValueError(
573
+
f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}."
574
+
)
575
+
ifgamma.numel() !=shard_input.shape[1]:
576
+
raiseValueError(
577
+
f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements."
578
+
)
579
+
580
+
iflen(normed_output.shape) !=2:
581
+
raiseValueError(
582
+
f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}."
583
+
)
584
+
585
+
iflen(prenorm_output.shape) !=2:
586
+
raiseValueError(
587
+
f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}."
0 commit comments