Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .workspace_base import AllReduceFusionWorkspace

import torch
from torch.distributed import ProcessGroup

from flashinfer.api_logging import flashinfer_api
from flashinfer.trace.templates.comm import allreduce_fusion_trace
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
hidden_dim: int,
dtype: torch.dtype = torch.float16,
comm_backend: Optional[CommBackend] = None,
group: Optional[ProcessGroup] = None,
):
"""
Create TensorRT-LLM AllReduce fusion workspace.
Expand All @@ -116,7 +118,7 @@ def __init__(
hidden_dim: Hidden dimension size
dtype: Data type
comm_backend: Communication backend
**kwargs: Additional arguments for workspace creation
group: Process group for symmetric memory rendezvous. Defaults to torch.distributed.group.WORLD.
"""
super().__init__(tp_size, tp_rank)

Expand All @@ -126,6 +128,7 @@ def __init__(
tp_size=tp_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=group,
comm_backend=comm_backend,
create_metadata=True,
use_fp32_lamport=dtype == torch.float32,
Expand Down Expand Up @@ -290,6 +293,7 @@ def create_allreduce_fusion_workspace(
gpus_per_node: int = None,
comm_backend: Optional[CommBackend] = None,
force_oneshot_support: bool = False,
group: Optional[ProcessGroup] = None,
) -> AllReduceFusionWorkspace:
"""
Create workspace for AllReduce fusion operations.
Expand Down Expand Up @@ -324,6 +328,7 @@ def create_allreduce_fusion_workspace(
False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold.
Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy.
The trtllm backend will be sufficient for both strategies.
group: Process group for symmetric memory rendezvous (trtllm backend only). Defaults to torch.distributed.group.WORLD.

Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
Expand Down Expand Up @@ -412,6 +417,7 @@ def create_allreduce_fusion_workspace(
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
group=group,
)

elif actual_backend == "mnnvl":
Expand Down
Loading