Skip to content

Commit 0cd4848

Browse files
committed
add FP8 autotune and update deprecate warning
Signed-off-by: jiahanc <[email protected]>
1 parent c2e4bdd commit 0cd4848

File tree

3 files changed

+384
-146
lines changed

3 files changed

+384
-146
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 144 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ inline int32_t nextPowerOfTwo(float value) {
5959
}
6060

6161
std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums,
62-
int64_t const num_tokens, int64_t const top_k,
63-
int64_t const num_local_experts) {
62+
int64_t const num_tokens, int64_t const top_k,
63+
int64_t const num_local_experts) {
6464
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
6565
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
6666
supported_tile_nums.front(), supported_tile_nums.back());
@@ -82,7 +82,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
8282
int64_t const intermediate_size, int64_t const local_expert_offset,
8383
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
8484
bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
85-
int64_t const routing_method_type, bool enable_pdl) {
85+
int64_t const routing_method_type,
86+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
87+
bool enable_pdl) {
8688
static const std::tuple<int, int> device_props = [hidden_states] {
8789
int major, minor;
8890
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor,
@@ -160,6 +162,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
160162
} else {
161163
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE.";
162164
}
165+
args.mDtypeOut = btg::Dtype::Bfloat16; // Output is always bfloat16 for fp8 per-tensor scale
163166

164167
args.routing_logits = routing_logits.data_ptr();
165168
auto const routing_bias_dtype =
@@ -194,6 +197,13 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
194197
int32_t max_num_padded_tokens =
195198
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
196199
args.num_tokens, top_k, num_experts, tile_tokens_dim);
200+
int32_t max_num_padded_tokens_gemm1 =
201+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
202+
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
203+
int32_t max_num_padded_tokens_gemm2 =
204+
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
205+
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
206+
197207
Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device());
198208
Tensor expanded_idx_to_permuted_idx =
199209
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device());
@@ -210,16 +220,17 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
210220
routing_logits.device());
211221

212222
// allocate workspace for activation/gemm/finalize kernels
213-
Tensor gemm1_output =
214-
alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states.device());
215-
Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens},
216-
dl_float32, hidden_states.device());
217-
Tensor activation_output =
218-
alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states.device());
219-
Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens},
220-
dl_float32, hidden_states.device());
221-
Tensor gemm2_output =
222-
alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states.device());
223+
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8,
224+
hidden_states.device());
225+
Tensor gemm1_output_scale =
226+
alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32,
227+
hidden_states.device());
228+
Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size},
229+
dl_uint8, hidden_states.device());
230+
Tensor activation_output_scale = alloc_tensor(
231+
{intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device());
232+
Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16,
233+
hidden_states.device());
223234

224235
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
225236
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
@@ -289,7 +300,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
289300

290301
// setup workspace
291302
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens.data_ptr());
292-
workspace.total_max_padded_tokens = max_num_padded_tokens;
303+
workspace.total_max_padded_tokens =
304+
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
293305
workspace.ProjUpTileN = tile_tokens_dim;
294306
workspace.routing_expert_indexes = static_cast<int*>(expert_indexes.data_ptr());
295307
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens.data_ptr());
@@ -315,13 +327,6 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
315327
args.output = output.data_ptr();
316328
args.output_scale = nullptr;
317329

