Skip to content

Commit 20435b4

Browse files
nv-yunzheqyongwwwaleozlx
authored
update trtllm cutlass moe (#2020)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * SM90 scatter-based epilogue and broader SM100/SM120 MOE/GEMM coverage; new public enum for GEMM stages and explicit runner instantiations. * **Improvements** * New runtime controls and parameters exposed: dynamic CGA, swap-AB, swizzled-input SF, unpadded hidden-size, and per-GEMM-stage tactic counts; expanded tile/cluster shape options, finalize-epilogue fusion and fusion/swap-aware dispatch; increased runtime debug logging and profiling. * **Bug Fixes** * License/namespace/header cleanups, suppressed compiler warnings, tightened assertions. * **Tests** * MXFP8×MXFP4 test now permits SM120 devices. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Yong Wu <[email protected]> Co-authored-by: Alex Yang <[email protected]>
1 parent 3cb8f9a commit 20435b4

File tree

44 files changed

+3272
-1828
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3272
-1828
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "moe_kernels.h"
1919

2020
namespace tensorrt_llm::kernels::cutlass_kernels {
21-
// ==================== Variable batched GEMM specializations ==================================
2221
template class CutlassMoeFCRunner<float, float>;
2322

2423
#ifdef ENABLE_BF16
@@ -38,6 +37,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
3837
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
3938
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
4039
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
40+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
4141
#endif
4242
#endif
4343
#ifdef ENABLE_FP4
@@ -54,4 +54,12 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _
5454
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>;
5555
#endif
5656
#endif
57-
}; // namespace tensorrt_llm::kernels::cutlass_kernels
57+
58+
// Explicit instantiations for finalizeMoeRoutingKernelLauncher to ensure
59+
// symbols are emitted in the JIT library for common data types.
60+
INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half);
61+
INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float);
62+
#ifdef ENABLE_BF16
63+
INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16);
64+
#endif
65+
} // namespace tensorrt_llm::kernels::cutlass_kernels

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 431 additions & 440 deletions
Large diffs are not rendered by default.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class DtypeUtils {
7272
default:
7373
TVM_FFI_ICHECK(false) << "unsupported data type";
7474
}
75+
76+
return nvinfer1::DataType::kFLOAT; // supress compiler warning
7577
}
7678

7779
private:
@@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
111113
TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type)
112114
<< " specified for " << DLDataTypeToString(mActivationDtype);
113115
}
116+
117+
return nullptr; // supress compiler warning
114118
};
115119

116120
FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype,
@@ -219,7 +223,13 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
219223
}
220224

221225
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
222-
mAllProfiles = mKernelRunner->getTactics();
226+
// Get tactics for both GEMM1 and GEMM2, combine them
227+
auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
228+
auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
229+
mGemm1TacticCount = static_cast<int64_t>(gemm1_tactics.size());
230+
mGemm2TacticCount = static_cast<int64_t>(gemm2_tactics.size());
231+
mAllProfiles = gemm1_tactics;
232+
mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end());
223233
TVM_FFI_ICHECK(!mAllProfiles.empty())
224234
<< "No valid tactics available for fused moe op with the requested input combination "
225235
"Activation: "
@@ -367,27 +377,31 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
367377

368378
// TODO: support lora in the future
369379
::tensorrt_llm::kernels::LoraParams lora_params{};
380+
// HACK Define default values for parameters we don't have good values for
381+
bool const swizzled_input_sf = true; // Assume input_sf is swizzled by default
382+
int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
383+
bool const use_lora = false; // No lora support yet
370384
#ifdef USING_OSS_CUTLASS_MOE_GEMM
371385
mKernelRunner->runMoe(
372386
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
373-
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
387+
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
374388
token_final_scales.has_value()
375389
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
376390
: nullptr,
377391
fc1_expert_weights.data_ptr(),
378392
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
379393
activation_params, fc2_expert_weights.data_ptr(),
380394
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
381-
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
395+
quant_params, num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
382396
static_cast<int>(experts_per_token),
383397
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
384398
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
385-
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
399+
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
386400
enable_pdl, stream);
387401
#else
388402
mKernelRunner->runMoe(
389403
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
390-
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
404+
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
391405
token_final_scales.has_value()
392406
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
393407
: nullptr,
@@ -396,10 +410,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
396410
activation_params, fc2_expert_weights.data_ptr(),
397411
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
398412
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
399-
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
400-
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
401-
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
402-
enable_pdl, stream);
413+
static_cast<int>(experts_per_token),
414+
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
415+
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
416+
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
403417
#endif
404418
}
405419

@@ -547,39 +561,44 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
547561

