Skip to content

Commit 6344177

Browse files
committed
Remove debug prints and add compatability interface.
1 parent bca4f5d commit 6344177

File tree

3 files changed

+238
-28
lines changed

3 files changed

+238
-28
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
136136
if not host_ptr_array:
137137
return None
138138

139-
for addr in host_ptr_array:
140-
print(f"DEBUG: ptr_array: 0x{addr:x}")
141-
142139
ArrayType = ctypes.c_uint64 * len(host_ptr_array)
143140
c_array = ArrayType(*host_ptr_array)
144141
size_in_bytes = ctypes.sizeof(c_array)
@@ -788,9 +785,6 @@ def __del__(self):
788785
if not hasattr(self, "is_multi_node"):
789786
return
790787

791-
if not self.is_multi_node:
792-
return
793-
794788
# Skip cleanup during Python finalization to avoid segfaults
795789
# Especially cause the CUDA context could be destroyed at this point.
796790
if sys.is_finalizing():
@@ -976,7 +970,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
976970
)
977971
else:
978972
# Implement the allgather logic with ipc socket
979-
# TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help.
980973
all_shareable_uc_handles = [None] * self.group_size
981974
for i in range(self.group_size):
982975
self.comm_backend.barrier()
@@ -988,10 +981,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
988981
all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd()
989982
cuda.cuCtxSynchronize()
990983

991-
print(
992-
f"[Rank {self.group_rank}] all_shareable_uc_handles: {all_shareable_uc_handles}"
993-
)
994-
995984
# Import remote handles
996985
for p in range(self.group_size):
997986
if p != self.group_rank:

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 228 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from enum import Enum
1212

1313
import torch
14+
from typing_extensions import deprecated
1415

1516
from flashinfer.comm.mapping import Mapping
1617

