@@ -174,7 +174,7 @@ protected:
174174 using WeightStorage = std::conditional_t <WEIGHT_ELEM_PER_BYTE == 2 , uint8_t , WeightType>;
175175 constexpr static int64_t HIDDEN_SIZE_MULTIPLIER = 16 ;
176176 constexpr static int64_t MINIMUM_BYTE_ALIGNMENT
177- = MX_QUANT_WEIGHT ? 64 : 256 / 8 ; // NoSmem requires 256 bits alignment, MX quant requires 64 bytes
177+ = MX_QUANT_WEIGHT ? 64 : 128 / 8 ; // TMA requires 128 bits alignment, MX quant requires 64 bytes
178178 constexpr static int64_t MINIMUM_ALIGNMENT_CONST
179179 = MINIMUM_BYTE_ALIGNMENT * WEIGHT_ELEM_PER_BYTE / sizeof (WeightStorage);
180180 constexpr static int64_t DEFAULT_HIDDEN_SIZE = HIDDEN_SIZE_MULTIPLIER * MINIMUM_ALIGNMENT_CONST;
@@ -1127,23 +1127,33 @@ protected:
11271127 return tactics;
11281128 }
11291129
1130- auto selectTacticsForArch (int sm, bool exact_match = false )
1130+ auto selectTacticsForArch (int sm, bool exact_match = false , bool allow_no_smem = false )
11311131 {
11321132 bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT;
1133+ auto filter_epi_schd = [sm, allow_no_smem](auto & c)
1134+ {
1135+ if (sm >= 100 && sm < 120 )
1136+ {
1137+ return c.epilogue_schedule
1138+ == (allow_no_smem ? tensorrt_llm::cutlass_extensions::EpilogueScheduleType::NO_SMEM
1139+ : tensorrt_llm::cutlass_extensions::EpilogueScheduleType::TMA);
1140+ }
1141+ return true ;
1142+ };
11331143 auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize )
11341144 ? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE
11351145 : tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE;
11361146 auto smExact = [exact_match, sm](auto & c) { return !exact_match || c.sm_version == sm; };
11371147 auto tactics1 = getFilteredConfigs (sm, MoeGemmId::GEMM_1);
11381148 auto tactics2 = getFilteredConfigs (sm, MoeGemmId::GEMM_2);
11391149 auto it1 = std::find_if (tactics1.begin (), tactics1.end (),
1140- [is_tma_warp_specialized, smExact](auto & c)
1141- { return c.is_tma_warp_specialized == is_tma_warp_specialized && smExact (c); });
1150+ [is_tma_warp_specialized, smExact, filter_epi_schd ](auto & c)
1151+ { return c.is_tma_warp_specialized == is_tma_warp_specialized && smExact (c) && filter_epi_schd (c) ; });
11421152 auto it2 = std::find_if (tactics2.begin (), tactics2.end (),
1143- [is_tma_warp_specialized, epilogue_fusion_type, smExact](auto & c)
1153+ [is_tma_warp_specialized, epilogue_fusion_type, smExact, filter_epi_schd ](auto & c)
11441154 {
11451155 return c.is_tma_warp_specialized == is_tma_warp_specialized
1146- && c.epilogue_fusion_type == epilogue_fusion_type && smExact (c);
1156+ && c.epilogue_fusion_type == epilogue_fusion_type && smExact (c) && filter_epi_schd (c) ;
11471157 });
11481158 if (it1 == tactics1.end () || it2 == tactics2.end ())
11491159 {
@@ -1159,7 +1169,7 @@ protected:
11591169 using ConfigsToTestVec = std::vector<std::pair<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
11601170 tensorrt_llm::cutlass_extensions::CutlassGemmConfig>>;
11611171
1162- auto getAllTileConfigsToTest ()
1172+ auto getAllTileConfigsToTest (bool allow_no_smem = false )
11631173 {
11641174 if (mOverrideSelectedConfig1 && mOverrideSelectedConfig2 )
11651175 {
@@ -1168,16 +1178,16 @@ protected:
11681178
11691179 int sm = getSMVersion ();
11701180 bool needs_exact_match = sm == 103 && NVFP4;
1171- ConfigsToTestVec tactics = {selectTacticsForArch (sm, needs_exact_match)};
1181+ ConfigsToTestVec tactics = {selectTacticsForArch (sm, needs_exact_match, allow_no_smem )};
11721182 if (sm == 103 && NVFP4)
11731183 {
11741184 // SM103 NVFP4 should also test SM100f kernels
1175- tactics.push_back (selectTacticsForArch (100 , true ));
1185+ tactics.push_back (selectTacticsForArch (100 , true , allow_no_smem ));
11761186 }
11771187 if (sm >= 90 && !ANY_FPX)
11781188 {
11791189 // SM90+ should also grab some configs for SM80 to test them
1180- tactics.push_back (selectTacticsForArch (80 , true ));
1190+ tactics.push_back (selectTacticsForArch (80 , true , allow_no_smem ));
11811191 }
11821192 return tactics;
11831193 }
@@ -1530,6 +1540,9 @@ protected:
15301540 void BasicPermuteTest (
15311541 int k = 1 , int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4 , int64_t num_tokens = 3 );
15321542
1543+ void BasicPermuteTestInternal (int k = 1 , int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4 ,
1544+ int64_t num_tokens = 3 , bool allow_no_smem = false );
1545+
15331546 std::vector<int > calcPermuteMapExpertParallel (std::vector<int > const & expected_experts);
15341547
15351548 void ExpertParallelTest (int k = 1 , int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4 ,
@@ -1650,8 +1663,44 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
16501663 return ;
16511664 }
16521665 }
1666+ this ->BasicPermuteTestInternal (k, hidden_size, num_experts, num_tokens, false );
1667+ int sm = getSMVersion ();
1668+ if (sm >= 100 && sm < 120 && (!ANY_FP4 || sm == 103 ))
1669+ {
1670+ // Test NO_SMEM for: SM103 all, or SM100 non-FP4
1671+ int64_t minimum_byte_alignment
1672+ = MX_QUANT_WEIGHT ? 64 : 256 / 8 ; // NO_SMEM requires 256 bits alignment, MX quant requires 64 bytes
1673+ int64_t minimum_alignment_const = minimum_byte_alignment * WEIGHT_ELEM_PER_BYTE / sizeof (WeightStorage);
1674+ int64_t default_hidden_size = HIDDEN_SIZE_MULTIPLIER * minimum_alignment_const;
1675+ int64_t deviceMinimumAlignment
1676+ = std::max (minimum_alignment_const, int64_t (WEIGHT_ELEM_PER_BYTE * 32 / sizeof (WeightStorage)));
1677+ int old_hidden_size = hidden_size;
1678+ if (hidden_size == DEFAULT_HIDDEN_SIZE)
1679+ {
1680+ hidden_size = default_hidden_size;
1681+ }
1682+ if (hidden_size == mDeviceMinimumAlignment )
1683+ {
1684+ hidden_size = deviceMinimumAlignment;
1685+ }
1686+ if (hidden_size % minimum_alignment_const != 0 )
1687+ {
1688+ hidden_size = ((hidden_size / minimum_alignment_const) + 1 ) * minimum_alignment_const;
1689+ }
1690+ if (hidden_size != old_hidden_size)
1691+ {
1692+ GTEST_LOG_ (INFO) << " Appending NO_SMEM test with hidden size: " << hidden_size
1693+ << " (was: " << old_hidden_size << " )" ;
1694+ }
1695+ this ->BasicPermuteTestInternal (k, hidden_size, num_experts, num_tokens, true );
1696+ }
1697+ }
16531698
1654- auto test_archs = getAllTileConfigsToTest ();
1699+ template <class TypeParam_ >
1700+ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTestInternal(
1701+ int k, int64_t hidden_size, int64_t num_experts, int64_t num_tokens, bool allow_no_smem)
1702+ {
1703+ auto test_archs = getAllTileConfigsToTest (allow_no_smem);
16551704 for (auto [gemm1, gemm2] : test_archs)
16561705 {
16571706 mInternalSelectedConfig1 = gemm1;
0 commit comments