Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __init__(
self.config = config # Store config as instance variable
pretrained_config = config.pretrained_config
self.num_experts = pretrained_config.num_local_experts
moe_load_balancer_config = config.moe_load_balancer
self.num_slots = moe_load_balancer_config.num_slots if moe_load_balancer_config and moe_load_balancer_config.num_slots else self.num_experts

self.layer_idx = layer_idx
self.enable_attention_dp = config.mapping.enable_attention_dp
self.mapping = config.mapping
Expand All @@ -162,13 +165,13 @@ def __init__(
if config.moe_backend.upper() == "TRTLLM" else torch.float32)

self.swiglu_alpha = torch.tensor(
[1.702] * (self.num_experts // config.mapping.moe_ep_size),
[1.702] * (self.num_slots // config.mapping.moe_ep_size),
dtype=torch.float32).cuda()
self.swiglu_beta = torch.tensor(
[1.0] * (self.num_experts // config.mapping.moe_ep_size),
[1.0] * (self.num_slots // config.mapping.moe_ep_size),
dtype=torch.float32).cuda()
self.swiglu_limit = torch.tensor(
[7.0] * (self.num_experts // config.mapping.moe_ep_size),
[7.0] * (self.num_slots // config.mapping.moe_ep_size),
dtype=torch.float32).cuda()
# Prepare MoE creation parameters
moe_params = {
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def create_moe(

moe_load_balancer = get_moe_load_balancer()
if moe_load_balancer is not None:
assert moe_cls == WideEPMoE, "MoE Load Balance is only supported in WideEPMoE now."
assert moe_cls in [
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now."

if bias:
assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE
Expand All @@ -106,6 +108,7 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
bias=bias,
layer_idx=layer_idx,
Expand Down
13 changes: 7 additions & 6 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def __init__(
)

def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: tuple = (True, True),
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
Expand Down
160 changes: 101 additions & 59 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Large diffs are not rendered by default.

80 changes: 63 additions & 17 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from ...custom_ops.trtllm_gen_custom_ops import \
fp4_block_scale_fake_output_without_finalize
from ...distributed import allgather
from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor, ceil_div
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
NVFP4TRTLLMGenFusedMoEMethod,
Expand All @@ -37,6 +38,7 @@ class TRTLLMGenFusedMoE(MoE):
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.

MoE torch custom op:
Only support min-latency mode now (SM100 Blackwell only).
Expand Down Expand Up @@ -66,6 +68,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream_dict: Optional[Dict[AuxStreamType,
torch.cuda.Stream]] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
layer_idx: Optional[int] = None,
Expand All @@ -82,6 +86,7 @@ def __init__(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream_dict=aux_stream_dict,
weight_loading_mode=weight_loading_mode,
bias=bias,
swiglu_alpha=swiglu_alpha,
Expand All @@ -97,19 +102,11 @@ def __init__(

assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."

self.num_slots = self.num_experts
self.expert_size_per_partition = self.num_experts // self.ep_size
self.initial_global_assignments = [
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
self.num_experts for ep_rank in range(self.ep_size)
for local_slot_id in range(self.expert_size_per_partition)
]
self.slot_start = self.ep_rank * self.expert_size_per_partition
self.slot_end = self.slot_start + self.expert_size_per_partition
self.initial_local_expert_ids = self.initial_global_assignments[
self.slot_start:self.slot_end]
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
# Note: Load balancer initialization is handled by base class _init_load_balancer()
# If no load balancer is available, the base class will set:
# - self.num_slots = self.num_experts
# - self.expert_size_per_partition = self.num_experts // self.ep_size
# - self.initial_global_assignments, self.slot_start, self.slot_end, etc.

# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
self.alltoall_method_type = self.select_alltoall_method_type()
Expand All @@ -136,7 +133,7 @@ def __init__(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_experts,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
)
else:
Expand Down Expand Up @@ -183,6 +180,10 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:

return AlltoallMethodType.MNNVL

def _supports_load_balancer(self) -> bool:
"""TRTLLMGenFusedMoE supports load balancer."""
return True

@cached_property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
Expand Down Expand Up @@ -340,14 +341,39 @@ def forward_impl(
x_col = x.shape[1]
token_count = x.shape[0]
alltoall_info = None
# Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking)
is_first_call = self.repeat_idx == 0
is_last_call = self.repeat_idx == self.repeat_count - 1

if post_quant_comm:
# Start GPU stage for first call
self._load_balancer_start_wait_gpu_stage(is_first_call)
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
token_selected_experts = token_selected_experts.to(torch.int32)
if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)

self._load_balancer_done_wait_gpu_stage(is_first_call)

ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
is_last_call,
ignore_allreduce=ignore_allreduce)

# Route tokens to slots
token_selected_slots = self._load_balancer_route(
token_selected_experts, self.use_dp)

# Update expert statistics
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)

# Use routed slots for subsequent processing
token_selected_experts = token_selected_slots

x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)

if self.enable_alltoall:
Expand All @@ -364,9 +390,14 @@ def forward_impl(

if self.moe_alltoall_backend == "mnnvllatency":
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
if is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
)
else:
loadbalancer_local_statistic_info = None
alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_experts,
None,
loadbalancer_local_statistic_info,
self.alltoall_prepare_workspace,
runtime_max_tokens_per_rank,
self.ep_rank,
Expand All @@ -375,6 +406,11 @@ def forward_impl(
self.num_slots,
top_k,
)
if gathered_loadbalancer_local_statistic_info is not None:
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
(self.mapping.moe_ep_size, self.num_experts))
self._load_balancer_update_statistic_with_gathered_statistic(
gathered_loadbalancer_local_statistic_info)

if x_sf is not None:
x_sf = x_sf.view(x_row,
Expand Down Expand Up @@ -716,6 +752,9 @@ def forward_impl(
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
)

# Handle load balancer CPU stage if needed
self._load_balancer_start_set_cpu_stage(is_last_call)

# Combine results if using alltoall
if self.enable_alltoall:
if self.moe_alltoall_backend == "mnnvllatency":
Expand Down Expand Up @@ -763,10 +802,17 @@ def forward_impl(
use_dp_padding=use_dp_padding,
)

self._load_balancer_done_set_cpu_stage(is_last_call)

if use_dp_padding:
rank = self.mapping.tp_rank
final_hidden_states = final_hidden_states[:
all_rank_num_tokens[rank]]

# Update repeat index for load balancer
if self.layer_load_balancer:
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1

return final_hidden_states

def forward_fake(
Expand Down
Loading