Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 125 additions & 17 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,109 @@
fp4_quantize,
mxfp8_quantize,
)
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
from flashinfer.fused_moe import (
trtllm_fp4_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from flashinfer.autotuner import autotune
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import device_support_pdl

FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FLOAT4_E2M1_MAX = 6.0


def fp8_quantize(x):
max = x.float().abs().nan_to_num().max()
scale = FLOAT8_E4M3_MAX / max
x = (x * scale).to(torch.float8_e4m3fn)
return x, 1.0 / scale
Comment on lines +23 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard FP8 quantization against all-zero inputs
All-zero inputs make max zero, scale infinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.

 def fp8_quantize(x):
-    max = x.float().abs().nan_to_num().max()
-    scale = FLOAT8_E4M3_MAX / max
-    x = (x * scale).to(torch.float8_e4m3fn)
-    return x, 1.0 / scale
+    max_val = x.float().abs().nan_to_num().max()
+    if max_val == 0:
+        return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0
+    scale = FLOAT8_E4M3_MAX / max_val
+    quantized = (x * scale).to(torch.float8_e4m3fn)
+    return quantized, 1.0 / scale
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 21 to 25, the
fp8_quantize function divides by max which can be zero for all-zero inputs;
guard against that by computing max = x.float().abs().nan_to_num().max(), then
check if max == 0 (or torch.isclose(max, torch.tensor(0., device=max.device)))
before inverting it; if it is zero, return the input cast to torch.float8_e4m3fn
(or an all-zero tensor of the same shape) and a safe inverse scale (e.g. 1.0),
otherwise compute scale = FLOAT8_E4M3_MAX / max and proceed with quantization
and return x.to(torch.float8_e4m3fn) and 1.0/scale.


def bench_trtllm_gen_fused_moe_autotuner(

def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
quant_mode: Literal["Fp8-Per-Tensor"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused parameter.

The quant_mode parameter is not used within the function body. If it's intended for future use or external validation, consider adding a comment explaining its purpose.

Apply this diff if the parameter is not needed:

 def bench_trtllm_gen_fused_moe_autotuner_fp8(
     tune_max_num_tokens: Optional[int],
-    quant_mode: Literal["Fp8-Per-Tensor"],
     num_tokens: int,
     num_experts: int,
     hidden_size: int,
     intermediate_size: int,
     top_k: int,
     warmups: int,
     iterations: int,
 ):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
quant_mode: Literal["Fp8-Per-Tensor"],
def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
num_tokens: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
warmups: int,
iterations: int,
):
🧰 Tools
🪛 Ruff (0.14.3)

29-29: Unused function argument: quant_mode

(ARG001)

🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around line 29, the
function signature includes an unused parameter quant_mode:
Literal["Fp8-Per-Tensor"]; remove this parameter from the signature and any
references to it, or if it is intentionally reserved for future use, keep it but
add a clear comment above the parameter explaining its purpose and why it is
unused (e.g., "reserved for future quantization modes"); update any call sites
if you remove it to avoid breaking callers.

num_tokens: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
warmups: int,
iterations: int,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)
hidden_states = torch.randn(num_tokens, hidden_size, device=device).to(
torch.bfloat16
)
w13 = torch.randn(
num_experts, intermediate_size * 2, hidden_size, device=device
).to(torch.bfloat16)
w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to(
torch.bfloat16
)

hidden_states, hidden_states_scale = fp8_quantize(hidden_states)
w13, w13_scale = fp8_quantize(w13)
w2, w2_scale = fp8_quantize(w2)

output1_scale_scalar = torch.tensor(
[hidden_states_scale * w13_scale] * num_experts, device=device
)
output1_scales_gate_scalar = torch.ones(
num_experts, device=device, dtype=torch.float32
)
output2_scale_scalar = torch.tensor(
[hidden_states_scale * w2_scale] * num_experts, device=device
)
Comment on lines +60 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Construct FP8 scale vectors without CPU conversion errors.

Lines 60-68 call torch.tensor([...], device=device) on CUDA scalars, which raises TypeError: can't convert CUDA tensor to numpy(). That stops the FP8 path before benchmarking. Build the vectors on device without Python lists.

Apply this diff to keep the scales on CUDA:

-    output1_scale_scalar = torch.tensor(
-        [hidden_states_scale * w13_scale] * num_experts, device=device
-    )
+    scale_prod_1 = (hidden_states_scale * w13_scale).item()
+    output1_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_1,
+        device=device,
+        dtype=torch.float32,
+    )
@@
-    output2_scale_scalar = torch.tensor(
-        [hidden_states_scale * w2_scale] * num_experts, device=device
-    )
+    scale_prod_2 = (hidden_states_scale * w2_scale).item()
+    output2_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_2,
+        device=device,
+        dtype=torch.float32,
+    )
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 60 to 68, the
code creates FP8 scale vectors using Python lists of CUDA scalars which triggers
"can't convert CUDA tensor to numpy()" on CUDA; replace those list constructions
with device-native tensor factories (e.g., use torch.full or torch.ones with
shape (num_experts,) and the desired dtype/device) to produce
output1_scale_scalar and output2_scale_scalar directly on the CUDA device (and
keep output1_scales_gate_scalar as torch.ones on device with correct dtype).


fn = lambda: trtllm_fp8_per_tensor_scale_moe(
routing_logits,
None, # routing_bias
hidden_states,
w13,
output1_scale_scalar,
output1_scales_gate_scalar,
w2,
output2_scale_scalar,
num_experts,
top_k,
None, # n_group
None, # topk_group
intermediate_size,
0, # local_expert_offset
num_experts,
1.0, # routed_scaling_factor
False, # use_routing_scales_on_input
None,
RoutingMethodType.TopK.value,
enable_pdl,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
)

def bench(do_autotune):
with autotune(do_autotune):
fn()
ms_list = bench_gpu_time(
fn,
dry_run_iters=warmups,
repeat_iters=iterations,
)
median_ms = np.median(ms_list)
return median_ms

ms = bench(do_autotune=False)
ms_tuned = bench(do_autotune=True)
print(
f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}"
)
print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms")


def bench_trtllm_gen_fused_moe_autotuner_fp4(
tune_max_num_tokens: Optional[int],
quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"],
num_tokens: int,
Expand Down Expand Up @@ -143,12 +239,11 @@ def bench_trtllm_gen_fused_moe_autotuner(
)

def bench(do_autotune):
# warmup
with autotune(do_autotune):
for _ in range(warmups):
fn()
fn()
ms_list = bench_gpu_time(
fn,
dry_run_iters=warmups,
repeat_iters=iterations,
)
median_ms = np.median(ms_list)
Expand All @@ -168,7 +263,7 @@ def bench(do_autotune):
"--quant-mode",
type=str,
default="MxFP4xMxFP8",
choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"],
choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", "Fp8-Per-Tensor"],
help="Quantization mode",
)
parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens")
Expand All @@ -193,14 +288,27 @@ def bench(do_autotune):
"--iterations", type=int, default=100, help="Number of benchmark iterations"
)
args = parser.parse_args()
bench_trtllm_gen_fused_moe_autotuner(
args.tune_max_num_tokens,
args.quant_mode,
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.warmups,
args.iterations,
)
if args.quant_mode == "Fp8-Per-Tensor":
bench_trtllm_gen_fused_moe_autotuner_fp8(
args.tune_max_num_tokens,
args.quant_mode,
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.warmups,
args.iterations,
)
else:
bench_trtllm_gen_fused_moe_autotuner_fp4(
args.tune_max_num_tokens,
args.quant_mode,
args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.warmups,
args.iterations,
)
19 changes: 16 additions & 3 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;

gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;

auto bmm = BatchedGemmInterface();

auto const configs = bmm.getBatchedGemmConfigs();
Expand Down Expand Up @@ -239,6 +243,10 @@ void TrtllmGenBatchedGemmRunner::run(
int32_t multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);

gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;

// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));

Expand Down Expand Up @@ -327,6 +335,10 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;

gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;

auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
auto const& optionsA = configs[idx0].mOptions;
auto const& optionsB = configs[idx1].mOptions;
Expand Down Expand Up @@ -387,8 +399,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
// Filter out invalid configs.
std::vector<int64_t> validConfigIndices;
for (auto const& configIndex : prioritizedIndices) {
auto const& config = configs[configIndex];
auto isValidConfig = bmm.isValidConfig(config, gemmData);
auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData);
if (isValidConfig) {
validConfigIndices.push_back(configIndex);
}
Expand Down Expand Up @@ -435,7 +446,9 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t

auto const& config = configs[configIndex];

return bmm.isValidConfig(config, gemmData);
// FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in
// trtllm-gen ac83afb
return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1;
}

} // namespace kernels
Expand Down
42 changes: 34 additions & 8 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_til
int64_t const num_tokens, int64_t const top_k,
int64_t const num_local_experts) {
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
// assume supported_tile_nums is sorted
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());

std::set<int32_t> selected_tile_nums = {
std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim,
std::min(supported_tile_nums.back(), tile_tokens_dim * 2),
std::min(supported_tile_nums.back(), tile_tokens_dim * 4)};
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);

std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}

Comment on lines 67 to 82
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against missing tile entry before iterating neighbors

tile_tokens_dim is clamped to the numeric range of supported_tile_nums, but the clamped value is not guaranteed to be present in the container. If the closest supported value is different (e.g., the list is {16, 24, 40, 64} and nextPowerOfTwo returns 32), std::find returns end(). The very next statement calls std::next(it), invoking undefined behaviour and potentially crashing the process. Please snap tile_tokens_dim to an actual supported entry (e.g., via std::lower_bound) before walking neighbours.

-  auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
-
-  std::set<int32_t> selected_tile_nums;
-  selected_tile_nums.insert(tile_tokens_dim);
+  auto it =
+      std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+  if (it == supported_tile_nums.end()) {
+    it = std::prev(supported_tile_nums.end());
+  }
+  tile_tokens_dim = *it;
+
+  std::set<int32_t> selected_tile_nums;
+  selected_tile_nums.insert(tile_tokens_dim);
   if (std::next(it) != supported_tile_nums.end()) {
     selected_tile_nums.insert(*std::next(it));
     if (std::next(std::next(it)) != supported_tile_nums.end()) {
       selected_tile_nums.insert(*std::next(std::next(it)));
     }
   }
   if (it != supported_tile_nums.begin()) {
     selected_tile_nums.insert(*std::prev(it));
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());
std::set<int32_t> selected_tile_nums = {
std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim,
std::min(supported_tile_nums.back(), tile_tokens_dim * 2),
std::min(supported_tile_nums.back(), tile_tokens_dim * 4)};
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());
auto it =
std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
it = std::prev(supported_tile_nums.end());
}
tile_tokens_dim = *it;
std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_kernel_launcher.cu around lines 67-82, the code uses
std::find on supported_tile_nums with a clamped tile_tokens_dim but that value
may not exist in the container and iterating neighbors from an end() iterator is
undefined; fix by snapping tile_tokens_dim to an actual supported entry via
std::lower_bound: use lower_bound to get the first element >= tile_tokens_dim,
if lower_bound == supported_tile_nums.end() set the iterator to std::prev(end())
(the largest supported value), otherwise keep the lower_bound result (or if you
want the closest neighbor compare the lower_bound and its previous element and
pick the nearer value), then set tile_tokens_dim = *it and proceed to insert *it
and its neighbors into selected_tile_nums while guarding with iterator boundary
checks (only call std::next/std::prev after verifying they are not
end()/begin()).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We should make sure supported_tile_nums contains consecutive pow of 2 numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IwakuraRein, I appreciate you acknowledging the issue! However, I notice that the current implementation already includes non-power-of-2 values in supported_tile_nums (e.g., 192 at line 256), so strictly enforcing "consecutive powers of 2" would require broader changes.

