Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
int64_t sizePerRank = workspace.size(1);
int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded;
TORCH_CHECK(sizePerRank >= requiredSize,
"Workspace size per rank insufficient. "
"Workspace size per rank insufficient for dispatch. "
"Need at least ",
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded,
" for payloads), but got ", sizePerRank);
Expand Down Expand Up @@ -404,8 +404,10 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke

int64_t payloadSize = payload.numel() * payload.element_size();
TORCH_CHECK(combinePayloadOffset >= 0 && combinePayloadOffset + payloadSize <= sizePerRank,
"workspace does not contain enough space for the payload region for combine. combine payload offset=",
combinePayloadOffset, ", payload size needed=", payloadSize, ", workspace size per rank=", sizePerRank);
"Workspace size per rank insufficient for combine. "
"Need at least ",
combinePayloadOffset + payloadSize, " bytes (", combinePayloadOffset, " for offset + ", payloadSize,
" for payload), but got ", sizePerRank);

// Create output tensor (local on current rank), no need for initialization
torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options());
Expand Down Expand Up @@ -508,6 +510,13 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
return t;
}

// Return the size of auxiliary data in workspace
int64_t moeA2AGetAuxDataSizeOp(int64_t epSize, int64_t maxNumTokens)
{
MoeA2ADataOffsets offsets = calculateOffsets(static_cast<int>(epSize), static_cast<int>(maxNumTokens));
return static_cast<int64_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
}

} // namespace moe_comm

} // namespace torch_ext
Expand Down Expand Up @@ -536,6 +545,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
"moe_a2a_get_combine_payload_tensor(Tensor(a) workspace, int ep_rank, int ep_size, int "
"runtime_max_tokens_per_rank, "
"int combine_payload_offset, ScalarType out_dtype, int hidden_size) -> Tensor(a)");
module.def("moe_a2a_get_aux_data_size(int ep_size, int max_num_tokens) -> int",
&torch_ext::moe_comm::moeA2AGetAuxDataSizeOp);
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, module)
Expand Down
54 changes: 51 additions & 3 deletions tensorrt_llm/_torch/distributed/moe_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
with proper workspace management and synchronization.
"""

# ruff: noqa: E501

import os
from dataclasses import dataclass
from typing import Dict, Optional

Expand All @@ -14,6 +17,7 @@
from tensorrt_llm.bindings import internal as _tllm_internal
from tensorrt_llm.logger import logger as tllm_logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import pad_up


@dataclass
Expand All @@ -36,6 +40,40 @@ class MoeAlltoAll:

_METAINFO_INDEX: Dict[str, int] | None = None

@staticmethod
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
return torch.ops.trtllm.moe_a2a_get_aux_data_size(
ep_size, max_num_tokens)

@staticmethod
def calculate_required_workspace_size(
ep_size: int,
top_k: int,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
extra_payload_bytes_per_token: int = 0) -> int:
element_size = dtype.itemsize
# Auxiliary data size
aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)

# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size,
# so we conservatively estimate assuming no quantization.
payload_size_dispatch = ep_size * max_num_tokens * (
hidden_size * element_size # (Unquantized) token hidden states
+ top_k * 4 # token_selected_experts
+ top_k * 4 # token_final_scales
+ extra_payload_bytes_per_token # extra payload bytes per token
)

# Required workspace for combine [ep_size, max_tokens] tokens
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size

# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
return pad_up(aux_size, 128) + pad_up(
payload_size_dispatch, 128) + pad_up(payload_size_combine, 128)

@classmethod
def _init_constants(cls):
"""Initialize constants from C++ if not already done."""
Expand Down Expand Up @@ -71,8 +109,7 @@ def __init__(
max_num_tokens: int,
top_k: int,
num_experts: int,
# TODO: WE should be able to know the required workspace size if knowing max_num_tokens, ep_size and hidden_size
workspace_size_per_rank: int = 256 * 1024 * 1024,
workspace_size_per_rank: int,
):
"""
Initialize MoeAlltoAll with workspace allocation.
Expand All @@ -82,6 +119,17 @@ def __init__(
max_num_tokens: Maximum number of tokens supported. Should be ModelConfig.max_num_tokens.
workspace_size_per_rank: Size of workspace per rank in bytes
"""
# Check for environment variable override
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
if workspace_mb_env:
workspace_size_env = int(workspace_mb_env) * 1024 * 1024
tllm_logger.warning(
f"Overriding automatically calculated workspace_size_per_rank ({workspace_size_per_rank} bytes) with "
f"TRTLLM_MOE_A2A_WORKSPACE_MB={workspace_mb_env} ({workspace_size_env} bytes)."
f"Automatically calculated workspace_size_per_rank is conservatively large, please only consider overriding it if you have a specific reason."
)
workspace_size_per_rank = workspace_size_env

# Initialize constants from C++
self._init_constants()

Expand All @@ -102,7 +150,7 @@ def __init__(

if self._WORKSPACE is None:
tllm_logger.info(
f"MoE AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
f"nvlink_one_sided AlltoAll: Allocating workspace with size {workspace_size_per_rank} bytes. ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, max_num_tokens: {self.max_num_tokens}"
)
mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def create_strategy(
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
weight_dtype = model_config.torch_dtype
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
moe_max_num_tokens = model_config.moe_max_num_tokens
Expand Down Expand Up @@ -126,8 +126,10 @@ def create_strategy(
mapping,
num_slots,
top_k,
max_num_tokens_per_rank=max_num_tokens,
payload_in_workspace=payload_in_workspace,
max_num_tokens,
payload_in_workspace,
hidden_size=hidden_size,
dtype=act_dtype,
)
logger.info("Selected communication strategy: NVLinkOneSided")
return strategy
Expand All @@ -149,13 +151,13 @@ def create_strategy(
logger.debug(f"NVLinkTwoSided not available: {e}")

# Try DeepEP (if enabled and weight dtype is bfloat16)
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1" and weight_dtype == torch.bfloat16:
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1" and act_dtype == torch.bfloat16:
try:
strategy = DeepEP(
mapping,
num_slots,
hidden_size,
weight_dtype,
act_dtype,
quant_config,
expert_size_per_partition,
use_cuda_graph,
Expand All @@ -171,7 +173,7 @@ def create_strategy(
mapping,
num_slots,
hidden_size,
weight_dtype,
act_dtype,
quant_config,
expert_size_per_partition,
max_num_tokens,
Expand Down Expand Up @@ -209,7 +211,7 @@ def _create_forced_method(
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
weight_dtype = model_config.torch_dtype
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
moe_max_num_tokens = model_config.moe_max_num_tokens
Expand All @@ -229,21 +231,21 @@ def _create_forced_method(
alltoall_result_do_sum=alltoall_result_do_sum,
)
elif method in ["NVLINK_ONE_SIDED"]:
# NVLinkOneSided requires max_num_tokens_per_rank
# max_num_tokens is per-rank value (as passed from callers like cutlass)
return NVLinkOneSided(
mapping,
num_slots,
top_k,
max_num_tokens_per_rank=max_num_tokens,
payload_in_workspace=payload_in_workspace,
max_num_tokens,
payload_in_workspace,
hidden_size=hidden_size,
dtype=act_dtype,
)
elif method == "DEEPEP":
return DeepEP(
mapping,
num_slots,
hidden_size,
weight_dtype,
act_dtype,
quant_config,
expert_size_per_partition,
use_cuda_graph,
Expand All @@ -253,7 +255,7 @@ def _create_forced_method(
mapping,
num_slots,
hidden_size,
weight_dtype,
act_dtype,
quant_config,
expert_size_per_partition,
max_num_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# ruff: noqa: E501


"""
NVLINK One-Sided AllToAll Communication Strategy

Expand All @@ -30,6 +33,7 @@
from tensorrt_llm.bindings import internal as _tllm_internal
from tensorrt_llm.logger import logger as tllm_logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import pad_up

from .base import Communication

Expand Down Expand Up @@ -63,6 +67,47 @@ class NVLinkOneSided(Communication):
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None
PAYLOAD_DATA_OFFSET_INDEX = None

@staticmethod
def get_aux_data_size(ep_size: int, max_num_tokens: int) -> int:
return torch.ops.trtllm.moe_a2a_get_aux_data_size(ep_size, max_num_tokens)

@staticmethod
def calculate_required_workspace_size(
ep_size: int,
top_k: int,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
extra_payload_bytes_per_token: int = 0,
) -> int:
element_size = dtype.itemsize
# Auxiliary data size
aux_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)

# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size,
# so we conservatively estimate assuming no quantization.
payload_size_dispatch = (
ep_size
* max_num_tokens
* (
hidden_size * element_size # (Unquantized) token hidden states
+ top_k * 4 # token_selected_experts
+ top_k * 4 # token_final_scales
+ extra_payload_bytes_per_token # extra payload bytes per token
)
)

# Required workspace for combine [ep_size, max_tokens] tokens
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size

# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
return (
pad_up(aux_size, 128)
+ pad_up(payload_size_dispatch, 128)
+ pad_up(payload_size_combine, 128)
)

@classmethod
def _init_constants(cls):
"""Initialize constants from C++ if not already done."""
Expand All @@ -86,9 +131,11 @@ def __init__(
self,
mapping: Mapping,
num_experts: int,
top_k: int = 1,
max_num_tokens_per_rank: Optional[int] = None,
top_k: int,
max_num_tokens_per_rank: int,
payload_in_workspace: bool = False,
hidden_size: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Initialize NVLinkOneSided with workspace allocation.
Expand All @@ -99,30 +146,53 @@ def __init__(
top_k: Number of experts per token
max_num_tokens_per_rank: Maximum number of tokens per rank (for workspace allocation)
payload_in_workspace: If True, final_hidden_states is already in workspace
hidden_size: Hidden dimension size (optional, for auto workspace calculation)
dtype: Data type (optional, for auto workspace calculation)
"""
super().__init__(mapping)

if self.mapping.world_size != self.ep_size:
raise RuntimeError("Currently NVLinkOneSided only supports pure EP for MoE.")

# Store needed parameters
self.num_experts = num_experts
self.top_k = top_k

self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.payload_in_workspace = payload_in_workspace

# Initialize constants from C++
self._init_constants()

# Get workspace size from environment variable (default 2048MB to match MoeAlltoAll)
workspace_mb = int(os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.workspace_size_per_rank = workspace_mb * 1024 * 1024
# Get workspace size
auto_workspace_size = None
if hidden_size is not None and dtype is not None:
auto_workspace_size = self.calculate_required_workspace_size(
self.ep_size, self.top_k, max_num_tokens_per_rank, hidden_size, dtype
)
workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB")
if workspace_mb_env:
self.workspace_size_per_rank = int(workspace_mb_env) * 1024 * 1024
msg = f"NVLinkOneSided: Forcing workspace size to {self.workspace_size_per_rank} bytes (TRTLLM_MOE_A2A_WORKSPACE_MB={workspace_mb_env})."
if auto_workspace_size is not None:
msg += f"Automatically calculated workspace size is {auto_workspace_size} bytes."
msg += "Auto calculation is conservative, so only consider overriding it if you have a specific reason."
tllm_logger.warning(msg)
elif auto_workspace_size is not None:
self.workspace_size_per_rank = auto_workspace_size
else:
tllm_logger.warning(
"NVLinkOneSided: hidden_size and dtype are not provided (which are required for calculating workspace size)."
"Using default workspace size 2048MB."
)
self.workspace_size_per_rank = 2048 * 1024 * 1024

# Initialize or reuse workspace
MnnvlMemory.initialize()

if self._WORKSPACE is None:
tllm_logger.info(
f"MoE AlltoAll: Allocating workspace with size {self.workspace_size_per_rank} bytes. "
f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, "
f"max_num_tokens_per_rank: {self.max_num_tokens_per_rank}"
f"NVLinkOneSided: Allocating workspace with size {self.workspace_size_per_rank} bytes."
f"ep_rank: {self.ep_rank}, ep_size: {self.ep_size}, top_k: {self.top_k}, max_num_tokens_per_rank: {self.max_num_tokens_per_rank}"
)
mnnvl_mem = MnnvlMemory(mapping, self.workspace_size_per_rank)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8)
Expand Down
14 changes: 11 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,22 @@ def __init__(
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
# Calculate required workspace size
ep_size = self.mapping.moe_ep_size
max_num_tokens = model_config.max_num_tokens
hidden_size = self.hidden_size
dtype = self.dtype or torch.float16

workspace_size = MoeAlltoAll.calculate_required_workspace_size(
ep_size, self.routing_method.experts_per_token,
max_num_tokens, hidden_size, dtype)

self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
workspace_size_per_rank=workspace_size,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
Expand Down
Loading
Loading