Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

from .mapping import Mapping

from .mnnvl import CommBackend, SymmDeviceMemory
from .mnnvl import CommBackend

# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants)
# Import them for runtime use but type hint as int for mypy compatibility
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
dtype: torch.dtype = torch.float16,
comm_backend: Optional[CommBackend] = None,
group: Optional[ProcessGroup] = None,
use_torch_symm_mem: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aleozlx is this considered breaking backward compatibility? or is this fine?

):
"""
Create TensorRT-LLM AllReduce fusion workspace.
Expand All @@ -118,7 +119,9 @@ def __init__(
hidden_dim: Hidden dimension size
dtype: Data type
comm_backend: Communication backend
group: Process group for symmetric memory rendezvous. Defaults to torch.distributed.group.WORLD.
group: Process group for workspace allocation. Defaults to torch.distributed.group.WORLD.
use_torch_symm_mem: If True, use torch symmetric memory for workspace
allocation. Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory).
"""
super().__init__(tp_size, tp_rank)

Expand All @@ -133,12 +136,13 @@ def __init__(
create_metadata=True,
use_fp32_lamport=dtype == torch.float32,
use_symm_dev_mem=True,
use_torch_symm_mem=use_torch_symm_mem,
)

# Store essential attributes for easy access
# Cast to 3-tuple to make linter happy, since we always call with create_metadata=True
workspace_tuple = cast(
Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict],
Tuple[List[List[int]], torch.Tensor, List[Any], dict],
self._internal_workspace,
)
self.ipc_handles = workspace_tuple[0]
Expand Down Expand Up @@ -294,6 +298,7 @@ def create_allreduce_fusion_workspace(
comm_backend: Optional[CommBackend] = None,
force_oneshot_support: bool = False,
group: Optional[ProcessGroup] = None,
use_torch_symm_mem: bool = False,
) -> AllReduceFusionWorkspace:
"""
Create workspace for AllReduce fusion operations.
Expand Down Expand Up @@ -328,7 +333,9 @@ def create_allreduce_fusion_workspace(
False: Allocate workspace for twoshot strategy for all problem sizes, and for oneshot strategy up to the heuristic threshold.
Note that only the workspace for MNNVL backend needs to be initialized with the correct strategy.
The trtllm backend will be sufficient for both strategies.
group: Process group for symmetric memory rendezvous (trtllm backend only). Defaults to torch.distributed.group.WORLD.
group: Process group for workspace allocation (trtllm backend only). Defaults to torch.distributed.group.WORLD.
use_torch_symm_mem: If True, use torch symmetric memory for workspace allocation.
Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory).

Returns:
Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace)
Expand Down Expand Up @@ -418,6 +425,7 @@ def create_allreduce_fusion_workspace(
dtype=dtype,
comm_backend=comm_backend,
group=group,
use_torch_symm_mem=use_torch_symm_mem,
)

elif actual_backend == "mnnvl":
Expand Down Expand Up @@ -446,6 +454,7 @@ def create_allreduce_fusion_workspace(
dtype=dtype,
comm_backend=comm_backend,
buffer_size_in_bytes=buffer_size_in_bytes,
use_torch_symm_mem=use_torch_symm_mem,
)
else:
raise RuntimeError(f"Unknown backend: {actual_backend}")
Expand Down
114 changes: 96 additions & 18 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
from ctypes import c_void_p, cast
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from typing_extensions import deprecated

from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend
Expand All @@ -31,7 +31,40 @@

logger = logging.getLogger(__name__)
from .cuda_ipc import cudart
from .torch_symmetric_memory import _alloc_symm_buffer_bytes


def _alloc_trtllm_ar_workspace_buffer(
size_bytes: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
comm_backend: CommBackend,
use_torch_symm_mem: bool,
) -> tuple[list[int], Any]:
if use_torch_symm_mem:
from .torch_symmetric_memory import _alloc_symm_buffer_bytes

ptrs, tensor, handle = _alloc_symm_buffer_bytes(
size_bytes,
world_size,
dtype,
device,
group_name,
)
return ptrs, (tensor, handle)

symm_mem = SymmDeviceMemory(
size_bytes,
world_size,
rank,
device.index if device.index is not None else torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
return symm_mem.uc_ptrs, symm_mem


class AllReduceStrategyType:
Expand Down Expand Up @@ -427,7 +460,7 @@ def trtllm_moe_finalize_allreduce_fusion(
MAX_ALL_REDUCE_BLOCKS = 24
LamportTokenNumThreshold = 16

_symm_workspace_refs: dict[int, list[torch.Tensor]] = {}
_symm_workspace_refs: dict[int, list[Any]] = {}


@deprecated(
Expand All @@ -439,6 +472,7 @@ def trtllm_create_ipc_workspace_for_all_reduce(
max_token_num: int,
hidden_dim,
group: Optional[ProcessGroup] = None,
use_torch_symm_mem: bool = False,
) -> List[List[int]]:
"""
Parameters:
Expand All @@ -447,6 +481,8 @@ def trtllm_create_ipc_workspace_for_all_reduce(
- max_token_num: the maximum number of tokens in a sequence.
- hidden_dim: the dimension of the hidden states.
- group: the process group to use.
- use_torch_symm_mem: if True, use torch symmetric memory for workspace allocation.
Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory).

