Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fda3ec0
Using ActivationType instead of GatedActType, added compiled kernels,…
amitz-nv Feb 1, 2026
9e7d911
Add actType and eltwiseActType to 'no kernel found' message, move is_…
amitz-nv Feb 1, 2026
ce704d1
Update remaining GatedActType uses to ActivationType, remove GatedAct…
amitz-nv Feb 1, 2026
954eaeb
Use ActivationType in benchmarks, add missing activation_type argument
amitz-nv Feb 1, 2026
158b8fd
Minor fixes
amitz-nv Feb 1, 2026
3a6b9f4
Fix activation_type default value to Swiglu on trtllm_fp4_block_scale…
amitz-nv Feb 1, 2026
02a2502
Minor improvement
amitz-nv Feb 1, 2026
7b17ca1
Support non-gated activation in NVFP4 block scale MoE
amitz-nv Feb 1, 2026
5d0fa51
Rename useShuffledMatrixA to useShuffledMatrix (remove the 'A' suffix)
amitz-nv Feb 1, 2026
b74a7f1
Add FP4_NVFP4_NVFP4 parameterization to test_llama4_routing, update t…
amitz-nv Feb 1, 2026
52e0828
Increase supported topK and num experts in deepseek routing for nemotron
amitz-nv Feb 1, 2026
23348b2
Commit more files for increase supported topK and num experts in deep…
amitz-nv Feb 1, 2026
62d0489
Fix formatting
amitz-nv Feb 1, 2026
6c0409a
Change TODO to comment
amitz-nv Feb 1, 2026
ea95fb0
Change default activation_type to Swiglu
amitz-nv Feb 1, 2026
b7bbb7f
Restore intermediate size factor of 2 for gated activation in getWork…
amitz-nv Feb 1, 2026
8204fb5
Formatting fixes
amitz-nv Feb 1, 2026
3e0b77c
Treat SwigluBias as gated activation
amitz-nv Feb 1, 2026
cbf66c5
Fix use of ActivationType enum in CLI
amitz-nv Feb 1, 2026
e370ab2
Fix activation-type command line argument handling in benchmarks
amitz-nv Feb 1, 2026
c114994
Fix choices of activation-type command line argument handling in benc…
amitz-nv Feb 1, 2026
1b0b5f7
GEMM (non batched) still has mUseShuffledMatrixA member (with 'A' suf…
amitz-nv Feb 1, 2026
bf88c7b
Update bench_trtllm_gen_fused_moe_autotuner.py to support more activa…
amitz-nv Feb 1, 2026
f5ac485
Revert activation_Type check in bench_trtllm_gen_fused_moe_autotuner.…
amitz-nv Feb 1, 2026
370579c
Include activation type in results in benchmarks/routings/moe.py
amitz-nv Feb 1, 2026
f7c2df5
Remove bad num experts check in csrc/trtllm_fused_moe_routing_deepsee…
amitz-nv Feb 2, 2026
9b69e42
Skip test cases with unnecessary parameterization combinations
amitz-nv Feb 2, 2026
b4edcfe
Fix ignoring compatible_activation_types in test when it's not defined
amitz-nv Feb 2, 2026
d5887ab
Fix data.mTopK value check in deepseek routing according to the relev…
amitz-nv Feb 2, 2026
67c0ea4
Add topK<=numExperts check to deepseek routing
amitz-nv Feb 3, 2026
df1ae03
Minor fix of passing activation_type in test_trtllm_gen_fused_moe.py
amitz-nv Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from flashinfer import (
RoutingMethodType,
GatedActType,
ActivationType,
fp4_quantize,
mxfp8_quantize,
)
Expand All @@ -17,6 +17,7 @@
from flashinfer.autotuner import autotune
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import device_support_pdl
from routines.flashinfer_benchmark_utils import enum_type

FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FLOAT4_E2M1_MAX = 6.0
Expand All @@ -39,6 +40,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -97,6 +99,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
)

if is_block_scale:
if activation_type != ActivationType.Swiglu:
raise ValueError(
"Only Swiglu activation is supported for FP8 block scale MoE."
)
fn = lambda: trtllm_fp8_block_scale_moe(
routing_logits,
routing_bias,
Expand Down Expand Up @@ -144,6 +150,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
RoutingMethodType.TopK.value,
enable_pdl,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
activation_type.value,
)

