Skip to content

Commit c1926da

Browse files
committed
update moe gtest
Signed-off-by: Xiwen Yu <[email protected]>
1 parent 36c0415 commit c1926da

File tree

1 file changed

+60
-11
lines changed

1 file changed

+60
-11
lines changed

cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ protected:
174174
using WeightStorage = std::conditional_t<WEIGHT_ELEM_PER_BYTE == 2, uint8_t, WeightType>;
175175
constexpr static int64_t HIDDEN_SIZE_MULTIPLIER = 16;
176176
constexpr static int64_t MINIMUM_BYTE_ALIGNMENT
177-
= MX_QUANT_WEIGHT ? 64 : 256 / 8; // NoSmem requires 256 bits alignment, MX quant requires 64 bytes
177+
= MX_QUANT_WEIGHT ? 64 : 128 / 8; // TMA requires 128 bits alignment, MX quant requires 64 bytes
178178
constexpr static int64_t MINIMUM_ALIGNMENT_CONST
179179
= MINIMUM_BYTE_ALIGNMENT * WEIGHT_ELEM_PER_BYTE / sizeof(WeightStorage);
180180
constexpr static int64_t DEFAULT_HIDDEN_SIZE = HIDDEN_SIZE_MULTIPLIER * MINIMUM_ALIGNMENT_CONST;
@@ -1127,23 +1127,33 @@ protected:
11271127
return tactics;
11281128
}
11291129
1130-
auto selectTacticsForArch(int sm, bool exact_match = false)
1130+
auto selectTacticsForArch(int sm, bool exact_match = false, bool allow_no_smem = false)
11311131
{
11321132
bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT;
1133+
auto filter_epi_schd = [sm, allow_no_smem](auto& c)
1134+
{
1135+
if (sm >= 100 && sm < 120)
1136+
{
1137+
return c.epilogue_schedule
1138+
== (allow_no_smem ? tensorrt_llm::cutlass_extensions::EpilogueScheduleType::NO_SMEM
1139+
: tensorrt_llm::cutlass_extensions::EpilogueScheduleType::TMA);
1140+
}
1141+
return true;
1142+
};
11331143
auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize)
11341144
? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE
11351145
: tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE;
11361146
auto smExact = [exact_match, sm](auto& c) { return !exact_match || c.sm_version == sm; };
11371147
auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1);
11381148
auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2);
11391149
auto it1 = std::find_if(tactics1.begin(), tactics1.end(),
1140-
[is_tma_warp_specialized, smExact](auto& c)
1141-
{ return c.is_tma_warp_specialized == is_tma_warp_specialized && smExact(c); });
1150+
[is_tma_warp_specialized, smExact, filter_epi_schd](auto& c)
1151+
{ return c.is_tma_warp_specialized == is_tma_warp_specialized && smExact(c) && filter_epi_schd(c); });
11421152
auto it2 = std::find_if(tactics2.begin(), tactics2.end(),
1143-
[is_tma_warp_specialized, epilogue_fusion_type, smExact](auto& c)
1153+
[is_tma_warp_specialized, epilogue_fusion_type, smExact, filter_epi_schd](auto& c)
11441154
{
11451155
return c.is_tma_warp_specialized == is_tma_warp_specialized
1146-
&& c.epilogue_fusion_type == epilogue_fusion_type && smExact(c);
1156+
&& c.epilogue_fusion_type == epilogue_fusion_type && smExact(c) && filter_epi_schd(c);
11471157
});
11481158
if (it1 == tactics1.end() || it2 == tactics2.end())
11491159
{
@@ -1159,7 +1169,7 @@ protected:
11591169
using ConfigsToTestVec = std::vector<std::pair<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
11601170
tensorrt_llm::cutlass_extensions::CutlassGemmConfig>>;
11611171
1162-
auto getAllTileConfigsToTest()
1172+
auto getAllTileConfigsToTest(bool allow_no_smem = false)
11631173
{
11641174
if (mOverrideSelectedConfig1 && mOverrideSelectedConfig2)
11651175
{
@@ -1168,16 +1178,16 @@ protected:
11681178
11691179
int sm = getSMVersion();
11701180
bool needs_exact_match = sm == 103 && NVFP4;
1171-
ConfigsToTestVec tactics = {selectTacticsForArch(sm, needs_exact_match)};
1181+
ConfigsToTestVec tactics = {selectTacticsForArch(sm, needs_exact_match, allow_no_smem)};
11721182
if (sm == 103 && NVFP4)
11731183
{
11741184
// SM103 NVFP4 should also test SM100f kernels
1175-
tactics.push_back(selectTacticsForArch(100, true));
1185+
tactics.push_back(selectTacticsForArch(100, true, allow_no_smem));
11761186
}
11771187
if (sm >= 90 && !ANY_FPX)
11781188
{
11791189
// SM90+ should also grab some configs for SM80 to test them
1180-
tactics.push_back(selectTacticsForArch(80, true));
1190+
tactics.push_back(selectTacticsForArch(80, true, allow_no_smem));
11811191
}
11821192
return tactics;
11831193
}
@@ -1530,6 +1540,9 @@ protected:
15301540
void BasicPermuteTest(
15311541
int k = 1, int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4, int64_t num_tokens = 3);
15321542
1543+
void BasicPermuteTestInternal(int k = 1, int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4,
1544+
int64_t num_tokens = 3, bool allow_no_smem = false);
1545+
15331546
std::vector<int> calcPermuteMapExpertParallel(std::vector<int> const& expected_experts);
15341547
15351548
void ExpertParallelTest(int k = 1, int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4,
@@ -1650,8 +1663,44 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
16501663
return;
16511664
}
16521665
}
1666+
this->BasicPermuteTestInternal(k, hidden_size, num_experts, num_tokens, false);
1667+
int sm = getSMVersion();
1668+
if (sm >= 100 && sm < 120 && (!ANY_FP4 || sm == 103))
1669+
{
1670+
// Test NO_SMEM for: SM103 all, or SM100 non-FP4
1671+
int64_t minimum_byte_alignment
1672+
= MX_QUANT_WEIGHT ? 64 : 256 / 8; // NO_SMEM requires 256 bits alignment, MX quant requires 64 bytes
1673+
int64_t minimum_alignment_const = minimum_byte_alignment * WEIGHT_ELEM_PER_BYTE / sizeof(WeightStorage);
1674+
int64_t default_hidden_size = HIDDEN_SIZE_MULTIPLIER * minimum_alignment_const;
1675+
int64_t deviceMinimumAlignment
1676+
= std::max(minimum_alignment_const, int64_t(WEIGHT_ELEM_PER_BYTE * 32 / sizeof(WeightStorage)));
1677+
int old_hidden_size = hidden_size;
1678+
if (hidden_size == DEFAULT_HIDDEN_SIZE)
1679+
{
1680+
hidden_size = default_hidden_size;
1681+
}
1682+
if (hidden_size == mDeviceMinimumAlignment)
1683+
{
1684+
hidden_size = deviceMinimumAlignment;
1685+
}
1686+
if (hidden_size % minimum_alignment_const != 0)
1687+
{
1688+
hidden_size = ((hidden_size / minimum_alignment_const) + 1) * minimum_alignment_const;
1689+
}
1690+
if (hidden_size != old_hidden_size)
1691+
{
1692+
GTEST_LOG_(INFO) << "Appending NO_SMEM test with hidden size: " << hidden_size
1693+
<< " (was: " << old_hidden_size << ")";
1694+
}
1695+
this->BasicPermuteTestInternal(k, hidden_size, num_experts, num_tokens, true);
1696+
}
1697+
}
16531698
1654-
auto test_archs = getAllTileConfigsToTest();
1699+
template <class TypeParam_>
1700+
void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTestInternal(
1701+
int k, int64_t hidden_size, int64_t num_experts, int64_t num_tokens, bool allow_no_smem)
1702+
{
1703+
auto test_archs = getAllTileConfigsToTest(allow_no_smem);
16551704
for (auto [gemm1, gemm2] : test_archs)
16561705
{
16571706
mInternalSelectedConfig1 = gemm1;

0 commit comments

Comments
 (0)