More importantly, even if we constrain supported_tile_nums to powers of 2, the current code still has a vulnerability: nextPowerOfTwo(avg_tokens_per_expert) can produce a power-of-2 value that's clamped to the numeric range but doesn't exist in the container. For example:

  • If supported_tile_nums = {8, 16, 64, 128} (missing 32)
  • And nextPowerOfTwo(...) returns 32
  • std::clamp(32, 8, 128) → 32
  • std::find returns end() → ❌ UB when calling std::next(it)

Recommendation:

I suggest implementing CodeRabbit's defensive fix using std::lower_bound, which gracefully handles any input:

auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
  it = std::prev(it);  // Use largest supported value
}
tile_tokens_dim = *it;  // Snap to actual supported value

This ensures robustness regardless of the supported tile set. If you want to enforce a design constraint (consecutive powers of 2), consider adding a debug assertion in the callers that construct supported_tile_nums, but the runtime logic should still be defensive.

Would you like me to open an issue to track adding this fix?

return selected_tile_nums;
}
Expand Down Expand Up @@ -369,7 +378,7 @@ void trtllm_fp8_per_tensor_scale_moe(
auto const hidden_size = hidden_states.size(1);
bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8

std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256};
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);

Expand Down Expand Up @@ -718,7 +727,7 @@ void trtllm_fp8_block_scale_moe(
auto const num_tokens = hidden_states.size(0);
auto const hidden_size = hidden_states.size(1);

std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);

Expand Down Expand Up @@ -1228,6 +1237,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
if (mDtypeAct != btg::Dtype::Bfloat16) {
mSupportedTileN.push_back(128);
}
if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) ||
(mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) {
// MxFP4 x MxFP4 or NvFP4 x NvFP4
mSupportedTileN.push_back(256);
}
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
// Build runners for all supported tile sizes
Expand Down Expand Up @@ -1305,8 +1319,20 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
bool is_fp8_per_tensor =
dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;

if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
if (useDeepSeekFp8) {
supported_tile_nums.push_back(128);
} else if (is_fp8_per_tensor) {
supported_tile_nums.push_back(128);
supported_tile_nums.push_back(192);
supported_tile_nums.push_back(256);
} else if (is_fp4_without_bf16_act) {
supported_tile_nums.push_back(128);
}

if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) ||
(dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) {
// MxFP4 x MxFP4 or NvFP4 x NvFP4
supported_tile_nums.push_back(256);
}
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);
Expand Down
38 changes: 30 additions & 8 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
// Compute the runtime config for projections
// Whether or not an expert is local is taken into account when smemExpertCount is computed
// so we do not need to take it into account here.
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);

int32_t numCta;
if constexpr (KernelParams::isPow2) {
numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
} else {
numCta = divUpTileN<int32_t>(count, params.mTileTokensDim);
}

int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
Expand All @@ -401,14 +408,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
const int32_t localExpertIdx =
(threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] =
min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
int32_t mnLimit1;
int32_t mnLimit2;
if constexpr (KernelParams::isPow2) {
mnLimit1 = mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2);
mnLimit2 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count;
} else {
mnLimit1 = mulTileN<int32_t>(ctaOffset + cta + 1, params.mTileTokensDim);
mnLimit2 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim) + count;
}
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2);
}

// get the padded offset associated with this expert
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
int32_t offset;
if constexpr (KernelParams::isPow2) {
offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
} else {
offset = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim);
}
int32_t permutedIdxSize;
if constexpr (KernelParams::isPow2) {
permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
} else {
permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
}

// write out padded count
if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) {
Expand Down Expand Up @@ -542,8 +566,6 @@ void runImpl(Data& data, void* stream) {
}
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d",
data.mPaddingLog2);

int const numBlocks = data.mNumTokens;
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
Expand Down
Loading