Skip to content

Commit e09cebe

Browse files
committed
Fix build issues
Signed-off-by: djns99 <[email protected]>
1 parent 2580782 commit e09cebe

File tree

9 files changed

+63
-34
lines changed

9 files changed

+63
-34
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ auto listAllTactics(MoeGemmId gemm_id)
5050
}
5151

5252
template <class BenchClass>
53-
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
53+
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids, MoeGemmId gemm_id)
5454
{
5555
if (tactic.is_number_integer())
5656
{
@@ -60,7 +60,7 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
6060
{
6161
for (auto c : tactic)
6262
{
63-
parseTacticToVectorID<BenchClass>(c, tactic_ids);
63+
parseTacticToVectorID<BenchClass>(c, tactic_ids, gemm_id);
6464
}
6565
}
6666
else if (tactic.is_string())
@@ -69,7 +69,7 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
6969
auto tactic_name = tactic.get<std::string>();
7070
if (tactic_name == "all")
7171
{
72-
auto all_tactics = listAllTactics<BenchClass>();
72+
auto all_tactics = listAllTactics<BenchClass>(gemm_id);
7373
tactic_ids.resize(all_tactics.size());
7474
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
7575
}
@@ -291,9 +291,14 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
291291
{
292292
printed = true;
293293
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
294-
auto confs = listAllTactics<BenchClass>();
295-
for (auto c : confs)
296-
std::cerr << c.toString();
294+
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
295+
{
296+
std::cerr << "GEMM " << (int) gemm_id << ":\n";
297+
auto confs = listAllTactics<BenchClass>(gemm_id);
298+
for (auto c : confs)
299+
std::cerr << c.toString();
300+
std::cerr << std::endl;
301+
}
297302
}
298303

299304
continue;

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,14 @@ struct CutlassGemmConfig
432432
bool enableCudaKernel = false;
433433
int sm_version = 80; // Use 80 as a catch all for <90
434434
bool is_tma_warp_specialized = false;
435-
bool is_finalize_fusion = false;
435+
436+
enum class EpilogueFusionType : int
437+
{
438+
NONE,
439+
FINALIZE
440+
};
441+
442+
EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE;
436443

437444
CutlassGemmConfig() = default;
438445

@@ -504,7 +511,7 @@ struct CutlassGemmConfig
504511
<< "\n\tcluster shape ID: " << (int) cluster_shape
505512
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
506513
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false")
507-
<< "\n\tis_finalize_fusion: " << (is_finalize_fusion ? "true" : "false");
514+
<< "\n\tepilogue fusion type: " << (int) epilogue_fusion_type;
508515
}
509516
else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
510517
{
@@ -536,7 +543,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
536543
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
537544
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
538545
<< ", cluster_shape_enum: " << int(config.cluster_shape)
539-
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
546+
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false")
547+
<< ", epilogue_fusion_type: " << int(config.epilogue_fusion_type);
540548
}
541549
else
542550
{

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,13 @@ class MoeGemmRunner
297297
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
298298

299299
[[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;
300-
[[nodiscard]] bool supportsTmaWarpSpecialized() const;
300+
301+
[[nodiscard]] bool supportsTmaWarpSpecialized() const
302+
{
303+
return supportsTmaWarpSpecialized(sm_);
304+
}
305+
306+
[[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
301307
[[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
302308
ActivationType activation_type, int gemm_n, int gemm_k) const;
303309
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
600600
gemm2_config_ = std::move(gemm2_config);
601601
}
602602

603-
static auto& addFinalizeFusionConfigs(
604-
std::vector<cutlass_extensions::CutlassGemmConfig>& configs, bool use_fused_finalize)
603+
static auto addFinalizeFusionConfigs(
604+
std::vector<cutlass_extensions::CutlassGemmConfig>&& configs, bool use_fused_finalize)
605605
{
606606
if (!use_fused_finalize)
607607
return configs;
@@ -612,22 +612,24 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
612612
if (configs[i].is_tma_warp_specialized)
613613
{
614614
configs.push_back(configs[i]);
615-
configs.back().is_finalize_fusion = true;
615+
configs.back().epilogue_fusion_type
616+
= cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
616617
}
617618
}
618619
return configs;
619620
}
620621

621622
std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(MoeGemmId gemm_id) override
622623
{
623-
return addFinalizeFusionConfigs(
624+
return Self::addFinalizeFusionConfigs(
624625
moe_gemm_runner_.getConfigs(), gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused());
625626
}
626627

627628
static std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(int sm, MoeGemmId gemm_id)
628629
{
629630
using RunnerType = decltype(moe_gemm_runner_);
630-
return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused(sm));
631+
return Self::addFinalizeFusionConfigs(
632+
RunnerType::getConfigs(sm), gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm));
631633
}
632634

633635
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,8 +2847,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
28472847
expert_first_token_offset_ = getWsPtr(int64_t{}, "expert_first_token_offset");
28482848

28492849
// We check if the provided config uses fused finalize and disable it if it does not
2850+
bool gemm2_using_finalize_fusion
2851+
= gemm2_config_->epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
28502852
permuted_token_final_scales_
2851-
= gemm2_config_->using_fused_finalize ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
2853+
= gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
28522854

28532855
bool const is_gated_activation = isGatedActivation(activation_type);
28542856
bool const gemm1_using_fused_moe
@@ -4005,9 +4007,11 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
40054007

40064008
bool apply_bias = parallelism_config.tp_rank == 0;
40074009
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
4010+
bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
4011+
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
40084012
bool using_fused_finalize
4009-
= use_fused_finalize_ && gemm2_config_->is_finalize_fusion && !use_w4_groupwise && !use_lora;
4010-
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_config_->using_fused_finalize,
4013+
= use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
4014+
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
40114015
"GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
40124016
if (using_fused_finalize)
40134017
{

cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,7 @@ class MoeGemmRunner
290290
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
291291

292292
[[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;
293-
294-
[[nodiscard]] bool supportsTmaWarpSpecialized() const
295-
{
296-
return supportsTmaWarpSpecialized(sm_);
297-
}
298-
299-
[[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
293+
[[nodiscard]] bool supportsTmaWarpSpecialized() const;
300294
[[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
301295
ActivationType activation_type, int gemm_n, int gemm_k) const;
302296
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;

cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
4343
using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyParams;
4444
using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig;
4545
using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams;
46+
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
4647
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
4748
using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
4849
using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace common = tensorrt_llm::common;
4848
namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
4949
using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams;
5050
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
51+
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
5152
// Always use public header as it is just utility functions and types
5253
using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
5354
using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend;
@@ -571,9 +572,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
571572
return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids);
572573
}
573574

574-
int64_t getTacticNum(int gemm_idx)
575+
int64_t getTacticNum(int64_t const gemm_idx)
575576
{
576577
std::lock_guard<std::mutex> lock(mMutex);
578+
TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2");
577579
return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size();
578580
}
579581

cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,15 +1120,17 @@ protected:
11201120
auto selectTacticsForArch(int sm)
11211121
{
11221122
bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT;
1123-
bool is_finalize_fusion = is_tma_warp_specialized && mUseFusedFinalize;
1123+
bool epilogue_fusion_type = is_tma_warp_specialized && mUseFusedFinalize
1124+
? cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE
1125+
: cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE;
11241126
auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1);
11251127
auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2);
11261128
auto it1 = std::find_if(tactics1.begin(), tactics1.end(),
11271129
[is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; });
11281130
auto it2 = std::find_if(tactics2.begin(), tactics2.end(),
1129-
[is_tma_warp_specialized, is_finalize_fusion](auto& c) {
1131+
[is_tma_warp_specialized, epilogue_fusion_type](auto& c) {
11301132
return c.is_tma_warp_specialized == is_tma_warp_specialized
1131-
&& c.using_fused_finalize == is_finalize_fusion;
1133+
&& c.epilogue_fusion_type == epilogue_fusion_type;
11321134
});
11331135
if (it1 == tactics1.end() || it2 == tactics2.end())
11341136
{
@@ -1175,7 +1177,7 @@ protected:
11751177
if (!tactic1 || !tactic2)
11761178
{
11771179
int sm = getSMVersion();
1178-
std::tie(tactic1, tactic2) = selectTacticsForArch(sm, mUseFusedFinalize);
1180+
std::tie(tactic1, tactic2) = selectTacticsForArch(sm);
11791181
}
11801182
ASSERT_TRUE(tactic1.has_value());
11811183
ASSERT_TRUE(tactic2.has_value());
@@ -1637,8 +1639,9 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
16371639
auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k);
16381640
16391641
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
1640-
bool should_be_deterministic
1641-
= !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
1642+
bool is_finalize_fusion
1643+
= gemm2.epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
1644+
bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
16421645
if (should_be_deterministic && !mIsLongTest)
16431646
{
16441647
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@@ -1904,8 +1907,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
19041907
// Only need to init the inputs on the first iteration
19051908
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
19061909
MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
1910+
bool is_finalize_fusion = gemm2.epilogue_fusion_type
1911+
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
19071912
bool should_be_deterministic
1908-
= !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
1913+
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
19091914
if (should_be_deterministic && !mIsLongTest)
19101915
{
19111916
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@@ -1920,8 +1925,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
19201925
else
19211926
{
19221927
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
1928+
bool is_finalize_fusion = gemm2.epilogue_fusion_type
1929+
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
19231930
bool should_be_deterministic
1924-
= !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
1931+
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
19251932
if (should_be_deterministic && !mIsLongTest)
19261933
{
19271934
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);

0 commit comments

Comments
 (0)