Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
ep_rank: int,
cluster_size: int,
cluster_rank: int,
enable_alltoall: bool,
use_deepseek_fp8_block_scale: bool,
use_w4a8_group_scaling: bool,
use_mxfp8_act_scaling: bool,
Expand All @@ -55,7 +54,8 @@ def __init__(
self.ep_rank = ep_rank
self.cluster_size = cluster_size
self.cluster_rank = cluster_rank
self.enable_alltoall = enable_alltoall
# The best tactic is estimated as if alltoall is disabled
self.enable_alltoall = False
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_w4a8_group_scaling = use_w4a8_group_scaling
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
Expand Down Expand Up @@ -141,24 +141,37 @@ def fused_moe(
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
tune_max_num_tokens: int = 8192,
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
) -> List[torch.Tensor]:

tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

# Only the non-alltoall case is considered for profiling in the warmup phase.
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall.
if enable_alltoall:
assert tuner_num_tokens is not None
assert tuner_top_k is not None
tuner_input = input[:tuner_num_tokens]
else:
assert tuner_num_tokens is None
assert tuner_top_k is None
tuner_input = input
tuner_top_k = token_selected_experts.size(1)

# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
weight_dtype=fc1_expert_weights.dtype,
output_dtype=output_dtype,
top_k=token_selected_experts.size(1),
top_k=tuner_top_k,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
cluster_size=cluster_size,
cluster_rank=cluster_rank,
enable_alltoall=enable_alltoall,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
Expand All @@ -170,8 +183,8 @@ def fused_moe(
[moe_runner],
MoERunner.tuning_config,
[
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
fc2_expert_biases
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
gemm_idx=1,
)
Expand All @@ -181,8 +194,8 @@ def fused_moe(
[moe_runner],
MoERunner.tuning_config,
[
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
fc2_expert_biases
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
gemm_idx=2,
)
Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,19 @@ def forward_chunk(

# If alltoall is disabled, we need also disable use_postquant_alltoall
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all

# Prepare additional information for profiling in case padding is applied when using alltoall.
# Only the non-alltoall case is considered for profiling in the warmup phase.
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall.
if use_all_to_all:
if all_rank_num_tokens is not None:
tuner_num_tokens = sum(all_rank_num_tokens)
else:
tuner_num_tokens = x.shape[0] * self.mapping.tp_size
tuner_top_k = token_selected_slots.shape[1]
else:
tuner_num_tokens = None
tuner_top_k = None
if use_all_to_all:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.enable_dummy_allreduce:
Expand Down Expand Up @@ -669,6 +682,8 @@ def forward_chunk(
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=False,
tune_max_num_tokens=self.tune_max_num_tokens,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
)

if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
num_token_buckets.append(m)
m //= 2

return tuple(num_token_buckets)
return tuple(num_token_buckets[::-1])


def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
Expand All @@ -239,7 +239,7 @@ def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
while m >= 1:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets)
return tuple(num_token_buckets[::-1])


def fp4_scale_infer_shape(input_shapes: List[List[int]]):
Expand Down