diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index c0bb663bd5..4a78443258 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -69,7 +69,7 @@ from .mapping import Mapping -from .mnnvl import CommBackend, SymmDeviceMemory +from .mnnvl import CommBackend # Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) # Import them for runtime use but type hint as int for mypy compatibility @@ -107,6 +107,7 @@ def __init__( dtype: torch.dtype = torch.float16, comm_backend: Optional[CommBackend] = None, group: Optional[ProcessGroup] = None, + use_torch_symm_mem: bool = False, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -118,7 +119,9 @@ def __init__( hidden_dim: Hidden dimension size dtype: Data type comm_backend: Communication backend - group: Process group for symmetric memory rendezvous. Defaults to torch.distributed.group.WORLD. + group: Process group for workspace allocation. Defaults to torch.distributed.group.WORLD. + use_torch_symm_mem: If True, use torch symmetric memory for workspace + allocation. Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory). """ super().__init__(tp_size, tp_rank) @@ -133,12 +136,13 @@ def __init__( create_metadata=True, use_fp32_lamport=dtype == torch.float32, use_symm_dev_mem=True, + use_torch_symm_mem=use_torch_symm_mem, ) # Store essential attributes for easy access # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True workspace_tuple = cast( - Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], + Tuple[List[List[int]], torch.Tensor, List[Any], dict], self._internal_workspace, ) self.ipc_handles = workspace_tuple[0] @@ -294,6 +298,7 @@ def create_allreduce_fusion_workspace( comm_backend: Optional[CommBackend] = None, force_oneshot_support: bool = False, group: Optional[ProcessGroup] = None, + use_torch_symm_mem: bool = False, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -328,7 +333,9 @@ def create_allreduce_fusion_workspace( False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. The trtllm backend will be sufficient for both strategies. - group: Process group for symmetric memory rendezvous (trtllm backend only). Defaults to torch.distributed.group.WORLD. + group: Process group for workspace allocation (trtllm backend only). Defaults to torch.distributed.group.WORLD. + use_torch_symm_mem: If True, use torch symmetric memory for workspace allocation. + Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory). Returns: Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) @@ -418,6 +425,7 @@ def create_allreduce_fusion_workspace( dtype=dtype, comm_backend=comm_backend, group=group, + use_torch_symm_mem=use_torch_symm_mem, ) elif actual_backend == "mnnvl": @@ -446,6 +454,7 @@ def create_allreduce_fusion_workspace( dtype=dtype, comm_backend=comm_backend, buffer_size_in_bytes=buffer_size_in_bytes, + use_torch_symm_mem=use_torch_symm_mem, ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 240826cbc7..f04177d986 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -18,7 +18,7 @@ import logging from ctypes import c_void_p, cast from types import SimpleNamespace -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from typing_extensions import deprecated from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend @@ -31,7 +31,40 @@ logger = logging.getLogger(__name__) from .cuda_ipc import cudart -from .torch_symmetric_memory import _alloc_symm_buffer_bytes + + +def _alloc_trtllm_ar_workspace_buffer( + size_bytes: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + group_name: str, + comm_backend: CommBackend, + use_torch_symm_mem: bool, +) -> tuple[list[int], Any]: + if use_torch_symm_mem: + from .torch_symmetric_memory import _alloc_symm_buffer_bytes + + ptrs, tensor, handle = _alloc_symm_buffer_bytes( + size_bytes, + world_size, + dtype, + device, + group_name, + ) + return ptrs, (tensor, handle) + + symm_mem = SymmDeviceMemory( + size_bytes, + world_size, + rank, + device.index if device.index is not None else torch.cuda.current_device(), + comm_backend, + enable_multicast=False, + allocate_signal_pads=False, + ) + return symm_mem.uc_ptrs, symm_mem class AllReduceStrategyType: @@ -427,7 +460,7 @@ def trtllm_moe_finalize_allreduce_fusion( MAX_ALL_REDUCE_BLOCKS = 24 LamportTokenNumThreshold = 16 -_symm_workspace_refs: dict[int, list[torch.Tensor]] = {} +_symm_workspace_refs: dict[int, list[Any]] = {} @deprecated( @@ -439,6 +472,7 @@ def trtllm_create_ipc_workspace_for_all_reduce( max_token_num: int, hidden_dim, group: Optional[ProcessGroup] = None, + use_torch_symm_mem: bool = False, ) -> List[List[int]]: """ Parameters: @@ -447,6 +481,8 @@ def trtllm_create_ipc_workspace_for_all_reduce( - max_token_num: the maximum number of tokens in a sequence. - hidden_dim: the dimension of the hidden states. - group: the process group to use. + - use_torch_symm_mem: if True, use torch symmetric memory for workspace allocation. + Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory). Note: This function is used to create a workspace for all reduce. @@ -487,7 +523,21 @@ def trtllm_create_ipc_workspace_for_all_reduce( if group is not None else torch.distributed.group.WORLD.group_name ) - symm_refs: list[torch.Tensor] = [] + if group is not None: + comm_backend = TorchDistBackend(group=group) + else: + comm_backend = TorchDistBackend() + if use_torch_symm_mem: + group_size = comm_backend.Get_size() + group_rank = comm_backend.Get_rank() + if group_size != tp_size or group_rank != rank: + raise ValueError( + "use_torch_symm_mem=True requires " + "tp_size/rank to match the TorchDistBackend process group " + f"(group size/rank: {group_size}/{group_rank}, " + f"tp_size/rank: {tp_size}/{rank})." + ) + symm_refs: list[Any] = [] ipc_handles = list() for size, dtype in [ @@ -500,14 +550,17 @@ def trtllm_create_ipc_workspace_for_all_reduce( (lamport_buffer_size, torch.float16), ]: aligned_size = round_up(size, 16) - ptrs, tensor, handle = _alloc_symm_buffer_bytes( + ptrs, mem_ref = _alloc_trtllm_ar_workspace_buffer( aligned_size, tp_size, + rank, dtype, device, group_name, + comm_backend, + use_torch_symm_mem, ) - symm_refs.append((tensor, handle)) + symm_refs.append(mem_ref) ipc_handles.append(ptrs) logger.debug( @@ -564,10 +617,11 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( create_metadata: bool = False, comm_backend: Optional[CommBackend] = None, use_symm_dev_mem: bool = False, + use_torch_symm_mem: bool = False, ) -> Union[ Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict], - Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict], + Tuple[List[List[int]], torch.Tensor, List[Any], dict], ]: """ Parameters: @@ -580,6 +634,8 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - create_metadata: if True, return metadata dict as third element (default: False). - comm_backend: the communication backend to use. - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace. + - use_torch_symm_mem: if True, use torch symmetric memory for workspace allocation. + Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory). Returns: - If create_metadata=False: (ipc_handles, workspace_tensor) @@ -589,7 +645,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( - If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata) where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size - and mem_handles is a list of SymmDeviceMemory objects. + and mem_handles is a list of allocator-owned memory handles. Note: The optional parameters make the API clunky at this time. This will be refactored in the future, at the cost of backward compatibility, where the default behavior will be create_metadata=True and use_symm_dev_mem=True. @@ -610,7 +666,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init """ - if comm_backend is None and use_symm_dev_mem: + if comm_backend is None: comm_backend = TorchDistBackend(group=group) # No need to support all variations. In the future we only support create_metadata=True and use_symm_dev_mem=True. @@ -640,13 +696,35 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( if group is not None else torch.distributed.group.WORLD.group_name ) - symm_refs: list[torch.Tensor] = [] + if use_torch_symm_mem: + if not isinstance(comm_backend, TorchDistBackend): + raise ValueError( + "use_torch_symm_mem=True requires " + "trtllm_create_ipc_workspace_for_all_reduce_fusion to use a " + "TorchDistBackend." + ) + torch_group = ( + comm_backend._group + if comm_backend._group is not None + else torch.distributed.group.WORLD + ) + group_size = comm_backend.Get_size() + group_rank = comm_backend.Get_rank() + if group_size != tp_size or group_rank != tp_rank: + raise ValueError( + "use_torch_symm_mem=True requires " + "tp_size/tp_rank to match the TorchDistBackend process group " + f"(group size/rank: {group_size}/{group_rank}, " + f"tp_size/tp_rank: {tp_size}/{tp_rank})." + ) + group_name = torch_group.group_name + symm_refs: list[Any] = [] # we should init 3 buffers for all reduce fusion: # [buffer_size, flag_size, lamport_buffer_size] ipc_handles: List[List[int]] = list() - mem_handles: List[SymmDeviceMemory] = list() + mem_handles: List[Any] = list() lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32 for size, dtype in [ (buffer_size, torch.float32), @@ -655,16 +733,19 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( ]: aligned_size = round_up(size, 16) - ptrs, tensor, handle = _alloc_symm_buffer_bytes( + ptrs, mem_ref = _alloc_trtllm_ar_workspace_buffer( aligned_size, tp_size, + tp_rank, dtype, device, group_name, + comm_backend, + use_torch_symm_mem, ) - symm_refs.append((tensor, handle)) + symm_refs.append(mem_ref) ipc_handles.append(ptrs) - mem_handles.append(handle) + mem_handles.append(mem_ref) logger.debug( "rank %s allocated ipc_handles: %s", @@ -723,10 +804,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( workspace, dtype=torch.int64, device=torch.device("cuda") ) - if use_symm_dev_mem: - comm_backend.barrier() # must sync after create_workspace - else: - dist.barrier(group=group) + comm_backend.barrier() # must sync after create_workspace if create_metadata: metadata = { diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 32464ca98f..1c3a01a831 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -18,9 +18,8 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import CommBackend, MPIBackend +from .mnnvl import CommBackend, McastGPUBuffer, MPIBackend from .workspace_base import AllReduceFusionWorkspace -from .torch_symmetric_memory import _alloc_symm_buffer_bytes def mpi_barrier(): @@ -61,6 +60,7 @@ def __init__( dtype: Optional[torch.dtype] = None, buffer_size_in_bytes: Optional[int] = None, comm_backend: Optional[CommBackend] = None, + use_torch_symm_mem: bool = False, ): """ Initialize the MNNVL Allreduce Fusion Workspace. The workspace will be allocated and initialized based on the provided problem size. If max_num_tokens is larger than the one-shot threshold, the workspace will be created according to the max of required one-shot size at threshold, or the required two-shot size. Note that the workspace is not bind to the given problem size. It can be reused for different problem size without reinitialization given the allocated size is sufficient. @@ -138,26 +138,50 @@ def __init__( # support base_gpu_id != 0 scenarios where the actual CUDA device # index differs from the TP rank / local_rank. device = torch.device("cuda", torch.cuda.current_device()) - if isinstance(comm_backend, TorchDistBackend): + self._uses_torch_symm_mem = use_torch_symm_mem + if self._uses_torch_symm_mem: + if not isinstance(comm_backend, TorchDistBackend): + raise ValueError( + "use_torch_symm_mem=True requires " + "MNNVLAllReduceFusionWorkspace to be constructed with a " + "TorchDistBackend." + ) group = ( comm_backend._group if comm_backend._group is not None else torch.distributed.group.WORLD ) + group_size = comm_backend.Get_size() + group_rank = comm_backend.Get_rank() + if group_size != mapping.tp_size or group_rank != mapping.tp_rank: + raise ValueError( + "use_torch_symm_mem=True requires " + "Mapping.tp_size/tp_rank to match the TorchDistBackend " + f"process group (group size/rank: {group_size}/{group_rank}, " + f"mapping.tp_size/tp_rank: {mapping.tp_size}/{mapping.tp_rank})." + ) group_name = group.group_name - else: - group_name = torch.distributed.group.WORLD.group_name - self.ptrs, self.tensor, self.handle = _alloc_symm_buffer_bytes( - requested_workspace_size, - mapping.tp_size, - torch.float32, - device, - group_name, - ) + from .torch_symmetric_memory import _alloc_symm_buffer_bytes - # handle.buffer_size is the usable data size. torch symmetric memory - # allocator places signal_pad on top of it, not carved from within. - allocated_size = self.handle.buffer_size + self.ptrs, self.tensor, self.handle = _alloc_symm_buffer_bytes( + requested_workspace_size, + mapping.tp_size, + torch.float32, + device, + group_name, + ) + # handle.buffer_size is the usable data size. torch symmetric memory + # allocator places signal_pad on top of it, not carved from within. + allocated_size = self.handle.buffer_size + else: + self.mcast_buffer_handle = McastGPUBuffer( + requested_workspace_size, + mapping.tp_size, + mapping.tp_rank, + device, + comm_backend, + ) + allocated_size = self.mcast_buffer_handle.buf_size # We want the buffer size to be aligned to 16B which is the granularity for buffer management. self.buffer_size_bytes = ( math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 @@ -169,8 +193,12 @@ def __init__( f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes." ) - # lamport initialize tensor to negative zero. - self.tensor.fill_(-0.0) + if self._uses_torch_symm_mem: + # lamport initialize tensor to negative zero. + self.tensor.fill_(-0.0) + else: + # The MNNVL workspace uses FP32 sentinels regardless of the tensor dtype. + self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) # Wait until the initialization is done torch.cuda.synchronize() comm_backend.barrier() @@ -186,9 +214,14 @@ def __init__( device=torch.device("cuda", torch.cuda.current_device()), ) - self.uc_ptrs_dev = self.handle.buffer_ptrs_dev - self.uc_ptr_local = self.handle.buffer_ptrs[self.rank] - self.mc_ptr = self.handle.multicast_ptr + if self._uses_torch_symm_mem: + self.uc_ptrs_dev = self.handle.buffer_ptrs_dev + self.uc_ptr_local = self.handle.buffer_ptrs[self.rank] + self.mc_ptr = self.handle.multicast_ptr + else: + self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() + self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) + self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() @functools.cache def is_buffer_size_sufficient( @@ -252,9 +285,12 @@ def destroy(self) -> None: del self.uc_ptrs_dev del self.uc_ptr_local del self.mc_ptr - del self.tensor - del self.handle - del self.ptrs + if getattr(self, "_uses_torch_symm_mem", True): + del self.tensor + del self.handle + del self.ptrs + else: + del self.mcast_buffer_handle self._destroyed = True @@ -539,7 +575,7 @@ def get_allreduce_mnnvl_workspace( Returns: Tuple containing: - - MNNVLAllReduceFusionWorkspace: The workspace object backed by torch symmetric memory + - MNNVLAllReduceFusionWorkspace: The workspace object. - torch.Tensor: Buffer flags tensor tracking state - int: Maximum number of elements that can fit in buffer """ diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e0f7877502..0dce863faf 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -401,7 +401,7 @@ def run_mnnvl_ar_full( multicast_ptr = legacy_workspace.mc_ptr buffer_ptrs_dev = legacy_workspace.uc_ptrs_dev - unicast_ptr = legacy_workspace.handle.buffer_ptrs[mapping.tp_rank] + unicast_ptr = legacy_workspace.uc_ptr_local else: workspace = trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace(