11
11
12
12
namespace turbomind ::gemm {
13
13
14
+ namespace {
15
+
16
+ template <int tile>
17
+ struct select_gmma_operation ;
18
+ template <>
19
+ struct select_gmma_operation <256 > {
20
+ using type = cute::SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN<>;
21
+ };
22
+ template <>
23
+ struct select_gmma_operation <224 > {
24
+ using type = cute::SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN<>;
25
+ };
26
+ template <>
27
+ struct select_gmma_operation <192 > {
28
+ using type = cute::SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>;
29
+ };
30
+ template <>
31
+ struct select_gmma_operation <160 > {
32
+ using type = cute::SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN<>;
33
+ };
34
+ template <>
35
+ struct select_gmma_operation <128 > {
36
+ using type = cute::SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN<>;
37
+ };
38
+ template <>
39
+ struct select_gmma_operation <96 > {
40
+ using type = cute::SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>;
41
+ };
42
+ template <>
43
+ struct select_gmma_operation <64 > {
44
+ using type = cute::SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN<>;
45
+ };
46
+
47
+ } // namespace
48
+
14
49
template <int TILE_M, int TILE_N, int TILE_K, int BATCH_M, int BATCH_N, int PIPE_M, int PIPE_N>
15
50
struct ScaledGmmaFP8_TN {
16
-
17
- static constexpr auto select_gmma_operation ()
51
+ static constexpr auto select_gmma_size ()
18
52
{
19
53
static_assert (TILE_M % (BATCH_M * PIPE_M) == 0 );
20
54
static_assert (TILE_N % (BATCH_N * PIPE_N) == 0 );
@@ -24,35 +58,33 @@ struct ScaledGmmaFP8_TN {
24
58
25
59
static_assert (M % 64 == 0 );
26
60
27
- using namespace cute ::SM90::GMMA;
28
-
29
61
if constexpr (N % 256 == 0 ) {
30
- return MMA_64x256x32_F32E4M3E4M3_SS_TN<>{} ;
62
+ return 256 ;
31
63
}
32
64
else if constexpr (N % 224 == 0 ) {
33
- return MMA_64x224x32_F32E4M3E4M3_SS_TN<>{} ;
65
+ return 224 ;
34
66
}
35
67
else if constexpr (N % 192 == 0 ) {
36
- return MMA_64x192x32_F32E4M3E4M3_SS_TN<>{} ;
68
+ return 192 ;
37
69
}
38
70
else if constexpr (N % 160 == 0 ) {
39
- return MMA_64x160x32_F32E4M3E4M3_SS_TN<>{} ;
71
+ return 160 ;
40
72
}
41
73
else if constexpr (N % 128 == 0 ) {
42
- return MMA_64x128x32_F32E4M3E4M3_SS_TN<>{} ;
74
+ return 128 ;
43
75
}
44
76
else if constexpr (N % 96 == 0 ) {
45
- return MMA_64x96x32_F32E4M3E4M3_SS_TN<>{} ;
77
+ return 96 ;
46
78
}
47
79
else if constexpr (N % 64 == 0 ) {
48
- return MMA_64x64x32_F32E4M3E4M3_SS_TN<>{} ;
80
+ return 64 ;
49
81
}
50
82
else {
51
83
static_assert (N == 0 , " unsupported configuration" );
52
84
}
53
85
}
54
86
55
- using Operation = decltype ( select_gmma_operation()) ;
87
+ using Operation = typename select_gmma_operation<select_gmma_size()>::type ;
56
88
57
89
static constexpr typename cute::MMA_Traits<Operation>::Shape_MNK OP_Shape{};
58
90
0 commit comments