Skip to content
94 changes: 42 additions & 52 deletions flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@

from .trtllm_ar import trtllm_allreduce_fusion
from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion
from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata

from .mapping import Mapping

from .mnnvl import CommBackend
from .mnnvl import CommBackend, SymmDeviceMemory

# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants)
# Import them for runtime use but type hint as int for mypy compatibility
from .trtllm_ar import AllReduceFusionPattern
from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace
from .trtllm_mnnvl_ar import MNNVLAllreduceFusionStrategy
from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce
from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype = torch.float16,
process_group: Optional["torch.distributed.ProcessGroup"] = None,
comm_backend: Optional[CommBackend] = None,
):
"""
Create TensorRT-LLM AllReduce fusion workspace.
Expand All @@ -106,7 +106,7 @@ def __init__(
max_token_num: Maximum number of tokens
hidden_dim: Hidden dimension size
dtype: Data type
process_group: PyTorch distributed process group
comm_backend: Communication backend
**kwargs: Additional arguments for workspace creation
"""
super().__init__(tp_size, tp_rank)
Expand All @@ -117,19 +117,22 @@ def __init__(
tp_size=tp_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=process_group,
comm_backend=comm_backend,
create_metadata=True,
use_fp32_lamport=dtype == torch.float32,
use_symm_dev_mem=True,
)

# Store essential attributes for easy access
# Cast to 3-tuple to make linter happy, since we always call with create_metadata=True
workspace_tuple = cast(
Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace
Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict],
self._internal_workspace,
)
self.ipc_handles = workspace_tuple[0]
self.workspace_tensor = workspace_tuple[1]
self.metadata = workspace_tuple[2]
self.mem_handles = workspace_tuple[2]
self.metadata = workspace_tuple[3]

@property
def backend(self) -> str:
Expand Down Expand Up @@ -165,7 +168,10 @@ def destroy(self) -> None:
if getattr(self, "_destroyed", False):
return # Already destroyed, nothing to do

trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles)
del self.ipc_handles
del self.workspace_tensor
del self.mem_handles
del self.metadata
self._destroyed = True


Expand All @@ -181,20 +187,14 @@ def _trtllm_workspace_check(
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
topology: Literal["single_node", "multi_node"],
) -> bool:
"""
Check if trtllm backend CAN be used for workspace creation.

Hard requirements:
- Single-node topology (multi-node not supported)

- Up to 16 ranks supported.
"""
# trtllm is optimized for single-node
if topology == "multi_node":
return False

return True
return world_size <= 16