def bench(do_autotune):
Expand Down Expand Up @@ -175,6 +182,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -234,6 +242,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
w13_global_scale = 1.0 / 448.0 / 6.0
w2_global_scale = 1.0 / 448.0 / 6.0
else:
if activation_type == ActivationType.Relu2:
raise ValueError(
"Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode"
)
w13, w13_scale = fp4_quantize(
w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True
)
Expand Down Expand Up @@ -288,7 +300,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
RoutingMethodType.Renormalize.value,
True,
enable_pdl,
GatedActType.SwiGlu.value, # gated_act_type
activation_type.value, # act_type
None,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
)
Expand Down Expand Up @@ -348,6 +360,14 @@ def bench(do_autotune):
parser.add_argument(
"--iterations", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
required=False,
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
)
args = parser.parse_args()
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]:
bench_trtllm_gen_fused_moe_autotuner_fp8(
Expand All @@ -360,6 +380,7 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
else:
bench_trtllm_gen_fused_moe_autotuner_fp4(
Expand All @@ -372,4 +393,5 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
16 changes: 16 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import torch

from flashinfer.testing.utils import set_seed
Expand Down Expand Up @@ -453,3 +454,18 @@ def filter_backends_by_compute_capability(backends, routine, device):
f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping."
)
return backends


def enum_type(enum_class):
"""Generic factory for argparse enum types."""

def converter(value):
try:
lower_name_to_member = {m.name.lower(): m for m in enum_class}
return lower_name_to_member[value.lower()]
except KeyError as e:
raise argparse.ArgumentTypeError(
f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}"
) from e

return converter
29 changes: 14 additions & 15 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import flashinfer
from flashinfer import ActivationType
from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
trtllm_fp4_block_scale_moe,
Expand All @@ -21,6 +22,7 @@

from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
enum_type,
get_device,
print_perf_metrics,
filter_backends_by_compute_capability,
Expand Down Expand Up @@ -170,12 +172,12 @@ def parse_moe_args(line, parser):
help="Data type of the weights (before quantization).",
)
parser.add_argument(
"--gated_act",
type=str,
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
required=False,
default="swiglu",
choices=["swiglu", "geglu"],
help="Type of gated activation function: swiglu | geglu.",
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
)
parser.add_argument(
"--autotune",
Expand Down Expand Up @@ -242,13 +244,6 @@ def parse_moe_args(line, parser):
}
args.routing_method_type = routing_method_name_to_type[args.routing_method]

# Normalize gated act type (map string to internal int expected by kernels)
gated_act_name_to_type = {
"swiglu": 0,
"geglu": 1,
}
args.gated_act_type = gated_act_name_to_type[args.gated_act]

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -451,7 +446,7 @@ def testTrtllmFp4BlockScaleMoe(args):
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
gated_act_type = args.gated_act_type
activation_type = args.activation_type
res = []

backends = ["trtllm"]
Expand Down Expand Up @@ -610,7 +605,7 @@ def run_fp4_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
gated_act_type=gated_act_type,
activation_type=activation_type.value,
do_finalize=True,
)

Expand Down Expand Up @@ -715,7 +710,7 @@ def run_fp4_moe(
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["gated_act"] = args.gated_act
cur_res["activation_type"] = args.activation_type.name
res.append(cur_res)

return res
Expand Down Expand Up @@ -1471,6 +1466,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
activation_type,
):
# Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
Expand All @@ -1493,6 +1489,7 @@ def run_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
)

# Benchmark timing
Expand All @@ -1513,6 +1510,7 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
args.activation_type,
),
)

Expand Down Expand Up @@ -1564,6 +1562,7 @@ def run_fp8_per_tensor_moe(
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["activation_type"] = args.activation_type.name
res.append(cur_res)

return res
10 changes: 8 additions & 2 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize &&
options.mUseShuffledMatrix == mOptions.useShuffledMatrixA &&
tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix &&
options.mLayoutA == mOptions.weightLayout) {
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
continue;
}
}
if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) {
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
mPassingConfigIndices.push_back(i);
Expand All @@ -122,6 +124,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
<< ", mActType: " << (int64_t)mOptions.actType
<< ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
Expand Down Expand Up @@ -219,6 +223,8 @@ void TrtllmGenBatchedGemmRunner::run(
gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB;
gemmData.mInputBuffers.mPtrScaleC = scaleC;
gemmData.mInputBuffers.mPtrScaleGate = scaleGateC;
// For simplicity pass set scaleAct to scaleGateC
gemmData.mInputBuffers.mPtrScaleAct = scaleGateC;
Comment on lines +226 to +227
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment "For simplicity pass set scaleAct to scaleGateC" suggests this might be a temporary solution. While this might work for the current set of activation functions (e.g., if Relu2 doesn't use mPtrScaleAct), it could lead to latent bugs if new element-wise activations are added that require a specific scaleAct value different from scaleGateC.

To improve clarity and prevent future issues, consider passing scaleAct as a separate parameter to the run function and setting mPtrScaleAct accordingly. If scaleGateC is indeed the correct value for all cases, a more detailed comment explaining why would be beneficial.

gemmData.mInputBuffers.mPtrPerTokenSfA =
mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA;
gemmData.mInputBuffers.mPtrPerTokenSfB =
Expand Down
Loading
Loading