diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index cfafa99220..237ef37102 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -55,17 +55,17 @@ from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion -from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata from .mapping import Mapping -from .mnnvl import CommBackend +from .mnnvl import CommBackend, SymmDeviceMemory # Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) # Import them for runtime use but type hint as int for mypy compatibility from .trtllm_ar import AllReduceFusionPattern from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace +from .trtllm_mnnvl_ar import MNNVLAllreduceFusionStrategy from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm @@ -95,7 +95,7 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype = torch.float16, - process_group: Optional["torch.distributed.ProcessGroup"] = None, + comm_backend: Optional[CommBackend] = None, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -106,7 +106,7 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - process_group: PyTorch distributed process group + comm_backend: Communication backend **kwargs: Additional arguments for workspace creation """ super().__init__(tp_size, tp_rank) @@ -117,19 +117,22 @@ def __init__( tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - group=process_group, + comm_backend=comm_backend, create_metadata=True, use_fp32_lamport=dtype == torch.float32, + use_symm_dev_mem=True, ) # Store essential attributes for easy access # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True workspace_tuple = cast( - Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], + self._internal_workspace, ) self.ipc_handles = workspace_tuple[0] self.workspace_tensor = workspace_tuple[1] - self.metadata = workspace_tuple[2] + self.mem_handles = workspace_tuple[2] + self.metadata = workspace_tuple[3] @property def backend(self) -> str: @@ -165,7 +168,10 @@ def destroy(self) -> None: if getattr(self, "_destroyed", False): return # Already destroyed, nothing to do - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) + del self.ipc_handles + del self.workspace_tensor + del self.mem_handles + del self.metadata self._destroyed = True @@ -181,20 +187,14 @@ def _trtllm_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if trtllm backend CAN be used for workspace creation. Hard requirements: - - Single-node topology (multi-node not supported) - + - Up to 16 ranks supported. """ - # trtllm is optimized for single-node - if topology == "multi_node": - return False - - return True + return world_size <= 16 def _mnnvl_workspace_check( @@ -204,16 +204,12 @@ def _mnnvl_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> bool: """ Check if mnnvl backend CAN be used for workspace creation. """ - if topology == "multi_node": - return True - return True @@ -230,7 +226,6 @@ def _workspace_creation_heuristic( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - topology: Literal["single_node", "multi_node"], ) -> list[str]: """ Select best backend for workspace creation based on performance. @@ -246,7 +241,6 @@ def _workspace_creation_heuristic( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - topology: Network topology ("single_node" or "multi_node") **kwargs: Additional arguments Note that at this point, the backend selection does not take "runtime parameters" into account, such as layout_code, and fusion pattern. @@ -262,13 +256,6 @@ def _workspace_creation_heuristic( # Decision tree based on benchmark data - # Multi-node: MNNVL is designed for this - if topology == "multi_node": - if "mnnvl" in suitable_backends: - return ["mnnvl"] - else: - return [suitable_backends[0]] - # Single-node scenarios # From benchmarking data, we can see that MNNVL is either on par (smaller problem sizes) or significantly faster than TRTLLM (larger problem sizes such as hidden_dim=8192, token_num=64 for TP=4), for single-node scenarios. # However, trtllm has a larger support surface (more fusion patterns, more quantization support, etc.) @@ -290,10 +277,9 @@ def create_allreduce_fusion_workspace( max_token_num: int = None, hidden_dim: int = None, dtype: torch.dtype = None, - topology: Literal["single_node", "multi_node"] = "single_node", - process_group: Optional["torch.distributed.ProcessGroup"] = None, gpus_per_node: int = None, comm_backend: Optional[CommBackend] = None, + force_oneshot_support: bool = False, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -315,19 +301,19 @@ def create_allreduce_fusion_workspace( Args: backend: Backend to use ("trtllm", "mnnvl", or "auto") - "auto" uses heuristic to select best backend based on topology - and problem size + "auto" uses heuristic to select best backend world_size: Number of ranks in the process group rank: Current rank ID max_token_num: Maximum number of tokens to support hidden_dim: Hidden dimension size dtype: Data type for communication tensors - topology: Network topology hint for backend selection - "single_node" - All ranks on one node (default) - "multi_node" - Ranks span multiple nodes - process_group: PyTorch distributed process group (for trtllm backend). gpus_per_node: Number of GPUs per node (for multi-node topology). - comm_backend: Communication backend to use (for multi-node topology). + comm_backend: Communication backend to use. + force_oneshot_support: Allocate workspace for oneshot strategy vs twoshot + True: Allocate workspace for oneshot strategy up to the largest problem size requested + False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. + Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. + The trtllm backend will be sufficient for both strategies. Returns: Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) @@ -338,7 +324,7 @@ def create_allreduce_fusion_workspace( ValueError: If problem size not supported for the specified backend Examples: - >>> # Auto-select best backend based on topology + >>> # Auto-select best backend >>> workspace = create_allreduce_fusion_workspace( ... backend="auto", ... world_size=8, @@ -346,7 +332,6 @@ def create_allreduce_fusion_workspace( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="single_node" ... ) >>> print(workspace.backend) # "trtllm" >>> print(workspace.get_workspace_capacity()) # 8388608 elements @@ -363,7 +348,6 @@ def create_allreduce_fusion_workspace( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="multi_node" ... ) >>> print(workspace.backend) # "mnnvl" """ @@ -371,7 +355,7 @@ def create_allreduce_fusion_workspace( gpus_per_node = min(torch.cuda.device_count(), world_size) # Determine the actual backend to use if backend == "auto": - # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) + # Find suitable backends (any compute capability check needs to be checked at kernel runtime, since there are no tensor available at this point) suitable_backends = [] if _trtllm_workspace_check( backend=backend, @@ -380,7 +364,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ): suitable_backends.append("trtllm") if _mnnvl_workspace_check( @@ -390,15 +373,11 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ): suitable_backends.append("mnnvl") if not suitable_backends: - raise ValueError( - f"No suitable backend found for topology={topology}. " - f"trtllm requires single_node topology, mnnvl works with both." - ) + raise ValueError("No suitable backend found. ") # Apply heuristic to select best backend selected = _workspace_creation_heuristic( @@ -409,7 +388,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - topology=topology, ) actual_backend = selected[0] else: @@ -423,7 +401,7 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - process_group=process_group, + comm_backend=comm_backend, ) elif actual_backend == "mnnvl": @@ -433,12 +411,25 @@ def create_allreduce_fusion_workspace( gpus_per_node=gpus_per_node, tp_size=world_size, ) + buffer_size_in_bytes = None + if force_oneshot_support: + buffer_size_in_bytes = ( + MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes( + world_size, + max_token_num, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.ONESHOT, + ) + ) + return MNNVLAllReduceFusionWorkspace( mapping=mapping, max_num_tokens=max_token_num, hidden_dim=hidden_dim, dtype=dtype, comm_backend=comm_backend, + buffer_size_in_bytes=buffer_size_in_bytes, ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") @@ -514,7 +505,7 @@ def allreduce_fusion( # ===== Control parameters ===== use_oneshot: Use oneshot strategy vs twoshot If None, uses internal heuristics. - Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. + Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce Returns: @@ -529,7 +520,6 @@ def allreduce_fusion( ... max_token_num=2048, ... hidden_dim=4096, ... dtype=torch.bfloat16, - ... topology="single_node" ... ) >>> >>> # Pre-allocate output tensors diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 13ca4f534d..667f9f6815 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -232,6 +232,86 @@ def Split(self, color: int, key: int) -> CommBackend: return MPIBackend() # Returns new adapter +class TorchDistBackend(CommBackend): + """Communication backend using torch.distributed""" + + def __init__(self, group: Optional[Any] = None): + """ + Initialize TorchDistBackend. + + Args: + group: Optional process group. If None, uses the default process group. + """ + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. " + "Please call torch.distributed.init_process_group() first." + ) + self._group = group + self._dist = dist + + def Get_rank(self) -> int: + return self._dist.get_rank(self._group) + + def Get_size(self) -> int: + return self._dist.get_world_size(self._group) + + def allgather(self, data: Any) -> List[Any]: + """All-gather arbitrary Python objects across all ranks.""" + output_list = [None] * self.Get_size() + self._dist.all_gather_object(output_list, data, group=self._group) + return output_list + + def bcast(self, data: Any, root: int) -> Any: + """Broadcast a Python object from root to all ranks.""" + object_list = [data] + self._dist.broadcast_object_list(object_list, src=root, group=self._group) + return object_list[0] + + def barrier(self) -> None: + self._dist.barrier(group=self._group) + + def Split(self, color: int, key: int) -> "TorchDistBackend": + """ + Split the communicator into sub-groups based on color. + + All processes with the same color will be in the same new group. + The key determines the rank ordering within the new group. + + Args: + color: Processes with the same color are placed in the same group + key: Determines rank ordering within the new group (lower key = lower rank) + + Returns: + New TorchDistBackend with the split process group + """ + # Gather (color, key, global_rank) from all processes + global_rank = self.Get_rank() + + all_info = self.allgather((color, key, global_rank)) + + # Group ranks by color, sort by key within each group + color_groups: Dict[int, List[tuple]] = {} + for c, k, r in all_info: + if c not in color_groups: + color_groups[c] = [] + color_groups[c].append((k, r)) + + # Sort each group by key to determine rank ordering + for c in color_groups: + color_groups[c].sort(key=lambda x: x[0]) + + # Find my new group's ranks (in sorted order by key) + my_group_ranks = [r for _, r in color_groups[color]] + + # Create new process group with the ranks in my color group + new_group = self._dist.new_group(ranks=my_group_ranks) + + return TorchDistBackend(group=new_group) + + @dataclass class MnnvlConfig: """Configuration for MNNVL memory management""" @@ -764,9 +844,34 @@ def close(self) -> None: self._socket.close() +def is_mnnvl_fabric_supported(device_idx: int) -> bool: + fabric_handle_supported = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, + device_idx, + ) + ) + if fabric_handle_supported == 0: + return False + + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + fabric_info = pynvml.c_nvmlGpuFabricInfoV_t() + pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info)) + if ( + fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED + and fabric_info.clusterUuid[0] != 0 + ): + return True + return False + finally: + pynvml.nvmlShutdown() + + # TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. -class McastDeviceMemory: - """Python port of McastDeviceMemory from TensorRT-LLM""" +class SymmDeviceMemory: + """Python port of SymmDeviceMemory from TensorRT-LLM""" def __init__( self, @@ -774,8 +879,9 @@ def __init__( group_size: int, group_rank: int, device_idx: int, - is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, + enable_multicast: bool = True, + allocate_signal_pads: bool = True, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -795,7 +901,6 @@ def __init__( checkCudaErrors(cudart.cudaSetDevice(device_idx)) - self.is_multi_node = is_multi_node self.device_idx = device_idx self.group_size = group_size self.group_rank = group_rank @@ -828,31 +933,20 @@ def __init__( ) if multicast_supported == 0: raise RuntimeError( - "[McastDeviceMemory] Device does not support multicasting." + "[SymmDeviceMemory] Device does not support multicasting." ) # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) logging.info( - f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " - f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, " + f"[SymmDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " + f"device_idx: {device_idx}, " f"Signal pad offset: {self.signal_pad_offset}" ) - # Create handle exchanger based on multi-node mode - if self.is_multi_node: - # Check if fabric handle is supported - fabric_handle_supported = checkCudaErrors( - cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, - device_idx, - ) - ) - if fabric_handle_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support fabric handle." - ) + # Create handle exchanger + if is_mnnvl_fabric_supported(device_idx): self._exchanger: HandleExchanger = FabricHandleExchanger( self.comm_backend, self.group_rank, self.group_size ) @@ -860,28 +954,24 @@ def __init__( self._exchanger = PosixFDHandleExchanger( self.comm_backend, self.group_rank, self.group_size ) - self._alloc_mn_mcast_mem(buf_size) - - # Initialize signal pads - self.signal_pads = [0] * self.group_size - for i in range(self.group_size): - self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset - if i == self.group_rank: - checkCudaErrors( - cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) - ) + self._alloc_mn_mcast_mem(buf_size, enable_multicast) + + if allocate_signal_pads: + # Initialize signal pads + self.signal_pads = [0] * self.group_size + for i in range(self.group_size): + self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset + if i == self.group_rank: + checkCudaErrors( + cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) + ) - # Create device pointers - self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) + self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs) def __del__(self): """Destructor - cleanup allocated memory""" - # Check if we're in a valid state for cleanup - if not hasattr(self, "is_multi_node"): - return - if hasattr(self, "_exchanger"): self._exchanger.close() @@ -987,7 +1077,7 @@ def get_usable_buffer_size(self) -> int: """Get the usable buffer size (excluding signal pad)""" return self.allocation_size - self.SIGNAL_PAD_SIZE - def _alloc_mn_mcast_mem(self, buf_size: int): + def _alloc_mn_mcast_mem(self, buf_size: int, enable_multicast: bool): """Allocate multi-node multicast memory using MNNVL""" self._verify_cuda_context() @@ -998,7 +1088,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._allocate_unicast_buffers(allocation_prop) # Setup multicast object, exchange handles, map and bind memory - self._setup_multicast(mc_prop) + if enable_multicast: + self._setup_multicast(mc_prop) def _verify_cuda_context(self): """Verify CUDA context is set to the correct device.""" @@ -1198,7 +1289,7 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): class McastGPUBuffer: """ - Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. + Wrapper class for SymmDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication. Python port of McastGPUBuffer from TensorRT-LLM @@ -1210,7 +1301,6 @@ def __init__( group_size: int, group_rank: int, device: torch.device, - mn_nvlink: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ @@ -1224,12 +1314,11 @@ def __init__( mn_nvlink: Flag indicating if multi-node NVLink is used comm_backend_for_handle_transfer: Communication backend for handle transfer """ - self.mcast_device_memory = McastDeviceMemory( + self.mcast_device_memory = SymmDeviceMemory( buf_size, group_size, group_rank, device.index, - mn_nvlink, comm_backend_for_handle_transfer, ) # Update buf_size to reflect the actual usable buffer size after allocation diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 85e953c766..4392468e80 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union from typing_extensions import deprecated +from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -511,8 +512,12 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( use_fp32_lamport: bool = False, group: Optional[ProcessGroup] = None, create_metadata: bool = False, + comm_backend: Optional[CommBackend] = None, + use_symm_dev_mem: bool = False, ) -> Union[ - Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict] + Tuple[List[List[int]], torch.Tensor], + Tuple[List[List[int]], torch.Tensor, dict], + Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], ]: """ Parameters: @@ -523,12 +528,21 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - group: the process group to use. - create_metadata: if True, return metadata dict as third element (default: False). + - comm_backend: the communication backend to use. + - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace. Returns: - If create_metadata=False: (ipc_handles, workspace_tensor) - - If create_metadata=True: (ipc_handles, workspace_tensor, metadata) + - If create_metadata=True: and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata) where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size + - If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata) + where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, + use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size + and mem_handles is a list of SymmDeviceMemory objects. + + Note: The optional parameters make the API clunky at this time. This will be refactored in the future, at the cost of backward compatibility, where the default behavior will be + create_metadata=True and use_symm_dev_mem=True. Note: We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion. @@ -546,6 +560,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init """ + if comm_backend is None and use_symm_dev_mem: + comm_backend = TorchDistBackend(group=group) + + # No need to support all variations. In the future we only support create_metadata=True and use_symm_dev_mem=True. + if use_symm_dev_mem and not create_metadata: + raise ValueError("use_symm_dev_mem is only supported when create_metadata=True") + buffer_size = tp_size * max_token_num * hidden_dim * 2 flag_size = tp_size * BarrierFlagCount * 4 # lamport_comm_size = tp_size * max(max_token_num, OneShotMaxToken) * hidden_dim * 2 @@ -567,11 +588,26 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( # [buffer_size, flag_size, lamport_buffer_size] ipc_handles: List[List[int]] = list() + mem_handles: List[SymmDeviceMemory] = list() for size in [buffer_size, flag_size, lamport_buffer_size]: # todo(review): confirm we need this alignment # all sizes should be aligned to 1LU << 21 bytes (2MB) aligned_size = round_up(size, 1 << 21) - ipc_handles.append(create_shared_buffer(aligned_size, group)) + + if not use_symm_dev_mem: + ipc_handles.append(create_shared_buffer(aligned_size, group)) + else: + symm_mem = SymmDeviceMemory( + aligned_size, + tp_size, + tp_rank, + torch.device("cuda", tp_rank).index, + comm_backend, + enable_multicast=False, + allocate_signal_pads=False, + ) + ipc_handles.append(symm_mem.uc_ptrs) + mem_handles.append(symm_mem) print( f"rank {tp_rank} allocated ipc_handles: {[[hex(handle) for handle in sublist] for sublist in ipc_handles]}" @@ -626,7 +662,10 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( workspace, dtype=torch.int64, device=torch.device("cuda") ) - dist.barrier(group=group) # must sync after create_workspace + if use_symm_dev_mem: + comm_backend.barrier() # must sync after create_workspace + else: + dist.barrier(group=group) if create_metadata: metadata = { @@ -640,7 +679,11 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( "lamport_comm_size": lamport_comm_size, "lamport_buffer_size": lamport_buffer_size, } - return ipc_handles, workspace_tensor, metadata + if use_symm_dev_mem: + return ipc_handles, workspace_tensor, mem_handles, metadata + else: + return ipc_handles, workspace_tensor, metadata + else: return ipc_handles, workspace_tensor diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index dfcb8317e5..3f8e198146 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -137,7 +137,6 @@ def __init__( mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), - mapping.is_multi_node(), comm_backend, ) diff --git a/tests/comm/conftest.py b/tests/comm/conftest.py index 53b6a48258..7d4c60044e 100644 --- a/tests/comm/conftest.py +++ b/tests/comm/conftest.py @@ -11,6 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch.distributed as dist + + +def pytest_sessionfinish(session, exitstatus): + """Cleanup torch.distributed at the end of pytest session. + + This runs after all tests complete but before Python shutdown, + avoiding the "destroy_process_group() was not called" warning. + """ + if dist.is_initialized(): + dist.destroy_process_group() + """ Shared test utilities for comm tests. diff --git a/tests/comm/test_allreduce_unified_api.py b/tests/comm/test_allreduce_unified_api.py index 732a3ddb92..2b80cf645c 100644 --- a/tests/comm/test_allreduce_unified_api.py +++ b/tests/comm/test_allreduce_unified_api.py @@ -5,10 +5,10 @@ import pytest import torch -import torch.distributed as dist from mpi4py import MPI import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar +from flashinfer.comm.mnnvl import TorchDistBackend # Unified API imports from flashinfer.comm import ( @@ -197,23 +197,17 @@ def run_allreduce_test( local_rank = rank % gpus_per_node torch.cuda.set_device(local_rank) - # Initialize torch.distributed for trtllm backend (needed for IPC workspace) - # TODO: check if it is ok to do this with auto backend - process_group = None - if backend in ("trtllm", "auto"): - init_torch_distributed_from_mpi() - process_group = dist.group.WORLD - if local_rank == 0: print(f"Running AllReduce test with {world_size} ranks, backend={backend}") print(f"Rank {rank} using GPU {torch.cuda.current_device()}") eps = 1e-5 - torch.manual_seed(42 + rank) + torch.manual_seed(0) workspace = None try: + init_torch_distributed_from_mpi() # Create workspace using unified API workspace = create_allreduce_fusion_workspace( backend=backend, @@ -222,9 +216,8 @@ def run_allreduce_test( max_token_num=max(seq_lens), hidden_dim=hidden_size, dtype=dtype, - topology="single_node", gpus_per_node=gpus_per_node, - process_group=process_group, + comm_backend=TorchDistBackend(), ) print(f"Rank {rank}: Created workspace with backend={workspace.backend}") diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index dab4877fb9..5ae027c4ff 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -9,6 +9,8 @@ import flashinfer.comm as comm +from flashinfer.comm.mnnvl import TorchDistBackend + # todo(Yingyi): add benchmark and quant test @@ -83,8 +85,7 @@ def _run_correctness_worker( max_token_num=MAX_TOKEN_NUM, hidden_dim=hidden_dim, dtype=dtype, - topology="single_node", - process_group=group, + comm_backend=TorchDistBackend(), ) test_loop = 5 @@ -445,6 +446,7 @@ def multi_process_parallel( ) +# Run as: python tests/comm/test_trtllm_allreduce_fusion.py @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index b71889e864..26662fa71a 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -4,14 +4,22 @@ import pytest import torch -from mpi4py import MPI # Added MPI import +import torch.distributed as dist import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import TorchDistBackend # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm +# Test helpers +from tests.test_helpers.comm import ( + init_torch_distributed_from_mpi, +) + +# Note: torch.distributed cleanup is handled by tests/comm/conftest.py + @torch.inference_mode() def row_linear_residual_norm_fusion_forward( @@ -25,7 +33,7 @@ def row_linear_residual_norm_fusion_forward( workspace: trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -41,7 +49,7 @@ def func( use_pdl = True if enable_fusion: - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() output, residual_out = ( trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( @@ -121,7 +129,7 @@ def row_linear_residual_norm_fusion_forward_legacy( ): tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -145,7 +153,7 @@ def func( prenorm_output = torch.empty_like(residual) normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output, @@ -231,10 +239,10 @@ def func( def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): - # Communicator used for passing data between ranks - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() + # Use torch.distributed for communication between ranks + rank = dist.get_rank() + world_size = dist.get_world_size() + if rank == 0: x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) residual = torch.randn((seq_len, hidden_size), dtype=dtype) @@ -244,10 +252,10 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion residual = None norm_weight = None - # Use lowercase bcast() for Python object broadcasting - x_full = comm.bcast(x_full, root=0) - residual = comm.bcast(residual, root=0) - norm_weight = comm.bcast(norm_weight, root=0) + # Use torch.distributed broadcast_object_list for Python object broadcasting + data_list = [x_full, residual, norm_weight] + dist.broadcast_object_list(data_list, src=0) + x_full, residual, norm_weight = data_list x_full = x_full.cuda() residual = residual.cuda() @@ -291,16 +299,20 @@ def run_mnnvl_ar_full( explicit_workspace_bytes: If provided, use this workspace size instead of default """ - comm = MPI.COMM_WORLD - # Get MPI info - rank = comm.Get_rank() - world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") + + # Initialize torch.distributed (safe to call if already initialized) + init_torch_distributed_from_mpi() + + # Get rank info from torch.distributed + rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size < 2: - pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") + pytest.skip(f"This test requires at least 2 ranks, got {world_size}") mapping = Mapping( world_size=world_size, @@ -312,6 +324,9 @@ def run_mnnvl_ar_full( # Set CUDA device based on rank torch.cuda.set_device(mapping.local_rank) + # Create TorchDistBackend for workspace creation (non-MPI based) + comm_backend = TorchDistBackend() + if mapping.local_rank == 0: print( f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" @@ -330,7 +345,10 @@ def run_mnnvl_ar_full( if legacy_api: mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes + mapping, + dtype, + comm_backend_for_handle_transfer=comm_backend, + buffer_size_in_bytes=legacy_explicit_workspace_bytes, ) ) @@ -346,6 +364,7 @@ def run_mnnvl_ar_full( max_num_tokens=max(seq_lens), hidden_dim=hidden_size, dtype=dtype, + comm_backend=comm_backend, ) test_data = [] @@ -392,8 +411,8 @@ def run_mnnvl_ar_full( workspace, ) - # Synchronize before next test - trtllm_mnnvl_ar.mpi_barrier() + # Synchronize before next test using torch.distributed barrier + dist.barrier() print( f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" @@ -405,32 +424,34 @@ def run_mnnvl_ar_full( print(failure_message) print(traceback.format_exc()) - # Gather failure status from all ranks for logging - all_failures = MPI.COMM_WORLD.allgather(rank_failed) + # Gather failure status from all ranks using torch.distributed + all_failures = [None] * world_size + dist.all_gather_object(all_failures, rank_failed) if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") - # Cleanup before re-raising - if "workspace" in locals(): - del workspace - # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: - # Ensure cleanup happens for this list's workspace - if "workspace" in locals(): - del workspace + # Explicitly destroy workspace to avoid __del__ issues during Python shutdown + if "workspace" in locals() and workspace is not None: + workspace.destroy() + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl - # Final synchronization and check for failures across all ranks - trtllm_mnnvl_ar.mpi_barrier() + # Final synchronization using torch.distributed barrier + dist.barrier() """Test with default workspace size""" +# Multi-gpu test: mpirun -np 4 pytest tests/comm/test_trtllm_mnnvl_allreduce.py -vv -s +# Multi-node test:srun -A coreai_libraries_cudnn -N4 --container-image= -J --mpi=pmix -- bash -c 'hostname && cd && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py' + @pytest.mark.parametrize( "seq_lens", diff --git a/tests/test_helpers/comm.py b/tests/test_helpers/comm.py index fcd4e0a23b..e327b0e861 100644 --- a/tests/test_helpers/comm.py +++ b/tests/test_helpers/comm.py @@ -4,55 +4,134 @@ import pytest import torch import torch.distributed as dist -from mpi4py import MPI + + +def _get_rank_info_from_env(): + """Get rank and world_size from environment variables. + + Supports multiple launchers: + - SLURM (srun): SLURM_PROCID, SLURM_NTASKS, SLURM_LOCALID + - torchrun: RANK, WORLD_SIZE, LOCAL_RANK + - MPI (fallback): Uses mpi4py if environment variables not found + + Returns: + tuple: (rank, world_size, local_rank) + """ + # Try SLURM environment variables first (set by srun) + if "SLURM_PROCID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + local_rank = int( + os.environ.get("SLURM_LOCALID", rank % torch.cuda.device_count()) + ) + return rank, world_size, local_rank + + # Try torchrun/torch.distributed.launch environment variables + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())) + return rank, world_size, local_rank + + # Fallback to MPI if available + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + local_rank = rank % torch.cuda.device_count() + return rank, world_size, local_rank + except ImportError as e: + raise RuntimeError( + "Could not determine rank/world_size. " + "Please set SLURM_PROCID/SLURM_NTASKS (srun), " + "RANK/WORLD_SIZE (torchrun), or install mpi4py." + ) from e def setup_mpi_and_cuda(): - """Setup MPI and CUDA device for tests. + """Setup distributed environment and CUDA device for tests. Returns: tuple: (rank, world_size, gpus_per_node) Raises: - pytest.skip: If no CUDA devices or fewer than 2 MPI ranks + pytest.skip: If no CUDA devices or fewer than 2 ranks """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("Tests require at least one CUDA device per node") + + rank, world_size, local_rank = _get_rank_info_from_env() + if world_size < 2: - pytest.skip(f"Tests require at least 2 MPI ranks, got {world_size}") + pytest.skip(f"Tests require at least 2 ranks, got {world_size}") - local_rank = rank % gpus_per_node torch.cuda.set_device(local_rank) return rank, world_size, gpus_per_node +def _get_master_addr(): + """Get the master address for torch.distributed. + + For multi-node SLURM jobs, extracts the first node from SLURM_NODELIST. + """ + if "MASTER_ADDR" in os.environ: + return os.environ["MASTER_ADDR"] + + # For SLURM multi-node: get first node from nodelist + if "SLURM_NODELIST" in os.environ: + import subprocess + + try: + # Use scontrol to expand the nodelist and get first node + result = subprocess.run( + ["scontrol", "show", "hostnames", os.environ["SLURM_NODELIST"]], + capture_output=True, + text=True, + check=True, + ) + first_node = result.stdout.strip().split("\n")[0] + return first_node + except (subprocess.CalledProcessError, FileNotFoundError): + # scontrol not available, try simple parsing + nodelist = os.environ["SLURM_NODELIST"] + # Handle simple cases like "node[0-3]" -> "node0" or "node0,node1" -> "node0" + if "[" in nodelist: + base = nodelist.split("[")[0] + nums = nodelist.split("[")[1].split("]")[0] + first_num = nums.split(",")[0].split("-")[0] + return f"{base}{first_num}" + else: + return nodelist.split(",")[0] + + return "localhost" + + def init_torch_distributed_from_mpi(): - """Initialize torch.distributed using MPI rank info. + """Initialize torch.distributed using environment rank info. - This allows running torch.distributed operations within an MPI context. + Supports SLURM (srun), torchrun, or MPI launchers. Safe to call multiple times - will skip if already initialized. - """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() + Uses explicit init_method to avoid modifying environment variables. + """ if dist.is_initialized(): return - # Set environment variables for torch.distributed - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) + rank, world_size, local_rank = _get_rank_info_from_env() + + # Use explicit init_method instead of setting environment variables + master_addr = _get_master_addr() + master_port = os.environ.get("MASTER_PORT", "29500") + init_method = f"tcp://{master_addr}:{master_port}" dist.init_process_group( backend="nccl", + init_method=init_method, rank=rank, world_size=world_size, )