def _mnnvl_workspace_check(
Expand All @@ -204,16 +204,12 @@ def _mnnvl_workspace_check(
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
topology: Literal["single_node", "multi_node"],
) -> bool:
"""
Check if mnnvl backend CAN be used for workspace creation.

"""

if topology == "multi_node":
return True

return True


Expand All @@ -230,7 +226,6 @@ def _workspace_creation_heuristic(
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
topology: Literal["single_node", "multi_node"],
) -> list[str]:
"""
Select best backend for workspace creation based on performance.
Expand All @@ -246,7 +241,6 @@ def _workspace_creation_heuristic(
max_token_num: Maximum number of tokens
hidden_dim: Hidden dimension size
dtype: Data type
topology: Network topology ("single_node" or "multi_node")
**kwargs: Additional arguments

Note that at this point, the backend selection does not take "runtime parameters" into account, such as layout_code, and fusion pattern.
Expand All @@ -262,13 +256,6 @@ def _workspace_creation_heuristic(

# Decision tree based on benchmark data

# Multi-node: MNNVL is designed for this
if topology == "multi_node":
if "mnnvl" in suitable_backends:
return ["mnnvl"]
else:
return [suitable_backends[0]]

# Single-node scenarios
# From benchmarking data, we can see that MNNVL is either on par (smaller problem sizes) or significantly faster than TRTLLM (larger problem sizes such as hidden_dim=8192, token_num=64 for TP=4), for single-node scenarios.
# However, trtllm has a larger support surface (more fusion patterns, more quantization support, etc.)
Expand All @@ -290,10 +277,9 @@ def create_allreduce_fusion_workspace(
max_token_num: int = None,
hidden_dim: int = None,
dtype: torch.dtype = None,
topology: Literal["single_node", "multi_node"] = "single_node",
process_group: Optional["torch.distributed.ProcessGroup"] = None,
gpus_per_node: int = None,
comm_backend: Optional[CommBackend] = None,
force_oneshot_support: bool = False,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> AllReduceFusionWorkspace:
"""
Create workspace for AllReduce fusion operations.
Expand All @@ -315,19 +301,19 @@ def create_allreduce_fusion_workspace(

Args:
backend: Backend to use ("trtllm", "mnnvl", or "auto")
"auto" uses heuristic to select best backend based on topology
and problem size
"auto" uses heuristic to select best backend
world_size: Number of ranks in the process group
rank: Current rank ID
max_token_num: Maximum number of tokens to support
hidden_dim: Hidden dimension size
dtype: Data type for communication tensors
topology: Network topology hint for backend selection
"single_node" - All ranks on one node (default)
"multi_node" - Ranks span multiple nodes
process_group: PyTorch distributed process group (for trtllm backend).
gpus_per_node: Number of GPUs per node (for multi-node topology).
comm_backend: Communication backend to use (for multi-node topology).
comm_backend: Communication backend to use.
force_oneshot_support: Allocate workspace for oneshot strategy vs twoshot
True: Allocate workspace for oneshot strategy up to the largest problem size requested
False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold.
Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy.
The trtllm backend will be sufficient for both strategies.

Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
Expand All @@ -338,15 +324,14 @@ def create_allreduce_fusion_workspace(
ValueError: If problem size not supported for the specified backend

Examples:
>>> # Auto-select best backend based on topology
>>> # Auto-select best backend
>>> workspace = create_allreduce_fusion_workspace(
... backend="auto",
... world_size=8,
... rank=0,
... max_token_num=2048,
... hidden_dim=4096,
... dtype=torch.bfloat16,
... topology="single_node"
... )
>>> print(workspace.backend) # "trtllm"
>>> print(workspace.get_workspace_capacity()) # 8388608 elements
Expand All @@ -363,15 +348,14 @@ def create_allreduce_fusion_workspace(
... max_token_num=2048,
... hidden_dim=4096,
... dtype=torch.bfloat16,
... topology="multi_node"
... )
>>> print(workspace.backend) # "mnnvl"
"""
if gpus_per_node is None:
gpus_per_node = min(torch.cuda.device_count(), world_size)
# Determine the actual backend to use
if backend == "auto":
# Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point)
# Find suitable backends (any compute capability check needs to be checked at kernel runtime, since there are no tensor available at this point)
suitable_backends = []
if _trtllm_workspace_check(
backend=backend,
Expand All @@ -380,7 +364,6 @@ def create_allreduce_fusion_workspace(
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
topology=topology,
):
suitable_backends.append("trtllm")
if _mnnvl_workspace_check(
Expand All @@ -390,15 +373,11 @@ def create_allreduce_fusion_workspace(
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
topology=topology,
):
suitable_backends.append("mnnvl")

if not suitable_backends:
raise ValueError(
f"No suitable backend found for topology={topology}. "
f"trtllm requires single_node topology, mnnvl works with both."
)
raise ValueError("No suitable backend found. ")

# Apply heuristic to select best backend
selected = _workspace_creation_heuristic(
Expand All @@ -409,7 +388,6 @@ def create_allreduce_fusion_workspace(
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
topology=topology,
)
actual_backend = selected[0]
else:
Expand All @@ -423,7 +401,7 @@ def create_allreduce_fusion_workspace(
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
process_group=process_group,
comm_backend=comm_backend,
)

elif actual_backend == "mnnvl":
Expand All @@ -433,12 +411,25 @@ def create_allreduce_fusion_workspace(
gpus_per_node=gpus_per_node,
tp_size=world_size,
)
buffer_size_in_bytes = None
if force_oneshot_support:
buffer_size_in_bytes = (
MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes(
world_size,
max_token_num,
hidden_dim,
dtype,
MNNVLAllreduceFusionStrategy.ONESHOT,
)
)

return MNNVLAllReduceFusionWorkspace(
mapping=mapping,
max_num_tokens=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
buffer_size_in_bytes=buffer_size_in_bytes,
)
else:
raise RuntimeError(f"Unknown backend: {actual_backend}")
Expand Down Expand Up @@ -514,7 +505,7 @@ def allreduce_fusion(
# ===== Control parameters =====
use_oneshot: Use oneshot strategy vs twoshot
If None, uses internal heuristics.
Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used.
Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This note is a bit vague. To improve clarity, consider explicitly mentioning the use_oneshot parameter it refers to.

Suggested change
Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
Note: when `use_oneshot` is explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.

fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce

Returns:
Expand All @@ -529,7 +520,6 @@ def allreduce_fusion(
... max_token_num=2048,
... hidden_dim=4096,
... dtype=torch.bfloat16,
... topology="single_node"
... )
>>>
>>> # Pre-allocate output tensors
Expand Down
Loading