Skip to content
7 changes: 2 additions & 5 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
from flashinfer.autotuner import autotune
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import device_support_pdl, calculate_tile_tokens_dim
from flashinfer.utils import device_support_pdl


def bench_trtllm_gen_fused_moe_autotuner(
Expand Down Expand Up @@ -99,9 +99,6 @@ def bench_trtllm_gen_fused_moe_autotuner(
bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10

tile_tokens_dim = calculate_tile_tokens_dim(
num_tokens, num_experts, top_k, 64 if quant_mode == "MxFP4xBf16" else 128
)
output1_scale_scalar = torch.tensor(
[hidden_states_global_scale * w13_global_scale] * num_experts, device=device
)
Expand Down Expand Up @@ -136,7 +133,7 @@ def bench_trtllm_gen_fused_moe_autotuner(
0, # local_expert_offset
num_experts,
None, # routed_scaling_factor
tile_tokens_dim,
None, # tile_tokens_dim
RoutingMethodType.Renormalize.value,
True,
enable_pdl,
Expand Down
Loading
Loading