548562
// TODO: support lora in the future
549563
::tensorrt_llm::kernels::LoraParams lora_params{};
564+
// HACK Define default values for parameters we don't have good values for
565+
bool const swizzled_input_sf_ml = true; // Assume input_sf is swizzled by default
566+
int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default
567+
bool const use_lora_ml = false; // No lora support yet
550568
#ifdef USING_OSS_CUTLASS_MOE_GEMM
551569
mKernelRunner->runMoe(
552570
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
553-
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
571+
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
554572
token_final_scales.has_value()
555573
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
556574
: nullptr,
557575
fc1_expert_weights.data_ptr(),
558576
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
559577
activation_params, fc2_expert_weights.data_ptr(),
560578
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
561-
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
579+
quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
562580
static_cast<int>(experts_per_token),
563581
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
564582
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
565-
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
583+
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
566584
enable_pdl, stream);
567585
#else
568586
mKernelRunner->runMoe(
569587
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
570-
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
588+
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
571589
token_final_scales.has_value()
572590
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
573591
: nullptr,
574592
fc1_expert_weights.data_ptr(),
575593
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
576594
activation_params, fc2_expert_weights.data_ptr(),
577595
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
578-
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
579-
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
580-
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
581-
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
582-
enable_pdl, stream);
596+
quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
597+
static_cast<int>(experts_per_token),
598+
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
599+
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml,
600+
lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl,
601+
stream);
583602
#endif
584603
}
585604

@@ -641,19 +660,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
641660
auto activation_dtype =
642661
(mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype;
643662
activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype;
663+
int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default
644664
#ifdef USING_OSS_CUTLASS_MOE_GEMM
645665
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
646666
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
647667
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
648-
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
649-
min_latency_mode,
668+
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
669+
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
650670
/*need_weights*/ false, parallelism_config, enable_alltoall);
651671
#else
652672
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
653673
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
654674
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
655-
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
656-
min_latency_mode,
675+
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
676+
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
657677
/*need_weights*/ false, parallelism_config);
658678
#endif
659679

@@ -691,6 +711,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
691711
});
692712
} else if (name == "get_tactic_num") {
693713
return Function::FromTyped([this]() -> int64_t { return getTacticNum(); });
714+
} else if (name == "get_gemm1_tactic_count") {
715+
return Function::FromTyped([this]() -> int64_t { return mGemm1TacticCount; });
716+
} else if (name == "get_gemm2_tactic_count") {
717+
return Function::FromTyped([this]() -> int64_t { return mGemm2TacticCount; });
694718
} else if (name == "run_moe") {
695719
return Function::FromTyped(
696720
[this](TensorView output, TensorView input, TensorView token_selected_experts,
@@ -758,6 +782,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
758782

759783
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
760784
std::vector<Profile> mAllProfiles;
785+
int64_t mGemm1TacticCount{0};
786+
int64_t mGemm2TacticCount{0};
761787

762788
void setRunnerProfiles(Optional<Array<int64_t>> profile_ids) {
763789
if (mUseDeepSeekFP8BlockScaling) {
@@ -771,13 +797,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
771797
}
772798

773799
auto best_gemm1_profile = mAllProfiles.front();
774-
auto best_gemm2_profile = mAllProfiles.front();
800+
// Default GEMM2 profile should come from the GEMM2 subrange if present
801+
auto best_gemm2_profile =
802+
(mGemm2TacticCount > 0 && mAllProfiles.size() > static_cast<size_t>(mGemm1TacticCount))
803+
? mAllProfiles.at(mGemm1TacticCount)
804+
: mAllProfiles.front();
775805
if (profile_ids.has_value()) {
776806
TVM_FFI_ICHECK_EQ(profile_ids.value().size(), 2) << "Expecting 2 profile ids";
777-
best_gemm1_profile = profile_ids.value()[0] == -1 ? best_gemm1_profile
778-
: mAllProfiles.at(profile_ids.value()[0]);
779-
best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
780-
: mAllProfiles.at(profile_ids.value()[1]);
807+
// GEMM1 index: accept absolute index; otherwise if clearly out of combined range, keep
808+
// default
809+
auto id1 = profile_ids.value()[0];
810+
if (id1 != -1) {
811+
TVM_FFI_ICHECK(id1 >= 0 && id1 < mGemm1TacticCount) << "Invalid gemm1 profile id: " << id1;
812+
best_gemm1_profile = mAllProfiles.at(id1);
813+
}
814+
815+
// GEMM2 index: support both absolute (combined) and relative (within GEMM2 subrange) ids
816+
auto id2 = profile_ids.value()[1];
817+
if (id2 != -1) {
818+
int64_t absolute_id2 = id2;
819+
// If id2 appears relative to GEMM2 subrange, offset it
820+
if (id2 >= 0 && id2 < mGemm2TacticCount) {
821+
absolute_id2 = mGemm1TacticCount + id2;
822+
}
823+
TVM_FFI_ICHECK(absolute_id2 >= 0 &&
824+
absolute_id2 < static_cast<int64_t>(mAllProfiles.size()))
825+
<< "Invalid gemm2 profile id: " << id2;
826+
best_gemm2_profile = mAllProfiles.at(absolute_id2);
827+
}
781828
}
782829
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
783830
}

csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,9 @@ using Int = ConstExprWrapper<int, VALUE>;
11811181
template <bool VALUE>
11821182
using Bool = ConstExprWrapper<bool, VALUE>;
11831183

1184+
template <bool VALUE>
1185+
using ConstBool = ConstExprWrapper<bool, VALUE>;
1186+
11841187
template <typename T>
11851188
struct TmaDescType;
11861189

0 commit comments

Comments
 (0)