@@ -292,7 +293,7 @@ def trtllm_mnnvl_allreduce(
292293
return output
293294

294295

295-
def trtllm_mnnvl_fused_allreduce_rmsnorm(
296+
def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
296297
input: torch.Tensor,
297298
residual_in: torch.Tensor,
298299
gamma: torch.Tensor,
@@ -303,10 +304,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm(
303304
launch_with_pdl: bool = False,
304305
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
305306
) -> Tuple[torch.Tensor, torch.Tensor]:
306-
"""Performs MNNVL Allreduce + RMSNorm.
307+
"""Performs MNNVL Allreduce + Residual + RMSNorm.
307308
308309
This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input.
309-
After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer.
310+
After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer.
310311
Note: multicast buffer is the same as the unicast buffer for the current rank.
311312
312313
Args:
@@ -321,8 +322,8 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm(
321322
strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided.
322323
323324
Returns:
324-
output: Normalized tensor [num_tokens, hidden_dim]
325-
residual_out: Residual output tensor [num_tokens, hidden_dim]
325+
output: Add-residual and normalized tensor [num_tokens, hidden_dim]
326+
residual_out: Add-residual tensor [num_tokens, hidden_dim]
326327
"""
327328

328329
if epsilon is None:
@@ -365,10 +366,6 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm(
365366
)
366367
)
367368

368-
print(
369-
f"[Rank {workspace.rank}] workspace.mc_ptr: {workspace.mc_ptr}, workspace.uc_ptrs_dev: {workspace.uc_ptrs_dev}, workspace.uc_ptr_local: {workspace.uc_ptr_local}"
370-
)
371-
372369
module.trtllm_mnnvl_allreduce_fusion(
373370
input,
374371
workspace.mc_ptr,
@@ -387,3 +384,225 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm(
387384
epsilon,
388385
)
389386
return output, residual_out
387+
388+
389+
# Legacy API that has been deprecated; Left for backward compatibility
390+
@deprecated(
391+
"get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead"
392+
)
393+
def get_allreduce_mnnvl_workspace(
394+
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
395+
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
396+
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
397+
398+
This function allocates and initializes the workspace buffers required for performing
399+
multi-node NVLink all-reduce operations. It creates:
400+
1. A multicast GPU buffer for communication between nodes
401+
2. A flags tensor to track buffer state
402+
3. Maximum number of elements that can fit in the buffer
403+
404+
The buffer size is calculated to efficiently handle common hidden dimensions
405+
(2048, 4096, 5120, 7168, 8192) by using their LCM of 286720.
406+
407+
Args:
408+
mapping: Tensor parallel mapping configuration containing rank info
409+
dtype: Data type of the tensors being reduced
410+
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
411+
412+
Returns:
413+
Tuple containing:
414+
- McastGPUBuffer: Multicast buffer for inter-node communication
415+
- torch.Tensor: Buffer flags tensor tracking state
416+
- int: Maximum number of elements that can fit in buffer
417+
"""
418+
# buffer shape: [3, 2, buffer_tokens, hidden_dim]
419+
stride = 3 * 2 * dtype.itemsize
420+
# LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720
421+
# max_num_elements must be a multiple of 286720
422+
lcm_hidden_dim = 286720
423+
TARGET_WORKSPACE_SIZE_BYTES = (
424+
buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000
425+
)
426+
buffer_size_in_bytes = math.ceil(
427+
TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)
428+
) * (lcm_hidden_dim * stride)
429+
430+
# Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout.
431+
workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes)
432+
433+
mcast_buffer = workspace.mcast_buffer_handle
434+
buffer_flags = workspace.buffer_flags
435+
max_num_elements = workspace.buffer_size_bytes // stride
436+
437+
return (
438+
mcast_buffer,
439+
buffer_flags,
440+
max_num_elements,
441+
)
442+
443+
444+
@deprecated(
445+
"trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future."
446+
)
447+
def trtllm_mnnvl_all_reduce(
448+
inp: torch.Tensor,
449+
multicast_buffer_ptr: int, # Pointer address as integer
450+
buffer_ptrs_dev: int, # Pointer address as integer
451+
buffer_M: int,
452+
buffer_flags_mnnvl: torch.Tensor,
453+
nranks: int,
454+
rank: int,
455+
wait_for_results: bool,
456+
launch_with_pdl: bool,
457+
out: Optional[torch.Tensor] = None,
458+
) -> None:
459+
"""Perform a multi-node NVLink all-reduce operation across multiple GPUs.
460+
461+
This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL)
462+
technology to efficiently combine tensors across multiple GPUs and nodes.
463+
464+
There are 3 steps:
465+
1. scatter each GPU's input shard to the right unicast buffer
466+
2. perform all-reduce on each GPU
467+
3. broadcast the result to all GPUs
468+
469+
Args:
470+
inp: Local Input Shard
471+
multicast_buffer_ptr: Pointer to the multicast buffer as an integer
472+
buffer_ptrs_dev: Pointer to device buffer pointers as an integer
473+
buffer_M: Maximum number of elements // hidden_dim
474+
buffer_flags_mnnvl: Tensor containing buffer state flags
475+
nranks: Total number of ranks participating in the all-reduce
476+
rank: Current process rank
477+
wait_for_results: If True, store the result to out
478+
launch_with_pdl: If True, launch using Programmatic Dependent Launch
479+
[Optional] out: Output tensor to store the result (required if wait_for_results is True)
480+
481+
"""
482+
483+
if len(inp.shape) != 2:
484+
raise ValueError(
485+
f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}."
486+
)
487+
488+
# buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior.
489+
if inp.shape[0] > buffer_M:
490+
raise ValueError(
491+
f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}."
492+
)
493+
494+
# Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm.
495+
assert wait_for_results and (out is not None), (
496+
"Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead."
497+
)
498+
module = get_trtllm_mnnvl_comm_module()
499+
module.trtllm_mnnvl_allreduce_fusion(
500+
input,
501+
multicast_buffer_ptr,
502+
buffer_ptrs_dev,
503+
0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility.
504+
buffer_flags_mnnvl,
505+
nranks,
506+
rank,
507+
False, # No RMSNorm Fusion
508+
launch_with_pdl,
509+
False, # Use two-shot
510+
out,
511+
None,
512+
None,
513+
None,
514+
None,
515+
)
516+
517+
518+
@deprecated(
519+
"trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future."
520+
)
521+
def trtllm_mnnvl_fused_allreduce_rmsnorm(
522+
prenorm_output: torch.Tensor,
523+
normed_output: torch.Tensor,
524+
shard_input: torch.Tensor,
525+
multicast_buffer_ptr: int, # Pointer address as integer
526+
buffer_ptrs_dev: int, # Pointer address as integer
527+
unicast_ptr: int, # Local unicast buffer pointer
528+
buffer_M: int,
529+
buffer_flags_mnnvl: torch.Tensor,
530+
nranks: int,
531+
rank: int,
532+
gamma: torch.Tensor,
533+
epsilon: float,
534+
residual: torch.Tensor,
535+
launch_with_pdl: bool,
536+
) -> None:
537+
"""Performs MNNVL TwoShot Allreduce + RMSNorm.
538+
539+
This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input.
540+
After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer.
541+
Note: multicast buffer is the same as the unicast buffer for the current rank.
542+
543+
Args:
544+
prenorm_output: Output tensor for prenorm results
545+
normed_output: Output tensor for normalized results
546+
shard_input: Input tensor shard
547+
multicast_buffer_ptr: Pointer address as integer for multicast buffer
548+
buffer_ptrs_dev: Pointer address as integer for device buffer pointers
549+
unicast_ptr: Pointer address as integer for unicast buffer
550+
buffer_M: Maximum number of elements // hidden_dim
551+
buffer_flags_mnnvl: Buffer flags for synchronization
552+
nranks: Number of ranks in the tensor parallel group
553+
rank: Current rank in the tensor parallel group
554+
gamma: The gamma (norm weight) parameter for RMSNorm
555+
epsilon: The epsilon parameter for RMSNorm
556+
residual: The residual tensor to add
557+
launch_with_pdl: Whether to launch with PDL
558+
559+
"""
560+
if len(shard_input.shape) != 2:
561+
raise ValueError(
562+
f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}."
563+
)
564+
565+
# buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior.
566+
if shard_input.shape[0] > buffer_M:
567+
raise ValueError(
568+
f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}."
569+
)
570+
571+
if len(residual.shape) != 2:
572+
raise ValueError(
573+
f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}."
574+
)
575+
if gamma.numel() != shard_input.shape[1]:
576+
raise ValueError(
577+
f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements."
578+
)
579+
580+
if len(normed_output.shape) != 2:
581+
raise ValueError(
582+
f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}."
583+
)
584+
585+
if len(prenorm_output.shape) != 2:
586+
raise ValueError(
587+
f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}."
588+
)
589+
590+
module = get_trtllm_mnnvl_comm_module()
591+
592+
module.trtllm_mnnvl_allreduce_fusion(
593+
shard_input,
594+
multicast_buffer_ptr,
595+
buffer_ptrs_dev,
596+
unicast_ptr,
597+
buffer_flags_mnnvl,
598+
nranks,
599+
rank,
600+
True, # RMSNorm Fusion
601+
launch_with_pdl,
602+
False,
603+
normed_output,
604+
prenorm_output,
605+
residual,
606+
gamma,
607+
epsilon,
608+
)

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,16 @@ def func(
4242
if enable_fusion:
4343
trtllm_mnnvl_ar.mpi_barrier()
4444

45-
output, residual_out = trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm(
46-
input,
47-
residual,
48-
norm_weight,
49-
workspace,
50-
eps,
51-
launch_with_pdl=use_pdl,
52-
strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO,
45+
output, residual_out = (
46+
trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm(
47+
input,
48+
residual,
49+
norm_weight,
50+
workspace,
51+
eps,
52+
launch_with_pdl=use_pdl,
53+
strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO,
54+
)
5355
)
5456

5557
return output.view(shape), residual_out.view(shape)

0 commit comments

Comments
 (0)