From 0b8d13c8b5c34edcc363feb934ac85b320407685 Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Wed, 23 Jul 2025 01:34:22 -0700 Subject: [PATCH 1/4] [fix] Fix perf regression caused by MoE autotuner when using DeepEPLowLatency Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 37 +++++++++++++++---- .../modules/fused_moe/fused_moe_wide_ep.py | 13 +++++++ tensorrt_llm/_torch/utils.py | 2 +- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 60ef215fe38..d3889eb8971 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -39,7 +39,6 @@ def __init__( ep_rank: int, cluster_size: int, cluster_rank: int, - enable_alltoall: bool, use_deepseek_fp8_block_scale: bool, use_w4a8_group_scaling: bool, use_mxfp8_act_scaling: bool, @@ -55,7 +54,8 @@ def __init__( self.ep_rank = ep_rank self.cluster_size = cluster_size self.cluster_rank = cluster_rank - self.enable_alltoall = enable_alltoall + # The best tactic is estimated as if alltoall is disabled + self.enable_alltoall = False self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale self.use_w4a8_group_scaling = use_w4a8_group_scaling self.use_mxfp8_act_scaling = use_mxfp8_act_scaling @@ -141,24 +141,45 @@ def fused_moe( use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, + total_valid_tokens: Optional[int] = None, + original_top_k: Optional[int] = None, ) -> List[torch.Tensor]: tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) + if enable_alltoall: + assert total_valid_tokens is not None + assert original_top_k is not None + if input.shape[0] >= total_valid_tokens: + tuner_input = input[:total_valid_tokens].contiguous() + else: + tuner_input = torch.cat([ + input, + torch.zeros(total_valid_tokens - input.shape[0], + *input.shape[1:], + dtype=input.dtype, + device=input.device) + ]) + tuner_top_k = original_top_k + else: + assert total_valid_tokens is None + assert original_top_k is None + tuner_input = input + tuner_top_k = token_selected_experts.size(1) + # allocate workspace for profiling moe_runner = MoERunner( x_dtype=input.dtype, weight_dtype=fc1_expert_weights.dtype, output_dtype=output_dtype, - top_k=token_selected_experts.size(1), + top_k=tuner_top_k, tp_size=tp_size, tp_rank=tp_rank, ep_size=ep_size, ep_rank=ep_rank, cluster_size=cluster_size, cluster_rank=cluster_rank, - enable_alltoall=enable_alltoall, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4a8_group_scaling=use_w4a8_group_scaling, use_mxfp8_act_scaling=use_mxfp8_act_scaling, @@ -170,8 +191,8 @@ def fused_moe( [moe_runner], MoERunner.tuning_config, [ - input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, - fc2_expert_biases + tuner_input, fc1_expert_weights, fc1_expert_biases, + fc2_expert_weights, fc2_expert_biases ], gemm_idx=1, ) @@ -181,8 +202,8 @@ def fused_moe( [moe_runner], MoERunner.tuning_config, [ - input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, - fc2_expert_biases + tuner_input, fc1_expert_weights, fc1_expert_biases, + fc2_expert_weights, fc2_expert_biases ], gemm_idx=2, ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 81778c28544..087f51619a5 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -437,6 +437,17 @@ def forward_chunk( # If alltoall is disabled, we need also disable use_postquant_alltoall use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all + + # Prepare additional information for profiling in case padding is applied in all-to-all + if use_all_to_all: + if all_rank_num_tokens is not None: + total_valid_tokens = sum(all_rank_num_tokens) + else: + total_valid_tokens = x.shape[0] * self.mapping.tp_size + original_top_k = token_selected_slots.shape[1] + else: + total_valid_tokens = None + original_top_k = None if use_all_to_all: if self.alltoall_method_type == AlltoallMethodType.MNNVL: if self.enable_dummy_allreduce: @@ -669,6 +680,8 @@ def forward_chunk( use_w4a8_group_scaling=use_w4a8_group_scaling, min_latency_mode=False, tune_max_num_tokens=self.tune_max_num_tokens, + total_valid_tokens=total_valid_tokens, + original_top_k=original_top_k, ) if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 5710dbdc6ae..4307bbaac57 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -239,7 +239,7 @@ def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: while m >= 1: num_token_buckets.append(m) m //= 2 - return tuple(num_token_buckets) + return tuple(num_token_buckets[::-1]) def fp4_scale_infer_shape(input_shapes: List[List[int]]): From 21774380884a429bae919c65dffdf76f50e51a93 Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Fri, 25 Jul 2025 00:53:44 -0700 Subject: [PATCH 2/4] Minor modifications Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 26 ++++++------------- .../modules/fused_moe/fused_moe_wide_ep.py | 14 +++++----- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index d3889eb8971..a55e9ce4f6d 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -88,7 +88,7 @@ def forward( ): x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs self.fused_moe_runner.run_gemm_profile( - x, + x.contiguous(), fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, @@ -141,30 +141,20 @@ def fused_moe( use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, - total_valid_tokens: Optional[int] = None, - original_top_k: Optional[int] = None, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, ) -> List[torch.Tensor]: tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) if enable_alltoall: - assert total_valid_tokens is not None - assert original_top_k is not None - if input.shape[0] >= total_valid_tokens: - tuner_input = input[:total_valid_tokens].contiguous() - else: - tuner_input = torch.cat([ - input, - torch.zeros(total_valid_tokens - input.shape[0], - *input.shape[1:], - dtype=input.dtype, - device=input.device) - ]) - tuner_top_k = original_top_k + assert tuner_num_tokens is not None + assert tuner_top_k is not None + tuner_input = input[:tuner_num_tokens] else: - assert total_valid_tokens is None - assert original_top_k is None + assert tuner_num_tokens is None + assert tuner_top_k is None tuner_input = input tuner_top_k = token_selected_experts.size(1) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 087f51619a5..06b0d66cd2f 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -441,13 +441,13 @@ def forward_chunk( # Prepare additional information for profiling in case padding is applied in all-to-all if use_all_to_all: if all_rank_num_tokens is not None: - total_valid_tokens = sum(all_rank_num_tokens) + tuner_num_tokens = sum(all_rank_num_tokens) else: - total_valid_tokens = x.shape[0] * self.mapping.tp_size - original_top_k = token_selected_slots.shape[1] + tuner_num_tokens = x.shape[0] * self.mapping.tp_size + tuner_top_k = token_selected_slots.shape[1] else: - total_valid_tokens = None - original_top_k = None + tuner_num_tokens = None + tuner_top_k = None if use_all_to_all: if self.alltoall_method_type == AlltoallMethodType.MNNVL: if self.enable_dummy_allreduce: @@ -680,8 +680,8 @@ def forward_chunk( use_w4a8_group_scaling=use_w4a8_group_scaling, min_latency_mode=False, tune_max_num_tokens=self.tune_max_num_tokens, - total_valid_tokens=total_valid_tokens, - original_top_k=original_top_k, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, ) if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( From 3e946754e583391077b5d9e95e035a4563c1f53a Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Fri, 25 Jul 2025 01:55:35 -0700 Subject: [PATCH 3/4] Minor modification Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 4307bbaac57..15f8e634a58 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -229,7 +229,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: num_token_buckets.append(m) m //= 2 - return tuple(num_token_buckets) + return tuple(num_token_buckets[::-1]) def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: From d7f5461c40fdaa8d3ec56087bb51bbf13b489c4b Mon Sep 17 00:00:00 2001 From: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Date: Fri, 25 Jul 2025 19:22:01 -0700 Subject: [PATCH 4/4] Minor modifications Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py | 4 +++- tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index a55e9ce4f6d..e9e0bb91331 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -88,7 +88,7 @@ def forward( ): x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs self.fused_moe_runner.run_gemm_profile( - x.contiguous(), + x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, @@ -148,6 +148,8 @@ def fused_moe( tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) + # Only the non-alltoall case is considered for profiling in the warmup phase. + # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. if enable_alltoall: assert tuner_num_tokens is not None assert tuner_top_k is not None diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 06b0d66cd2f..cb711486b63 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -438,7 +438,9 @@ def forward_chunk( # If alltoall is disabled, we need also disable use_postquant_alltoall use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all - # Prepare additional information for profiling in case padding is applied in all-to-all + # Prepare additional information for profiling in case padding is applied when using alltoall. + # Only the non-alltoall case is considered for profiling in the warmup phase. + # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. if use_all_to_all: if all_rank_num_tokens is not None: tuner_num_tokens = sum(all_rank_num_tokens)