Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from flashinfer import (
RoutingMethodType,
ActivationType,
GatedActType,
fp4_quantize,
mxfp8_quantize,
)
Expand All @@ -17,7 +17,6 @@
from flashinfer.autotuner import autotune
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import device_support_pdl
from routines.flashinfer_benchmark_utils import enum_type

FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FLOAT4_E2M1_MAX = 6.0
Expand All @@ -40,7 +39,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -99,10 +97,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
)

if is_block_scale:
if activation_type != ActivationType.Swiglu:
raise ValueError(
"Only Swiglu activation is supported for FP8 block scale MoE."
)
fn = lambda: trtllm_fp8_block_scale_moe(
routing_logits,
routing_bias,
Expand Down Expand Up @@ -150,7 +144,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
RoutingMethodType.TopK.value,
enable_pdl,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
activation_type.value,
)

def bench(do_autotune):
Expand Down Expand Up @@ -182,7 +175,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
top_k: int,
warmups: int,
iterations: int,
activation_type: ActivationType,
):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
Expand Down Expand Up @@ -242,10 +234,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
w13_global_scale = 1.0 / 448.0 / 6.0
w2_global_scale = 1.0 / 448.0 / 6.0
else:
if activation_type == ActivationType.Relu2:
raise ValueError(
"Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode"
)
w13, w13_scale = fp4_quantize(
w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True
)
Expand Down Expand Up @@ -300,7 +288,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
RoutingMethodType.Renormalize.value,
True,
enable_pdl,
activation_type.value, # act_type
GatedActType.SwiGlu.value, # gated_act_type
None,
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
)
Expand Down Expand Up @@ -360,14 +348,6 @@ def bench(do_autotune):
parser.add_argument(
"--iterations", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
required=False,
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
)
args = parser.parse_args()
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]:
bench_trtllm_gen_fused_moe_autotuner_fp8(
Expand All @@ -380,7 +360,6 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
else:
bench_trtllm_gen_fused_moe_autotuner_fp4(
Expand All @@ -393,5 +372,4 @@ def bench(do_autotune):
args.top_k,
args.warmups,
args.iterations,
args.activation_type,
)
16 changes: 0 additions & 16 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import torch

from flashinfer.testing.utils import set_seed
Expand Down Expand Up @@ -454,18 +453,3 @@ def filter_backends_by_compute_capability(backends, routine, device):
f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping."
)
return backends


def enum_type(enum_class):
"""Generic factory for argparse enum types."""

def converter(value):
try:
lower_name_to_member = {m.name.lower(): m for m in enum_class}
return lower_name_to_member[value.lower()]
except KeyError as e:
raise argparse.ArgumentTypeError(
f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}"
) from e

return converter
29 changes: 15 additions & 14 deletions benchmarks/routines/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch

import flashinfer
from flashinfer import ActivationType
from flashinfer.autotuner import autotune
from flashinfer.fused_moe import (
trtllm_fp4_block_scale_moe,
Expand All @@ -22,7 +21,6 @@

from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
enum_type,
get_device,
print_perf_metrics,
filter_backends_by_compute_capability,
Expand Down Expand Up @@ -172,12 +170,12 @@ def parse_moe_args(line, parser):
help="Data type of the weights (before quantization).",
)
parser.add_argument(
"--activation-type",
type=enum_type(ActivationType),
metavar=str([e.name for e in ActivationType]),
"--gated_act",
type=str,
required=False,
default=ActivationType.Swiglu,
help=f"Type of activation function: {[e.name for e in ActivationType]}",
default="swiglu",
choices=["swiglu", "geglu"],
help="Type of gated activation function: swiglu | geglu.",
)
parser.add_argument(
"--autotune",
Expand Down Expand Up @@ -244,6 +242,13 @@ def parse_moe_args(line, parser):
}
args.routing_method_type = routing_method_name_to_type[args.routing_method]

# Normalize gated act type (map string to internal int expected by kernels)
gated_act_name_to_type = {
"swiglu": 0,
"geglu": 1,
}
args.gated_act_type = gated_act_name_to_type[args.gated_act]

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Expand Down Expand Up @@ -446,7 +451,7 @@ def testTrtllmFp4BlockScaleMoe(args):
use_shuffled_weight = args.use_shuffled_weight
weight_layout = args.weight_layout
is_cuda_graph_compatible = not args.no_cuda_graph
activation_type = args.activation_type
gated_act_type = args.gated_act_type
res = []

backends = ["trtllm"]
Expand Down Expand Up @@ -605,7 +610,7 @@ def run_fp4_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
gated_act_type=gated_act_type,
do_finalize=True,
)

Expand Down Expand Up @@ -710,7 +715,7 @@ def run_fp4_moe(
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["activation_type"] = args.activation_type.name
cur_res["gated_act"] = args.gated_act
res.append(cur_res)

return res
Expand Down Expand Up @@ -1466,7 +1471,6 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
activation_type,
):
# Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t]
# So we convert None to 0 to indicate "no groups" mode
Expand All @@ -1489,7 +1493,6 @@ def run_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
activation_type=activation_type.value,
)

# Benchmark timing
Expand All @@ -1510,7 +1513,6 @@ def run_fp8_per_tensor_moe(
output1_scales_gate_scalar,
gemm2_weights_fp8,
output2_scales_scalar,
args.activation_type,
),
)

Expand Down Expand Up @@ -1562,7 +1564,6 @@ def run_fp8_per_tensor_moe(
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input
cur_res["input_dtype"] = input_dtype
cur_res["weight_dtype"] = weight_dtype
cur_res["activation_type"] = args.activation_type.name
res.append(cur_res)

return res
10 changes: 2 additions & 8 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix &&
tileSize == mOptions.tileSize &&
options.mUseShuffledMatrix == mOptions.useShuffledMatrixA &&
options.mLayoutA == mOptions.weightLayout) {
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
continue;
}
}
if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) {
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
mPassingConfigIndices.push_back(i);
Expand All @@ -124,8 +122,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
<< ", mActType: " << (int64_t)mOptions.actType
<< ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
Expand Down Expand Up @@ -223,8 +219,6 @@ void TrtllmGenBatchedGemmRunner::run(
gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB;
gemmData.mInputBuffers.mPtrScaleC = scaleC;
gemmData.mInputBuffers.mPtrScaleGate = scaleGateC;
// For simplicity pass set scaleAct to scaleGateC
gemmData.mInputBuffers.mPtrScaleAct = scaleGateC;
gemmData.mInputBuffers.mPtrPerTokenSfA =
mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA;
gemmData.mInputBuffers.mPtrPerTokenSfB =
Expand Down
Loading