diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index fc6f433403..828ddcb321 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -56,6 +56,7 @@ from .workspace_base import AllReduceFusionWorkspace import torch +from torch.distributed import ProcessGroup from flashinfer.api_logging import flashinfer_api from flashinfer.trace.templates.comm import allreduce_fusion_trace @@ -105,6 +106,7 @@ def __init__( hidden_dim: int, dtype: torch.dtype = torch.float16, comm_backend: Optional[CommBackend] = None, + group: Optional[ProcessGroup] = None, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -116,7 +118,7 @@ def __init__( hidden_dim: Hidden dimension size dtype: Data type comm_backend: Communication backend - **kwargs: Additional arguments for workspace creation + group: Process group for symmetric memory rendezvous. Defaults to torch.distributed.group.WORLD. """ super().__init__(tp_size, tp_rank) @@ -126,6 +128,7 @@ def __init__( tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, + group=group, comm_backend=comm_backend, create_metadata=True, use_fp32_lamport=dtype == torch.float32, @@ -290,6 +293,7 @@ def create_allreduce_fusion_workspace( gpus_per_node: int = None, comm_backend: Optional[CommBackend] = None, force_oneshot_support: bool = False, + group: Optional[ProcessGroup] = None, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -324,6 +328,7 @@ def create_allreduce_fusion_workspace( False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold. Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy. The trtllm backend will be sufficient for both strategies. + group: Process group for symmetric memory rendezvous (trtllm backend only). Defaults to torch.distributed.group.WORLD. Returns: Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) @@ -412,6 +417,7 @@ def create_allreduce_fusion_workspace( hidden_dim=hidden_dim, dtype=dtype, comm_backend=comm_backend, + group=group, ) elif actual_backend == "mnnvl":