318-
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
319-
args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim, /*useShuffledMatrixA*/ true);
320-
321-
auto const moeConfigIndex =
322-
moe_runner.getDefaultValidConfigIndex(args.top_k, args.hidden_size, args.intermediate_size,
323-
args.local_num_experts, args.num_tokens);
324-
325330
auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
326331
Tensor workspace_fc1 =
327332
alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device());
@@ -341,16 +346,56 @@ void trtllm_fp8_per_tensor_scale_moe(
341346
TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
342347
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
343348
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
344-
bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type,
345-
bool enable_pdl) {
349+
bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl,
350+
Array<int64_t> config_index) {
346351
auto dtype = hidden_states.dtype();
347352
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
353+
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
354+
355+
// Convert PyTorch dtype to TensorRT-LLM dtype
356+
btg::Dtype mDtypeElt;
357+
if (dtype == dl_float16) {
358+
mDtypeElt = btg::Dtype::Fp16;
359+
} else if (dtype == dl_bfloat16) {
360+
mDtypeElt = btg::Dtype::Bfloat16;
361+
} else if (dtype == dl_float8_e4m3fn) {
362+
mDtypeElt = btg::Dtype::E4m3;
363+
} else {
364+
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE.";
365+
}
366+
367+
auto const num_tokens = hidden_states.size(0);
368+
auto const hidden_size = hidden_states.size(1);
369+
bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8
370+
371+
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
372+
std::set<int32_t> selected_tile_nums =
373+
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
374+
375+
// Build runners for all supported tile sizes
376+
std::unordered_map<int32_t, std::unique_ptr<RunnerType>> mRunners;
377+
for (int32_t tile_N : selected_tile_nums) {
378+
// Always use the two-parameter constructor for consistency
379+
mRunners.emplace(tile_N, std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, tile_N,
380+
/*useShuffledMatrixA*/ true));
381+
}
382+
383+
// moeConfigIndex corresponds to pair (tile_N, config)
384+
int64_t tile_N = config_index[0];
385+
int64_t config = config_index[1];
386+
// Autotuner has requested a default or 'fallback' config index
387+
if (tile_N == -1 || config == -1) {
388+
tile_N = *selected_tile_nums.begin();
389+
config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
390+
local_num_experts, num_tokens);
391+
}
392+
348393
trtllm_fp8_per_tensor_scale_moe_launcher(
349394
routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar,
350395
output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts,
351396
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
352-
routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type,
353-
enable_pdl);
397+
routed_scaling_factor, use_routing_scales_on_input, tile_N, routing_method_type,
398+
*mRunners[tile_N], config, enable_pdl);
354399
} else {
355400
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype.";
356401
}
@@ -651,16 +696,14 @@ void trtllm_fp8_block_scale_moe_launcher(
651696
enable_pdl);
652697
}
653698

654-
void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView> routing_bias,
655-
TensorView hidden_states, TensorView hidden_states_scale,
656-
TensorView gemm1_weights, TensorView gemm1_weights_scale,
657-
TensorView gemm2_weights, TensorView gemm2_weights_scale,
658-
TensorView output, int64_t num_experts, int64_t top_k,
659-
Optional<int64_t> n_group, Optional<int64_t> topk_group,
660-
int64_t intermediate_size, int64_t local_expert_offset,
661-
int64_t local_num_experts, Optional<double> routed_scaling_factor,
662-
int64_t tile_tokens_dim, int64_t routing_method_type,
663-
bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) {
699+
void trtllm_fp8_block_scale_moe(
700+
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
701+
TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
702+
TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
703+
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group, Optional<int64_t> topk_group,
704+
int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts,
705+
Optional<double> routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight,
706+
int64_t weight_layout, bool enable_pdl, Array<int64_t> config_index) {
664707
auto dtype = hidden_states.dtype();
665708
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
666709
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
@@ -671,24 +714,36 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
671714
TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2)
672715
<< "the value of weight_layout is not recognized";
673716

674-
// Properly initialize the runner using make_unique like in the original code
675-
auto mRunner = std::make_unique<RunnerType>(
676-
mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim, use_shuffled_weight,
677-
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
678-
679-
// Always use fallback config (equivalent to moeConfigIndex == -1 case from original code)
680717
auto const num_tokens = hidden_states.size(0);
681718
auto const hidden_size = hidden_states.size(1);
682719

683-
int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex(
684-
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
720+
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
721+
std::set<int32_t> selected_tile_nums =
722+
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
723+
724+
// Build runners for all supported tile sizes
725+
std::unordered_map<int32_t, std::unique_ptr<RunnerType>> mRunners;
726+
for (int32_t tile_N : selected_tile_nums) {
727+
mRunners.emplace(tile_N, std::make_unique<RunnerType>(
728+
mDtypeElt, mUseDeepSeekFp8, tile_N, use_shuffled_weight,
729+
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)));
730+
}
731+
732+
// moeConfigIndex corresponds to pair (tile_N, config)
733+
int64_t tile_N = config_index[0];
734+
int64_t config = config_index[1];
735+
// Autotuner has requested a default or 'fallback' config index
736+
if (tile_N == -1 || config == -1) {
737+
tile_N = *selected_tile_nums.begin();
738+
config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
739+
local_num_experts, num_tokens);
740+
}
685741

686742
trtllm_fp8_block_scale_moe_launcher(
687743
routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
688744
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k,
689745
n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
690-
routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex,
691-
enable_pdl);
746+
routed_scaling_factor, tile_N, routing_method_type, *mRunners[tile_N], config, enable_pdl);
692747
} else {
693748
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported hidden state dtype.";
694749
}
@@ -1184,29 +1239,29 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
11841239
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
11851240
// Build runners for all supported tile sizes
11861241
std::unordered_map<int32_t, std::unique_ptr<RunnerType>> mRunners;
1187-
for (int32_t tileN : selected_tile_nums) {
1188-
mRunners.emplace(tileN,
1189-
std::make_unique<RunnerType>(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tileN,
1242+
for (int32_t tile_N : selected_tile_nums) {
1243+
mRunners.emplace(tile_N,
1244+
std::make_unique<RunnerType>(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tile_N,
11901245
static_cast<GatedActType>(gated_act_type),
11911246
/*useShuffledMatrixA*/ true));
11921247
}
11931248

1194-
// moeConfigIndex corresponds to pair (tileN, config)
1195-
int64_t tileN = config_index[0];
1249+
// moeConfigIndex corresponds to pair (tile_N, config)
1250+
int64_t tile_N = config_index[0];
11961251
int64_t config = config_index[1];
11971252
// Autotuner has requested a default or 'fallback' config index
1198-
if (tileN == -1 || config == -1) {
1199-
tileN = *selected_tile_nums.begin();
1200-
config = mRunners[tileN]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1201-
local_num_experts, num_tokens);
1253+
if (tile_N == -1 || config == -1) {
1254+
tile_N = *selected_tile_nums.begin();
1255+
config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1256+
local_num_experts, num_tokens);
12021257
}
12031258
return trtllm_fp4_block_scale_moe_launcher(
12041259
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
12051260
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit,
12061261
gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar,
12071262
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
1208-
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
1209-
routing_method_type, do_finalize, *mRunners[tileN], mDtypeAct, mDtypeWeights, config,
1263+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_N,
1264+
routing_method_type, do_finalize, *mRunners[tile_N], mDtypeAct, mDtypeWeights, config,
12101265
enable_pdl, output);
12111266
}
12121267

@@ -1218,41 +1273,58 @@ int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const d
12181273
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
12191274
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
12201275
std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
1221-
if (dtype_act != btg::Dtype::Bfloat16) {
1276+
if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1277+
dtype_act != btg::Dtype::Bfloat16) {
12221278
supported_tile_nums.push_back(128);
12231279
}
12241280
std::set<int32_t> selected_tile_nums =
12251281
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);
1226-
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1227-
dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(),
1228-
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
1229-
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1230-
num_local_experts, num_tokens);
1282+
1283+
std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner =
1284+
std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1285+
dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(),
1286+
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
1287+
1288+
return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1289+
num_local_experts, num_tokens);
12311290
}
12321291