Note:
This function is used to create a workspace for all reduce.
Expand Down Expand Up @@ -487,7 +523,21 @@ def trtllm_create_ipc_workspace_for_all_reduce(
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs: list[torch.Tensor] = []
if group is not None:
comm_backend = TorchDistBackend(group=group)
else:
comm_backend = TorchDistBackend()
if use_torch_symm_mem:
group_size = comm_backend.Get_size()
group_rank = comm_backend.Get_rank()
if group_size != tp_size or group_rank != rank:
raise ValueError(
"use_torch_symm_mem=True requires "
"tp_size/rank to match the TorchDistBackend process group "
f"(group size/rank: {group_size}/{group_rank}, "
f"tp_size/rank: {tp_size}/{rank})."
)
symm_refs: list[Any] = []
ipc_handles = list()

for size, dtype in [
Expand All @@ -500,14 +550,17 @@ def trtllm_create_ipc_workspace_for_all_reduce(
(lamport_buffer_size, torch.float16),
]:
aligned_size = round_up(size, 16)
ptrs, tensor, handle = _alloc_symm_buffer_bytes(
ptrs, mem_ref = _alloc_trtllm_ar_workspace_buffer(
aligned_size,
tp_size,
rank,
dtype,
device,
group_name,
comm_backend,
use_torch_symm_mem,
)
symm_refs.append((tensor, handle))
symm_refs.append(mem_ref)
ipc_handles.append(ptrs)

logger.debug(
Expand Down Expand Up @@ -564,10 +617,11 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
create_metadata: bool = False,
comm_backend: Optional[CommBackend] = None,
use_symm_dev_mem: bool = False,
use_torch_symm_mem: bool = False,
) -> Union[
Tuple[List[List[int]], torch.Tensor],
Tuple[List[List[int]], torch.Tensor, dict],
Tuple[List[List[int]], torch.Tensor, List[SymmDeviceMemory], dict],
Tuple[List[List[int]], torch.Tensor, List[Any], dict],
]:
"""
Parameters:
Expand All @@ -580,6 +634,8 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
- create_metadata: if True, return metadata dict as third element (default: False).
- comm_backend: the communication backend to use.
- use_symm_dev_mem: if True, we will use symmetric device memory for the workspace.
- use_torch_symm_mem: if True, use torch symmetric memory for workspace allocation.
Defaults to False (uses FlashInfer/TensorRT-style SymmDeviceMemory).

Returns:
- If create_metadata=False: (ipc_handles, workspace_tensor)
Expand All @@ -589,7 +645,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
- If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata)
where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim,
use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size
and mem_handles is a list of SymmDeviceMemory objects.
and mem_handles is a list of allocator-owned memory handles.

Note: The optional parameters make the API clunky at this time. This will be refactored in the future, at the cost of backward compatibility, where the default behavior will be
create_metadata=True and use_symm_dev_mem=True.
Expand All @@ -610,7 +666,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init
"""

if comm_backend is None and use_symm_dev_mem:
if comm_backend is None:
comm_backend = TorchDistBackend(group=group)

# No need to support all variations. In the future we only support create_metadata=True and use_symm_dev_mem=True.
Expand Down Expand Up @@ -640,13 +696,35 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs: list[torch.Tensor] = []
if use_torch_symm_mem:
if not isinstance(comm_backend, TorchDistBackend):
raise ValueError(
"use_torch_symm_mem=True requires "
"trtllm_create_ipc_workspace_for_all_reduce_fusion to use a "
"TorchDistBackend."
)
torch_group = (
comm_backend._group
if comm_backend._group is not None
else torch.distributed.group.WORLD
)
group_size = comm_backend.Get_size()
group_rank = comm_backend.Get_rank()
if group_size != tp_size or group_rank != tp_rank:
raise ValueError(
"use_torch_symm_mem=True requires "
"tp_size/tp_rank to match the TorchDistBackend process group "
f"(group size/rank: {group_size}/{group_rank}, "
f"tp_size/tp_rank: {tp_size}/{tp_rank})."
)
group_name = torch_group.group_name
symm_refs: list[Any] = []

# we should init 3 buffers for all reduce fusion:
# [buffer_size, flag_size, lamport_buffer_size]

ipc_handles: List[List[int]] = list()
mem_handles: List[SymmDeviceMemory] = list()
mem_handles: List[Any] = list()
lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32
for size, dtype in [
(buffer_size, torch.float32),
Expand All @@ -655,16 +733,19 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
]:
aligned_size = round_up(size, 16)

ptrs, tensor, handle = _alloc_symm_buffer_bytes(
ptrs, mem_ref = _alloc_trtllm_ar_workspace_buffer(
aligned_size,
tp_size,
tp_rank,
dtype,
device,
group_name,
comm_backend,
use_torch_symm_mem,
)
symm_refs.append((tensor, handle))
symm_refs.append(mem_ref)
ipc_handles.append(ptrs)
mem_handles.append(handle)
mem_handles.append(mem_ref)

logger.debug(
"rank %s allocated ipc_handles: %s",
Expand Down Expand Up @@ -723,10 +804,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
workspace, dtype=torch.int64, device=torch.device("cuda")
)

if use_symm_dev_mem:
comm_backend.barrier() # must sync after create_workspace
else:
dist.barrier(group=group)
comm_backend.barrier() # must sync after create_workspace

if create_metadata:
metadata = {
Expand Down
Loading
Loading