@@ -14,45 +14,44 @@ namespace turbomind::gemm {
14
14
template <int TILE_M, int TILE_N, int TILE_K, int BATCH_M, int BATCH_N, int PIPE_M, int PIPE_N>
15
15
struct ScaledGmmaFP8_TN {
16
16
17
- static constexpr auto select_gmma_operation ()
18
- {
19
- static_assert (TILE_M % (BATCH_M * PIPE_M) == 0 );
20
- static_assert (TILE_N % (BATCH_N * PIPE_N) == 0 );
21
-
22
- constexpr int M = TILE_M / (BATCH_M * PIPE_M);
23
- constexpr int N = TILE_N / (BATCH_N * PIPE_N);
17
+ template <int tile_m = TILE_M,
18
+ int tile_n = TILE_N,
19
+ int batch_m = BATCH_M,
20
+ int batch_n = BATCH_N,
21
+ int pipe_m = PIPE_M,
22
+ int pipe_n = PIPE_N>
23
+ struct select_gmma_operation {
24
+ static constexpr int M = tile_m / (batch_m * pipe_m);
25
+ static constexpr int N = tile_n / (batch_n * pipe_n);
24
26
27
+ static_assert (tile_m % (batch_m * pipe_m) == 0 );
28
+ static_assert (tile_n % (batch_n * pipe_n) == 0 );
25
29
static_assert (M % 64 == 0 );
26
30
27
- using namespace cute ::SM90::GMMA;
28
-
29
- if constexpr (N % 256 == 0 ) {
30
- return MMA_64x256x32_F32E4M3E4M3_SS_TN<>{};
31
- }
32
- else if constexpr (N % 224 == 0 ) {
33
- return MMA_64x224x32_F32E4M3E4M3_SS_TN<>{};
34
- }
35
- else if constexpr (N % 192 == 0 ) {
36
- return MMA_64x192x32_F32E4M3E4M3_SS_TN<>{};
37
- }
38
- else if constexpr (N % 160 == 0 ) {
39
- return MMA_64x160x32_F32E4M3E4M3_SS_TN<>{};
40
- }
41
- else if constexpr (N % 128 == 0 ) {
42
- return MMA_64x128x32_F32E4M3E4M3_SS_TN<>{};
43
- }
44
- else if constexpr (N % 96 == 0 ) {
45
- return MMA_64x96x32_F32E4M3E4M3_SS_TN<>{};
46
- }
47
- else if constexpr (N % 64 == 0 ) {
48
- return MMA_64x64x32_F32E4M3E4M3_SS_TN<>{};
49
- }
50
- else {
51
- static_assert (N == 0 , " unsupported configuration" );
52
- }
53
- }
31
+ using type = std::conditional_t <
32
+ N % 256 == 0 ,
33
+ cute::SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN<>,
34
+ std::conditional_t <
35
+ N % 224 == 0 ,
36
+ cute::SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN<>,
37
+ std::conditional_t <
38
+ N % 192 == 0 ,
39
+ cute::SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>,
40
+ std::conditional_t <
41
+ N % 160 == 0 ,
42
+ cute::SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN<>,
43
+ std::conditional_t <
44
+ N % 128 == 0 ,
45
+ cute::SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN<>,
46
+ std::conditional_t <N % 96 == 0 ,
47
+ cute::SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>,
48
+ std::conditional_t <N % 64 == 0 ,
49
+ cute::SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN<>,
50
+ void >>>>>>>;
51
+ static_assert (!std::is_same_v<type, void >, " unsupported configuration" );
52
+ };
54
53
55
- using Operation = decltype ( select_gmma_operation()) ;
54
+ using Operation = select_gmma_operation<>::type ;
56
55
57
56
static constexpr typename cute::MMA_Traits<Operation>::Shape_MNK OP_Shape{};
58
57
@@ -242,11 +241,11 @@ struct ScaledGmmaFP8_TN {
242
241
int n = ((i_n * PIPE_N) + p_n * BATCH_N) + b_n;
243
242
func (frag[i_m][i_n][p_m][p_n][b_m][b_n], m, n);
244
243
} // BATCH_N
245
- } // BATCH_M
246
- } // PIPE_N
247
- } // PIPE_M
248
- } // ITER_N
249
- } // ITER_M
244
+ } // BATCH_M
245
+ } // PIPE_N
246
+ } // PIPE_M
247
+ } // ITER_N
248
+ } // ITER_M
250
249
}
251
250
252
251
template <class Frag , class Func >
0 commit comments