12331292
Array<Array<int64_t>> trtllm_get_valid_moe_configs(
12341293
int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8,
12351294
int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size,
1236-
int64_t const num_local_experts, int64_t const gated_act_type, int64_t const num_tokens) {
1237-
// returns (tileN, config)
1295+
int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight,
1296+
int64_t const weight_layout, int64_t const num_tokens) {
1297+
// returns (tile_N, config)
12381298
Array<Array<int64_t>> valid_configs;
12391299
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
12401300
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
12411301
std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
1242-
if (dtype_act != btg::Dtype::Bfloat16) {
1302+
if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1303+
dtype_act != btg::Dtype::Bfloat16) {
12431304
supported_tile_nums.push_back(128);
12441305
}
12451306
std::set<int32_t> selected_tile_nums =
12461307
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);
12471308

1248-
for (int32_t tileN : selected_tile_nums) {
1249-
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1250-
dtype_act, dtype_weights, useDeepSeekFp8, tileN, static_cast<GatedActType>(gated_act_type),
1251-
/*useShuffledMatrixA*/ true);
1252-
auto cfgs = moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size,
1253-
num_local_experts, num_tokens);
1309+
for (int32_t tile_N : selected_tile_nums) {
1310+
std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner;
1311+
1312+
if (dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3) {
1313+
// FP8 block scale MOE runner
1314+
moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1315+
dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight,
1316+
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
1317+
} else {
1318+
// FP4 block scale MOE runner
1319+
moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1320+
dtype_act, dtype_weights, useDeepSeekFp8, tile_N,
1321+
static_cast<GatedActType>(gated_act_type),
1322+
/*useShuffledMatrixA*/ true);
1323+
}
1324+
auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size,
1325+
num_local_experts, num_tokens);
12541326
for (auto cfg : cfgs) {
1255-
valid_configs.push_back({tileN, cfg});
1327+
valid_configs.push_back({tile_N, cfg});
12561328
}
12571329
}
12581330
return valid_configs;

0 commit comments

Comments
 (0)