From 5c2dcff1472d0866992c15ee5806852d7b1676ba Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 16:42:49 +0000 Subject: [PATCH 01/10] Check fabric handle support --- flashinfer/comm/allreduce.py | 20 ++++--- flashinfer/comm/mnnvl.py | 90 ++++++++++++++++-------------- flashinfer/comm/trtllm_ar.py | 26 +++++++-- flashinfer/comm/trtllm_mnnvl_ar.py | 1 - 4 files changed, 80 insertions(+), 57 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index cfafa99220..2c70590109 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -60,7 +60,7 @@ from .mapping import Mapping -from .mnnvl import CommBackend +from .mnnvl import CommBackend, SymmDeviceMemory # Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) # Import them for runtime use but type hint as int for mypy compatibility @@ -95,7 +95,7 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype = torch.float16, - process_group: Optional["torch.distributed.ProcessGroup"] = None, + comm_backend: Optional[CommBackend] = None, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -106,7 +106,7 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - process_group: PyTorch distributed process group + comm_backend: Communication backend **kwargs: Additional arguments for workspace creation """ super().__init__(tp_size, tp_rank) @@ -117,7 +117,7 @@ def __init__( tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - group=process_group, + comm_backend=comm_backend, create_metadata=True, use_fp32_lamport=dtype == torch.float32, ) @@ -125,11 +125,12 @@ def __init__( # Store essential attributes for easy access # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True workspace_tuple = cast( - Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], self._internal_workspace, ) self.ipc_handles = workspace_tuple[0] self.workspace_tensor = workspace_tuple[1] - self.metadata = workspace_tuple[2] + self.mem_handles = workspace_tuple[2] + self.metadata = workspace_tuple[3] @property def backend(self) -> str: @@ -165,7 +166,10 @@ def destroy(self) -> None: if getattr(self, "_destroyed", False): return # Already destroyed, nothing to do - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) + del self.ipc_handles + del self.workspace_tensor + del self.mem_handles + del self.metadata self._destroyed = True @@ -423,7 +427,7 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - process_group=process_group, + comm_backend=comm_backend, ) elif actual_backend == "mnnvl": diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 13ca4f534d..a366e1bc10 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -764,9 +764,31 @@ def close(self) -> None: self._socket.close() +def is_mnnvl_fabric_supported(device_idx: int) -> bool: + fabric_handle_supported = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, + device_idx, + ) + ) + if fabric_handle_supported == 0: + return False + + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + fabric_info = pynvml.c_nvmlGpuFabricInfoV_t() + pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info)) + if fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED and fabric_info.clusterUuid[0] != 0: + return True + return False + finally: + pynvml.nvmlShutdown() + + # TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. -class McastDeviceMemory: - """Python port of McastDeviceMemory from TensorRT-LLM""" +class SymmDeviceMemory: + """Python port of SymmDeviceMemory from TensorRT-LLM""" def __init__( self, @@ -774,8 +796,9 @@ def __init__( group_size: int, group_rank: int, device_idx: int, - is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, + enable_multicast: bool = True, + allocate_signal_pads: bool = True, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -795,7 +818,6 @@ def __init__( checkCudaErrors(cudart.cudaSetDevice(device_idx)) - self.is_multi_node = is_multi_node self.device_idx = device_idx self.group_size = group_size self.group_rank = group_rank @@ -828,31 +850,20 @@ def __init__( ) if multicast_supported == 0: raise RuntimeError( - "[McastDeviceMemory] Device does not support multicasting." + "[SymmDeviceMemory] Device does not support multicasting." ) # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) logging.info( - f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " - f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, " + f"[SymmDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " + f"device_idx: {device_idx}, " f"Signal pad offset: {self.signal_pad_offset}" ) - # Create handle exchanger based on multi-node mode - if self.is_multi_node: - # Check if fabric handle is supported - fabric_handle_supported = checkCudaErrors( - cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, - device_idx, - ) - ) - if fabric_handle_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support fabric handle." - ) + # Create handle exchanger + if is_mnnvl_fabric_supported(device_idx): self._exchanger: HandleExchanger = FabricHandleExchanger( self.comm_backend, self.group_rank, self.group_size ) @@ -860,28 +871,24 @@ def __init__( self._exchanger = PosixFDHandleExchanger( self.comm_backend, self.group_rank, self.group_size ) - self._alloc_mn_mcast_mem(buf_size) - - # Initialize signal pads - self.signal_pads = [0] * self.group_size - for i in range(self.group_size): - self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset - if i == self.group_rank: - checkCudaErrors( - cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) - ) + self._alloc_mn_mcast_mem(buf_size, enable_multicast) + + if allocate_signal_pads: + # Initialize signal pads + self.signal_pads = [0] * self.group_size + for i in range(self.group_size): + self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset + if i == self.group_rank: + checkCudaErrors( + cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) + ) - # Create device pointers - self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) + self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs) def __del__(self): """Destructor - cleanup allocated memory""" - # Check if we're in a valid state for cleanup - if not hasattr(self, "is_multi_node"): - return - if hasattr(self, "_exchanger"): self._exchanger.close() @@ -987,7 +994,7 @@ def get_usable_buffer_size(self) -> int: """Get the usable buffer size (excluding signal pad)""" return self.allocation_size - self.SIGNAL_PAD_SIZE - def _alloc_mn_mcast_mem(self, buf_size: int): + def _alloc_mn_mcast_mem(self, buf_size: int, enable_multicast: bool): """Allocate multi-node multicast memory using MNNVL""" self._verify_cuda_context() @@ -998,7 +1005,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._allocate_unicast_buffers(allocation_prop) # Setup multicast object, exchange handles, map and bind memory - self._setup_multicast(mc_prop) + if enable_multicast: + self._setup_multicast(mc_prop) def _verify_cuda_context(self): """Verify CUDA context is set to the correct device.""" @@ -1198,7 +1206,7 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): class McastGPUBuffer: """ - Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. + Wrapper class for SymmDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication. Python port of McastGPUBuffer from TensorRT-LLM @@ -1210,7 +1218,6 @@ def __init__( group_size: int, group_rank: int, device: torch.device, - mn_nvlink: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ @@ -1224,12 +1231,11 @@ def __init__( mn_nvlink: Flag indicating if multi-node NVLink is used comm_backend_for_handle_transfer: Communication backend for handle transfer """ - self.mcast_device_memory = McastDeviceMemory( + self.mcast_device_memory = SymmDeviceMemory( buf_size, group_size, group_rank, device.index, - mn_nvlink, comm_backend_for_handle_transfer, ) # Update buf_size to reflect the actual usable buffer size after allocation diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 85e953c766..05237dedde 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union from typing_extensions import deprecated +from flashinfer.comm.mnnvl import CommBackend, MPIBackend, SymmDeviceMemory import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -509,7 +510,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, - group: Optional[ProcessGroup] = None, + comm_backend: Optional[CommBackend] = None, create_metadata: bool = False, ) -> Union[ Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict] @@ -521,7 +522,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - - group: the process group to use. + - comm_backend: the communication backend to use. - create_metadata: if True, return metadata dict as third element (default: False). Returns: @@ -546,6 +547,9 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init """ + if comm_backend is None: + comm_backend = MPIBackend() + buffer_size = tp_size * max_token_num * hidden_dim * 2 flag_size = tp_size * BarrierFlagCount * 4 # lamport_comm_size = tp_size * max(max_token_num, OneShotMaxToken) * hidden_dim * 2 @@ -567,11 +571,21 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( # [buffer_size, flag_size, lamport_buffer_size] ipc_handles: List[List[int]] = list() + mem_handles: List[SymmDeviceMemory] = list() for size in [buffer_size, flag_size, lamport_buffer_size]: # todo(review): confirm we need this alignment # all sizes should be aligned to 1LU << 21 bytes (2MB) aligned_size = round_up(size, 1 << 21) - ipc_handles.append(create_shared_buffer(aligned_size, group)) + symm_mem = SymmDeviceMemory(aligned_size, + tp_size, + tp_rank, + torch.device("cuda", tp_rank).index, + comm_backend, + False, + False) + ipc_handles.append(symm_mem.uc_ptrs) + mem_handles.append(symm_mem) + print( f"rank {tp_rank} allocated ipc_handles: {[[hex(handle) for handle in sublist] for sublist in ipc_handles]}" @@ -626,7 +640,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( workspace, dtype=torch.int64, device=torch.device("cuda") ) - dist.barrier(group=group) # must sync after create_workspace + comm_backend.barrier() # must sync after create_workspace if create_metadata: metadata = { @@ -640,9 +654,9 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( "lamport_comm_size": lamport_comm_size, "lamport_buffer_size": lamport_buffer_size, } - return ipc_handles, workspace_tensor, metadata + return ipc_handles, workspace_tensor, mem_handles, metadata else: - return ipc_handles, workspace_tensor + return ipc_handles, workspace_tensor, mem_handles def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index dfcb8317e5..3f8e198146 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -137,7 +137,6 @@ def __init__( mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), - mapping.is_multi_node(), comm_backend, ) From d421a672600ec4ad28f5d8c2c507f0f98de821a6 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Wed, 17 Dec 2025 14:04:15 -0800 Subject: [PATCH 02/10] one_shot in workspace creation, removed topology --- flashinfer/comm/allreduce.py | 69 ++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 2c70590109..2854bdb4d6 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -55,7 +55,6 @@ from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion -from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata from .mapping import Mapping @@ -66,6 +65,7 @@ # Import them for runtime use but type hint as int for mypy compatibility from .trtllm_ar import AllReduceFusionPattern from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace +from .trtllm_mnnvl_ar import MNNVLAllreduceFusionStrategy from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm @@ -125,7 +125,8 @@ def __init__( # Store essential attributes for easy access # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True workspace_tuple = cast( - Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], self._internal_workspace, + Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], + self._internal_workspace, ) self.ipc_handles = workspace_tuple[0] self.workspace_tensor = workspace_tuple[1] @@ -185,7 +186,6 @@ def _trtllm_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if trtllm backend CAN be used for workspace creation. @@ -194,8 +194,8 @@ def _trtllm_workspace_check( - Single-node topology (multi-node not supported) """ - # trtllm is optimized for single-node - if topology == "multi_node": + # trtllm is limited to 16 ranks + if world_size > 16: return False return True @@ -208,16 +208,12 @@ def _mnnvl_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if mnnvl backend CAN be used for workspace creation. """ - if topology == "multi_node": - return True - return True @@ -234,7 +230,6 @@ def _workspace_creation_heuristic( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> list[str]: """ Select best backend for workspace creation based on performance. @@ -250,7 +245,6 @@ def _workspace_creation_heuristic( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - topology: Network topology ("single_node" or "multi_node") **kwargs: Additional arguments Note that at this point, the backend selection does not take "runtime parameters" into account, such as layout_code, and fusion pattern. @@ -266,13 +260,6 @@ def _workspace_creation_heuristic( # Decision tree based on benchmark data - # Multi-node: MNNVL is designed for this - if topology == "multi_node": - if "mnnvl" in suitable_backends: - return ["mnnvl"] - else: - return [suitable_backends[0]] - # Single-node scenarios # From benchmarking data, we can see that MNNVL is either on par (smaller problem sizes) or significantly faster than TRTLLM (larger problem sizes such as hidden_dim=8192, token_num=64 for TP=4), for single-node scenarios. # However, trtllm has a larger support surface (more fusion patterns, more quantization support, etc.) @@ -294,10 +281,10 @@ def create_allreduce_fusion_workspace( max_token_num: int = None, hidden_dim: int = None, dtype: torch.dtype = None, - topology: Literal["single_node", "multi_node"] = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, gpus_per_node: int = None, comm_backend: Optional[CommBackend] = None, + use_oneshot: bool = False, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -319,19 +306,21 @@ def create_allreduce_fusion_workspace( Args: backend: Backend to use ("trtllm", "mnnvl", or "auto") - "auto" uses heuristic to select best backend based on topology - and problem size + "auto" uses heuristic to select best backend world_size: Number of ranks in the process group rank: Current rank ID max_token_num: Maximum number of tokens to support hidden_dim: Hidden dimension size dtype: Data type for communication tensors - topology: Network topology hint for backend selection - "single_node" - All ranks on one node (default) - "multi_node" - Ranks span multiple nodes process_group: PyTorch distributed process group (for trtllm backend). gpus_per_node: Number of GPUs per node (for multi-node topology). - comm_backend: Communication backend to use (for multi-node topology). + comm_backend: Communication backend to use. + use_oneshot: Allocate workspace for oneshot strategy vs twoshot + True: Allocate workspace for oneshot strategy (larger workspace size) + False: Allocate workspace for twoshot strategy + If None, uses internal heuristics to select the strategy. + Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. + The trtllm backend will be sufficient for both strategies. Returns: Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) @@ -342,7 +331,7 @@ def create_allreduce_fusion_workspace( ValueError: If problem size not supported for the specified backend Examples: - >>> # Auto-select best backend based on topology + >>> # Auto-select best backend >>> workspace = create_allreduce_fusion_workspace( ... backend="auto", ... world_size=8, @@ -350,7 +339,6 @@ def create_allreduce_fusion_workspace( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="single_node" ... ) >>> print(workspace.backend) # "trtllm" >>> print(workspace.get_workspace_capacity()) # 8388608 elements @@ -367,7 +355,6 @@ def create_allreduce_fusion_workspace( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="multi_node" ... ) >>> print(workspace.backend) # "mnnvl" """ @@ -375,7 +362,7 @@ def create_allreduce_fusion_workspace( gpus_per_node = min(torch.cuda.device_count(), world_size) # Determine the actual backend to use if backend == "auto": - # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) + # Find suitable backends (any compute capability check needs to be checked at kernel runtime, since there are no tensor available at this point) suitable_backends = [] if _trtllm_workspace_check( backend=backend, @@ -384,7 +371,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ): suitable_backends.append("trtllm") if _mnnvl_workspace_check( @@ -394,15 +380,11 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ): suitable_backends.append("mnnvl") if not suitable_backends: - raise ValueError( - f"No suitable backend found for topology={topology}. " - f"trtllm requires single_node topology, mnnvl works with both." - ) + raise ValueError("No suitable backend found. ") # Apply heuristic to select best backend selected = _workspace_creation_heuristic( @@ -413,7 +395,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ) actual_backend = selected[0] else: @@ -437,12 +418,25 @@ def create_allreduce_fusion_workspace( gpus_per_node=gpus_per_node, tp_size=world_size, ) + workspace_size_bytes = None + if use_oneshot: + workspace_size_bytes = ( + MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes( + world_size, + max_token_num, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.ONESHOT, + ) + ) + return MNNVLAllReduceFusionWorkspace( mapping=mapping, max_num_tokens=max_token_num, hidden_dim=hidden_dim, dtype=dtype, comm_backend=comm_backend, + workspace_size_bytes=workspace_size_bytes, ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") @@ -518,7 +512,7 @@ def allreduce_fusion( # ===== Control parameters ===== use_oneshot: Use oneshot strategy vs twoshot If None, uses internal heuristics. - Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. + Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce Returns: @@ -533,7 +527,6 @@ def allreduce_fusion( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="single_node" ... ) >>> >>> # Pre-allocate output tensors From 5285e3205b1b09fd472dca8967d4b1be9f994a14 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Thu, 18 Dec 2025 10:09:17 -0800 Subject: [PATCH 03/10] Backward compatibility, test integration Made trtllm_ar backward compatibile, updated tests with new API --- flashinfer/comm/allreduce.py | 9 +-- flashinfer/comm/mnnvl.py | 85 +++++++++++++++++++++- flashinfer/comm/trtllm_ar.py | 64 +++++++++++----- tests/comm/test_allreduce_unified_api.py | 14 +--- tests/comm/test_trtllm_allreduce_fusion.py | 6 +- 5 files changed, 141 insertions(+), 37 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 2854bdb4d6..7d677b0d21 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -120,6 +120,7 @@ def __init__( comm_backend=comm_backend, create_metadata=True, use_fp32_lamport=dtype == torch.float32, + use_symm_dev_mem=True, ) # Store essential attributes for easy access @@ -281,7 +282,6 @@ def create_allreduce_fusion_workspace( max_token_num: int = None, hidden_dim: int = None, dtype: torch.dtype = None, - process_group: Optional["torch.distributed.ProcessGroup"] = None, gpus_per_node: int = None, comm_backend: Optional[CommBackend] = None, use_oneshot: bool = False, @@ -312,7 +312,6 @@ def create_allreduce_fusion_workspace( max_token_num: Maximum number of tokens to support hidden_dim: Hidden dimension size dtype: Data type for communication tensors - process_group: PyTorch distributed process group (for trtllm backend). gpus_per_node: Number of GPUs per node (for multi-node topology). comm_backend: Communication backend to use. use_oneshot: Allocate workspace for oneshot strategy vs twoshot @@ -418,9 +417,9 @@ def create_allreduce_fusion_workspace( gpus_per_node=gpus_per_node, tp_size=world_size, ) - workspace_size_bytes = None + buffer_size_in_bytes = None if use_oneshot: - workspace_size_bytes = ( + buffer_size_in_bytes = ( MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes( world_size, max_token_num, @@ -436,7 +435,7 @@ def create_allreduce_fusion_workspace( hidden_dim=hidden_dim, dtype=dtype, comm_backend=comm_backend, - workspace_size_bytes=workspace_size_bytes, + buffer_size_in_bytes=buffer_size_in_bytes, ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index a366e1bc10..667f9f6815 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -232,6 +232,86 @@ def Split(self, color: int, key: int) -> CommBackend: return MPIBackend() # Returns new adapter +class TorchDistBackend(CommBackend): + """Communication backend using torch.distributed""" + + def __init__(self, group: Optional[Any] = None): + """ + Initialize TorchDistBackend. + + Args: + group: Optional process group. If None, uses the default process group. + """ + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. " + "Please call torch.distributed.init_process_group() first." + ) + self._group = group + self._dist = dist + + def Get_rank(self) -> int: + return self._dist.get_rank(self._group) + + def Get_size(self) -> int: + return self._dist.get_world_size(self._group) + + def allgather(self, data: Any) -> List[Any]: + """All-gather arbitrary Python objects across all ranks.""" + output_list = [None] * self.Get_size() + self._dist.all_gather_object(output_list, data, group=self._group) + return output_list + + def bcast(self, data: Any, root: int) -> Any: + """Broadcast a Python object from root to all ranks.""" + object_list = [data] + self._dist.broadcast_object_list(object_list, src=root, group=self._group) + return object_list[0] + + def barrier(self) -> None: + self._dist.barrier(group=self._group) + + def Split(self, color: int, key: int) -> "TorchDistBackend": + """ + Split the communicator into sub-groups based on color. + + All processes with the same color will be in the same new group. + The key determines the rank ordering within the new group. + + Args: + color: Processes with the same color are placed in the same group + key: Determines rank ordering within the new group (lower key = lower rank) + + Returns: + New TorchDistBackend with the split process group + """ + # Gather (color, key, global_rank) from all processes + global_rank = self.Get_rank() + + all_info = self.allgather((color, key, global_rank)) + + # Group ranks by color, sort by key within each group + color_groups: Dict[int, List[tuple]] = {} + for c, k, r in all_info: + if c not in color_groups: + color_groups[c] = [] + color_groups[c].append((k, r)) + + # Sort each group by key to determine rank ordering + for c in color_groups: + color_groups[c].sort(key=lambda x: x[0]) + + # Find my new group's ranks (in sorted order by key) + my_group_ranks = [r for _, r in color_groups[color]] + + # Create new process group with the ranks in my color group + new_group = self._dist.new_group(ranks=my_group_ranks) + + return TorchDistBackend(group=new_group) + + @dataclass class MnnvlConfig: """Configuration for MNNVL memory management""" @@ -779,7 +859,10 @@ def is_mnnvl_fabric_supported(device_idx: int) -> bool: handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) fabric_info = pynvml.c_nvmlGpuFabricInfoV_t() pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info)) - if fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED and fabric_info.clusterUuid[0] != 0: + if ( + fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED + and fabric_info.clusterUuid[0] != 0 + ): return True return False finally: diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 05237dedde..f8e17fd422 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple, Union from typing_extensions import deprecated -from flashinfer.comm.mnnvl import CommBackend, MPIBackend, SymmDeviceMemory +from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -510,10 +510,14 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, - comm_backend: Optional[CommBackend] = None, + group: Optional[ProcessGroup] = None, create_metadata: bool = False, + comm_backend: Optional[CommBackend] = None, + use_symm_dev_mem: bool = False, ) -> Union[ - Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict] + Tuple[List[List[int]], torch.Tensor], + Tuple[List[List[int]], torch.Tensor, dict], + Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], ]: """ Parameters: @@ -524,12 +528,20 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - comm_backend: the communication backend to use. - create_metadata: if True, return metadata dict as third element (default: False). + - group: the process group to use. + - create_metadata: if True, return metadata dict as third element (default: False). + - comm_backend: the communication backend to use. + - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace. Returns: - If create_metadata=False: (ipc_handles, workspace_tensor) - - If create_metadata=True: (ipc_handles, workspace_tensor, metadata) + - If create_metadata=True: and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata) + where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, + use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size + - If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata) where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size + and mem_handles is a list of SymmDeviceMemory objects. Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. @@ -547,8 +559,12 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init """ - if comm_backend is None: - comm_backend = MPIBackend() + if comm_backend is None and use_symm_dev_mem: + comm_backend = TorchDistBackend(group=group) + + # No need to support all variations. In the future we only support create_metadata=True and use_symm_dev_mem=True. + if use_symm_dev_mem and not create_metadata: + raise ValueError("use_symm_dev_mem is only supported when create_metadata=True") buffer_size = tp_size * max_token_num * hidden_dim * 2 flag_size = tp_size * BarrierFlagCount * 4 @@ -576,16 +592,21 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( # todo(review): confirm we need this alignment # all sizes should be aligned to 1LU << 21 bytes (2MB) aligned_size = round_up(size, 1 << 21) - symm_mem = SymmDeviceMemory(aligned_size, - tp_size, - tp_rank, - torch.device("cuda", tp_rank).index, - comm_backend, - False, - False) - ipc_handles.append(symm_mem.uc_ptrs) - mem_handles.append(symm_mem) + if not use_symm_dev_mem: + ipc_handles.append(create_shared_buffer(aligned_size, group)) + else: + symm_mem = SymmDeviceMemory( + aligned_size, + tp_size, + tp_rank, + torch.device("cuda", tp_rank).index, + comm_backend, + False, + False, + ) + ipc_handles.append(symm_mem.uc_ptrs) + mem_handles.append(symm_mem) print( f"rank {tp_rank} allocated ipc_handles: {[[hex(handle) for handle in sublist] for sublist in ipc_handles]}" @@ -640,7 +661,10 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( workspace, dtype=torch.int64, device=torch.device("cuda") ) - comm_backend.barrier() # must sync after create_workspace + if use_symm_dev_mem: + comm_backend.barrier() # must sync after create_workspace + else: + dist.barrier(group=group) if create_metadata: metadata = { @@ -654,9 +678,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( "lamport_comm_size": lamport_comm_size, "lamport_buffer_size": lamport_buffer_size, } - return ipc_handles, workspace_tensor, mem_handles, metadata + if use_symm_dev_mem: + return ipc_handles, workspace_tensor, mem_handles, metadata + else: + return ipc_handles, workspace_tensor, metadata + else: - return ipc_handles, workspace_tensor, mem_handles + return ipc_handles, workspace_tensor def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( diff --git a/tests/comm/test_allreduce_unified_api.py b/tests/comm/test_allreduce_unified_api.py index 732a3ddb92..d31a0a7e31 100644 --- a/tests/comm/test_allreduce_unified_api.py +++ b/tests/comm/test_allreduce_unified_api.py @@ -5,10 +5,10 @@ import pytest import torch -import torch.distributed as dist from mpi4py import MPI import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar +from flashinfer.comm.mnnvl import TorchDistBackend # Unified API imports from flashinfer.comm import ( @@ -197,23 +197,16 @@ def run_allreduce_test( local_rank = rank % gpus_per_node torch.cuda.set_device(local_rank) - # Initialize torch.distributed for trtllm backend (needed for IPC workspace) - # TODO: check if it is ok to do this with auto backend - process_group = None - if backend in ("trtllm", "auto"): - init_torch_distributed_from_mpi() - process_group = dist.group.WORLD - if local_rank == 0: print(f"Running AllReduce test with {world_size} ranks, backend={backend}") print(f"Rank {rank} using GPU {torch.cuda.current_device()}") eps = 1e-5 - torch.manual_seed(42 + rank) workspace = None try: + init_torch_distributed_from_mpi() # Create workspace using unified API workspace = create_allreduce_fusion_workspace( backend=backend, @@ -222,9 +215,8 @@ def run_allreduce_test( max_token_num=max(seq_lens), hidden_dim=hidden_size, dtype=dtype, - topology="single_node", gpus_per_node=gpus_per_node, - process_group=process_group, + comm_backend=TorchDistBackend(), ) print(f"Rank {rank}: Created workspace with backend={workspace.backend}") diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index dab4877fb9..5ae027c4ff 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -9,6 +9,8 @@ import flashinfer.comm as comm +from flashinfer.comm.mnnvl import TorchDistBackend + # todo(Yingyi): add benchmark and quant test @@ -83,8 +85,7 @@ def _run_correctness_worker( max_token_num=MAX_TOKEN_NUM, hidden_dim=hidden_dim, dtype=dtype, - topology="single_node", - process_group=group, + comm_backend=TorchDistBackend(), ) test_loop = 5 @@ -445,6 +446,7 @@ def multi_process_parallel( ) +# Run as: python tests/comm/test_trtllm_allreduce_fusion.py @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) From 0f8370146e84dddd6dc03864136ba2e38e456993 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Thu, 18 Dec 2025 10:45:29 -0800 Subject: [PATCH 04/10] Added documentation --- flashinfer/comm/allreduce.py | 11 +++++------ flashinfer/comm/trtllm_ar.py | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 7d677b0d21..69b2b62733 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -284,7 +284,7 @@ def create_allreduce_fusion_workspace( dtype: torch.dtype = None, gpus_per_node: int = None, comm_backend: Optional[CommBackend] = None, - use_oneshot: bool = False, + force_oneshot_support: bool = False, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -314,10 +314,9 @@ def create_allreduce_fusion_workspace( dtype: Data type for communication tensors gpus_per_node: Number of GPUs per node (for multi-node topology). comm_backend: Communication backend to use. - use_oneshot: Allocate workspace for oneshot strategy vs twoshot - True: Allocate workspace for oneshot strategy (larger workspace size) - False: Allocate workspace for twoshot strategy - If None, uses internal heuristics to select the strategy. + force_oneshot_support: Allocate workspace for oneshot strategy vs twoshot + True: Allocate workspace for oneshot strategy up to the largest problem size requested + None/False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. The trtllm backend will be sufficient for both strategies. @@ -418,7 +417,7 @@ def create_allreduce_fusion_workspace( tp_size=world_size, ) buffer_size_in_bytes = None - if use_oneshot: + if force_oneshot_support: buffer_size_in_bytes = ( MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes( world_size, diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index f8e17fd422..cf7f986c18 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -543,6 +543,9 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size and mem_handles is a list of SymmDeviceMemory objects. + Note: The optional parameters make the API clunky at this time. This will be refactored in the future, at the cost of backward compatibility, where the default behavior will be + create_metadata=True and use_symm_dev_mem=True. + Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. They are sized as follows: From 9dd38fe112646bf3e3ebe3dff304549fb4509d51 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Thu, 18 Dec 2025 12:17:48 -0800 Subject: [PATCH 05/10] Addressed bot reviewer comments --- flashinfer/comm/allreduce.py | 11 +++-------- flashinfer/comm/trtllm_ar.py | 2 -- tests/comm/test_allreduce_unified_api.py | 1 + 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 69b2b62733..237ef37102 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -192,14 +192,9 @@ def _trtllm_workspace_check( Check if trtllm backend CAN be used for workspace creation. Hard requirements: - - Single-node topology (multi-node not supported) - + - Up to 16 ranks supported. """ - # trtllm is limited to 16 ranks - if world_size > 16: - return False - - return True + return world_size <= 16 def _mnnvl_workspace_check( @@ -316,7 +311,7 @@ def create_allreduce_fusion_workspace( comm_backend: Communication backend to use. force_oneshot_support: Allocate workspace for oneshot strategy vs twoshot True: Allocate workspace for oneshot strategy up to the largest problem size requested - None/False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. + False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. The trtllm backend will be sufficient for both strategies. diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index cf7f986c18..88d3530577 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -526,8 +526,6 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - - comm_backend: the communication backend to use. - - create_metadata: if True, return metadata dict as third element (default: False). - group: the process group to use. - create_metadata: if True, return metadata dict as third element (default: False). - comm_backend: the communication backend to use. diff --git a/tests/comm/test_allreduce_unified_api.py b/tests/comm/test_allreduce_unified_api.py index d31a0a7e31..2b80cf645c 100644 --- a/tests/comm/test_allreduce_unified_api.py +++ b/tests/comm/test_allreduce_unified_api.py @@ -202,6 +202,7 @@ def run_allreduce_test( print(f"Rank {rank} using GPU {torch.cuda.current_device()}") eps = 1e-5 + torch.manual_seed(0) workspace = None From 1e42bb6c3e0417bc95822ecee4f09c4e385372c1 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 19 Dec 2025 16:16:36 -0800 Subject: [PATCH 06/10] Remove MPI dependency from MNNVL allreduce tests - Use TorchDistBackend instead of MPIBackend for workspace creation - Replace MPI calls with torch.distributed equivalents - Add SLURM environment variable support for rank info - Use explicit init_method to avoid setting env vars - Add conftest.py for proper torch.distributed cleanup (commits were squashed from another branch) --- tests/comm/conftest.py | 12 +++ tests/comm/test_trtllm_mnnvl_allreduce.py | 84 ++++++++++------ tests/test_helpers/comm.py | 117 ++++++++++++++++++---- 3 files changed, 161 insertions(+), 52 deletions(-) create mode 100644 tests/comm/conftest.py diff --git a/tests/comm/conftest.py b/tests/comm/conftest.py new file mode 100644 index 0000000000..c250c6a15d --- /dev/null +++ b/tests/comm/conftest.py @@ -0,0 +1,12 @@ +# Conftest for communication tests +import torch.distributed as dist + + +def pytest_sessionfinish(session, exitstatus): + """Cleanup torch.distributed at the end of pytest session. + + This runs after all tests complete but before Python shutdown, + avoiding the "destroy_process_group() was not called" warning. + """ + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index b71889e864..0a3b6fab1f 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -4,14 +4,22 @@ import pytest import torch -from mpi4py import MPI # Added MPI import +import torch.distributed as dist import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import TorchDistBackend # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm +# Test helpers +from tests.test_helpers.comm import ( + init_torch_distributed_from_mpi, +) + +# Note: torch.distributed cleanup is handled by tests/comm/conftest.py + @torch.inference_mode() def row_linear_residual_norm_fusion_forward( @@ -25,7 +33,7 @@ def row_linear_residual_norm_fusion_forward( workspace: trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -41,7 +49,7 @@ def func( use_pdl = True if enable_fusion: - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() output, residual_out = ( trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( @@ -121,7 +129,7 @@ def row_linear_residual_norm_fusion_forward_legacy( ): tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -145,7 +153,7 @@ def func( prenorm_output = torch.empty_like(residual) normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output, @@ -231,10 +239,10 @@ def func( def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): - # Communicator used for passing data between ranks - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() + # Use torch.distributed for communication between ranks + rank = dist.get_rank() + world_size = dist.get_world_size() + if rank == 0: x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) residual = torch.randn((seq_len, hidden_size), dtype=dtype) @@ -244,10 +252,10 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion residual = None norm_weight = None - # Use lowercase bcast() for Python object broadcasting - x_full = comm.bcast(x_full, root=0) - residual = comm.bcast(residual, root=0) - norm_weight = comm.bcast(norm_weight, root=0) + # Use torch.distributed broadcast_object_list for Python object broadcasting + data_list = [x_full, residual, norm_weight] + dist.broadcast_object_list(data_list, src=0) + x_full, residual, norm_weight = data_list x_full = x_full.cuda() residual = residual.cuda() @@ -291,16 +299,20 @@ def run_mnnvl_ar_full( explicit_workspace_bytes: If provided, use this workspace size instead of default """ - comm = MPI.COMM_WORLD - # Get MPI info - rank = comm.Get_rank() - world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") + + # Initialize torch.distributed (safe to call if already initialized) + init_torch_distributed_from_mpi() + + # Get rank info from torch.distributed + rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size < 2: - pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") + pytest.skip(f"This test requires at least 2 ranks, got {world_size}") mapping = Mapping( world_size=world_size, @@ -312,6 +324,9 @@ def run_mnnvl_ar_full( # Set CUDA device based on rank torch.cuda.set_device(mapping.local_rank) + # Create TorchDistBackend for workspace creation (non-MPI based) + comm_backend = TorchDistBackend() + if mapping.local_rank == 0: print( f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" @@ -330,7 +345,10 @@ def run_mnnvl_ar_full( if legacy_api: mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes + mapping, + dtype, + comm_backend_for_handle_transfer=comm_backend, + buffer_size_in_bytes=legacy_explicit_workspace_bytes, ) ) @@ -346,6 +364,7 @@ def run_mnnvl_ar_full( max_num_tokens=max(seq_lens), hidden_dim=hidden_size, dtype=dtype, + comm_backend=comm_backend, ) test_data = [] @@ -392,8 +411,8 @@ def run_mnnvl_ar_full( workspace, ) - # Synchronize before next test - trtllm_mnnvl_ar.mpi_barrier() + # Synchronize before next test using torch.distributed barrier + dist.barrier() print( f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" @@ -405,28 +424,27 @@ def run_mnnvl_ar_full( print(failure_message) print(traceback.format_exc()) - # Gather failure status from all ranks for logging - all_failures = MPI.COMM_WORLD.allgather(rank_failed) + # Gather failure status from all ranks using torch.distributed + all_failures = [None] * world_size + dist.all_gather_object(all_failures, rank_failed) if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") - # Cleanup before re-raising - if "workspace" in locals(): - del workspace - # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: - # Ensure cleanup happens for this list's workspace - if "workspace" in locals(): - del workspace - - # Final synchronization and check for failures across all ranks - trtllm_mnnvl_ar.mpi_barrier() + # Explicitly destroy workspace to avoid __del__ issues during Python shutdown + if "workspace" in locals() and workspace is not None: + workspace.destroy() + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + + # Final synchronization using torch.distributed barrier + dist.barrier() """Test with default workspace size""" diff --git a/tests/test_helpers/comm.py b/tests/test_helpers/comm.py index fcd4e0a23b..e327b0e861 100644 --- a/tests/test_helpers/comm.py +++ b/tests/test_helpers/comm.py @@ -4,55 +4,134 @@ import pytest import torch import torch.distributed as dist -from mpi4py import MPI + + +def _get_rank_info_from_env(): + """Get rank and world_size from environment variables. + + Supports multiple launchers: + - SLURM (srun): SLURM_PROCID, SLURM_NTASKS, SLURM_LOCALID + - torchrun: RANK, WORLD_SIZE, LOCAL_RANK + - MPI (fallback): Uses mpi4py if environment variables not found + + Returns: + tuple: (rank, world_size, local_rank) + """ + # Try SLURM environment variables first (set by srun) + if "SLURM_PROCID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + local_rank = int( + os.environ.get("SLURM_LOCALID", rank % torch.cuda.device_count()) + ) + return rank, world_size, local_rank + + # Try torchrun/torch.distributed.launch environment variables + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())) + return rank, world_size, local_rank + + # Fallback to MPI if available + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + local_rank = rank % torch.cuda.device_count() + return rank, world_size, local_rank + except ImportError as e: + raise RuntimeError( + "Could not determine rank/world_size. " + "Please set SLURM_PROCID/SLURM_NTASKS (srun), " + "RANK/WORLD_SIZE (torchrun), or install mpi4py." + ) from e def setup_mpi_and_cuda(): - """Setup MPI and CUDA device for tests. + """Setup distributed environment and CUDA device for tests. Returns: tuple: (rank, world_size, gpus_per_node) Raises: - pytest.skip: If no CUDA devices or fewer than 2 MPI ranks + pytest.skip: If no CUDA devices or fewer than 2 ranks """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("Tests require at least one CUDA device per node") + + rank, world_size, local_rank = _get_rank_info_from_env() + if world_size < 2: - pytest.skip(f"Tests require at least 2 MPI ranks, got {world_size}") + pytest.skip(f"Tests require at least 2 ranks, got {world_size}") - local_rank = rank % gpus_per_node torch.cuda.set_device(local_rank) return rank, world_size, gpus_per_node +def _get_master_addr(): + """Get the master address for torch.distributed. + + For multi-node SLURM jobs, extracts the first node from SLURM_NODELIST. + """ + if "MASTER_ADDR" in os.environ: + return os.environ["MASTER_ADDR"] + + # For SLURM multi-node: get first node from nodelist + if "SLURM_NODELIST" in os.environ: + import subprocess + + try: + # Use scontrol to expand the nodelist and get first node + result = subprocess.run( + ["scontrol", "show", "hostnames", os.environ["SLURM_NODELIST"]], + capture_output=True, + text=True, + check=True, + ) + first_node = result.stdout.strip().split("\n")[0] + return first_node + except (subprocess.CalledProcessError, FileNotFoundError): + # scontrol not available, try simple parsing + nodelist = os.environ["SLURM_NODELIST"] + # Handle simple cases like "node[0-3]" -> "node0" or "node0,node1" -> "node0" + if "[" in nodelist: + base = nodelist.split("[")[0] + nums = nodelist.split("[")[1].split("]")[0] + first_num = nums.split(",")[0].split("-")[0] + return f"{base}{first_num}" + else: + return nodelist.split(",")[0] + + return "localhost" + + def init_torch_distributed_from_mpi(): - """Initialize torch.distributed using MPI rank info. + """Initialize torch.distributed using environment rank info. - This allows running torch.distributed operations within an MPI context. + Supports SLURM (srun), torchrun, or MPI launchers. Safe to call multiple times - will skip if already initialized. - """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() + Uses explicit init_method to avoid modifying environment variables. + """ if dist.is_initialized(): return - # Set environment variables for torch.distributed - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) + rank, world_size, local_rank = _get_rank_info_from_env() + + # Use explicit init_method instead of setting environment variables + master_addr = _get_master_addr() + master_port = os.environ.get("MASTER_PORT", "29500") + init_method = f"tcp://{master_addr}:{master_port}" dist.init_process_group( backend="nccl", + init_method=init_method, rank=rank, world_size=world_size, ) From c7a6ead30e0a60936650acc3bfe1740cbdb5fcb5 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Fri, 19 Dec 2025 18:22:06 -0600 Subject: [PATCH 07/10] Update flashinfer/comm/trtllm_ar.py Co-authored-by: Zihao Ye --- flashinfer/comm/trtllm_ar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 88d3530577..b8a41f59e8 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -603,7 +603,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank, torch.device("cuda", tp_rank).index, comm_backend, - False, + enable_multicast=False, False, ) ipc_handles.append(symm_mem.uc_ptrs) From f6b0edcc8e70a47b1ed178285cc1d3dc98c01430 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Fri, 19 Dec 2025 18:22:19 -0600 Subject: [PATCH 08/10] Update flashinfer/comm/trtllm_ar.py Co-authored-by: Zihao Ye --- flashinfer/comm/trtllm_ar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index b8a41f59e8..4392468e80 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -604,7 +604,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( torch.device("cuda", tp_rank).index, comm_backend, enable_multicast=False, - False, + allocate_signal_pads=False, ) ipc_handles.append(symm_mem.uc_ptrs) mem_handles.append(symm_mem) From 9296de188bc67cd5e7f22576c736d86ecf192f58 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 19 Dec 2025 16:29:34 -0800 Subject: [PATCH 09/10] Added comment on how to run the test --- tests/comm/test_trtllm_mnnvl_allreduce.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 0a3b6fab1f..a28a46d2ba 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -449,6 +449,9 @@ def run_mnnvl_ar_full( """Test with default workspace size""" +# Multi-gpu test: mpirun -np 4 pytest tests/comm/test_trtllm_mnnvl_allreduce.py -vv -s +# Multi-node test:srun -A coreai_libraries_cudnn -N2 --container-image= -J --mpi=pmix -- bash -c 'hostname && cd && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py' + @pytest.mark.parametrize( "seq_lens", From 1cb0870740952b13c0ea522267f91ad7a62c6db5 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 19 Dec 2025 16:38:17 -0800 Subject: [PATCH 10/10] Typo --- tests/comm/test_trtllm_mnnvl_allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index a28a46d2ba..26662fa71a 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -450,7 +450,7 @@ def run_mnnvl_ar_full( """Test with default workspace size""" # Multi-gpu test: mpirun -np 4 pytest tests/comm/test_trtllm_mnnvl_allreduce.py -vv -s -# Multi-node test:srun -A coreai_libraries_cudnn -N2 --container-image= -J --mpi=pmix -- bash -c 'hostname && cd && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py' +# Multi-node test:srun -A coreai_libraries_cudnn -N4 --container-image= -J --mpi=pmix -- bash -c 'hostname && cd && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py' @pytest.mark.parametrize(