@@ -5,6 +5,12 @@ tags: [cutlass, cute]
55excerpt : intro for cute mma
66---
77
8+ [ TOC]
9+
10+ ## arch
11+ ### mma
12+ ### copy
13+
814## MMA
915``` cpp
1016struct SM80_16x8x8_F32F16F16F32_TN
@@ -32,6 +38,52 @@ struct SM80_16x8x8_F32F16F16F32_TN
3238 "f"(c0), "f"(c1), "f"(c2), "f"(c3));
3339 }
3440};
41+
42+
43+ // (T32,V1) -> (M8,N8)
44+ using SM80_8x4 = Layout<Shape <Shape < _ 4,_ 8>,_ 1>,
45+ Stride<Stride< _ 8,_ 1>,_ 0>>;
46+ // (T32,V2) -> (M8,N8)
47+ using SM80_8x8_Row = Layout<Shape <Shape < _ 4,_ 8>,_ 2>,
48+ Stride<Stride<_ 16,_ 1>,_ 8>>;
49+ // (T32,V4) -> (M8,N16)
50+ using SM80_8x16_Row = Layout<Shape <Shape < _ 4,_ 8>,_ 4>,
51+ Stride<Stride<_ 32,_ 1>,_ 8>>;
52+ // (T32,V4) -> (M16,N8)
53+ using SM80_16x8_Row = Layout<Shape <Shape < _ 4,_ 8>,Shape < _ 2,_ 2>>,
54+ Stride<Stride<_ 32,_ 1>,Stride<_ 16,_ 8>>>;
55+
56+ ////////////////////////////////////////////
57+ //////// fp16 = fp16 * fp16 + fp16 /////////
58+ ////////////////////////////////////////////
59+ template <>
60+ struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
61+ {
62+ using ValTypeD = half_t;
63+ using ValTypeA = half_t;
64+ using ValTypeB = half_t;
65+ using ValTypeC = half_t;
66+
67+ using Shape_MNK = Shape<_ 16,_ 8,_ 8>;
68+ using ThrID = Layout<_ 32>;
69+ using ALayout = SM80_16x8_Row;
70+ using BLayout = SM80_8x8_Row;
71+ using CLayout = SM80_16x8_Row;
72+ };
73+
74+ //////////////////////////////////////////
75+ /////// fp32 = fp16 * fp16 + fp32 ////////
76+ //////////////////////////////////////////
77+ template <>
78+ struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN>
79+ : MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
80+ {
81+ using ValTypeD = float;
82+ using ValTypeA = half_t;
83+ using ValTypeB = half_t;
84+ using ValTypeC = float;
85+ };
86+
3587```
3688### MMA Operation
3789- Operation 结构体名称
@@ -43,4 +95,168 @@ struct SM80_16x8x8_F32F16F16F32_TN
4395 - F32F16F16F32 分别指四个矩阵操作数的元素类型。MMA 用于计算 D=A*B+C, 对应数据类型从左到右读取(D-F32, A-F16, B-F16, C-F32). 对应 ptx 指令名称为 .f32.f16.f16.f32
4496 - NT 代表 A 矩阵 column major(M-major), B 矩阵 row major(N-major), 对应 ptx 指令为 .col.row.
4597
46- ### MMA_Traits
98+ ### MMA_Traits
99+ ```cpp
100+ template <class MMAOperation, class... MMAOpArgs>
101+ struct MMA_Traits
102+ {
103+ static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation.");
104+ };
105+
106+ template <class D, class A, class B, class C>
107+ struct MMA_Traits<UniversalFMA<D,A,B,C>>
108+ {
109+ using ValTypeD = D;
110+ using ValTypeA = A;
111+ using ValTypeB = B;
112+ using ValTypeC = C;
113+
114+ // Logical shape of the MMA
115+ using Shape_MNK = Shape<_1,_1,_1>;
116+
117+ // Logical thread id (tid) -> tidx
118+ using ThrID = Layout<_1>;
119+
120+ // (Logical thread id (tid), Logical value id (vid)) -> coord
121+
122+ // (tid,vid) -> (m,k)
123+ using ALayout = Layout<Shape<_1,_1>>;
124+ // (tid,vid) -> (n,k)
125+ using BLayout = Layout<Shape<_1,_1>>;
126+ // (tid,vid) -> (m,n)
127+ using CLayout = Layout<Shape<_1,_1>>;
128+ };
129+
130+ // Extract an MMA_Op from an MMA_Traits
131+ template <class MMA_Traits>
132+ struct MMA_Op {};
133+
134+ template <class MMA_Op_Arg, class... Args>
135+ struct MMA_Op<MMA_Traits<MMA_Op_Arg, Args...>> {
136+ using type = MMA_Op_Arg;
137+ };
138+ ```
139+ ### TiledMMA
140+
141+ ## Atom
142+ ### MMA_Atom
143+ ``` cpp
144+ template <class ... Args>
145+ struct MMA_Atom;
146+
147+ template <class MMAOperation >
148+ struct MMA_Atom<MMAOperation > : MMA_Atom<MMA_Traits<MMAOperation >>
149+ {};
150+
151+ template <class MMAOperation , class... Args>
152+ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
153+ : MMA_Traits<MMAOperation, Args...>
154+ {
155+ using MMA_Op = MMAOperation;
156+ using Traits = MMA_Traits<MMAOperation, Args...>;
157+
158+ // Element value types from the MMA_Traits
159+ using ValTypeD = typename Traits::ValTypeD;
160+ using ValTypeA = typename Traits::ValTypeA;
161+ using ValTypeB = typename Traits::ValTypeB;
162+ using ValTypeC = typename Traits::ValTypeC;
163+
164+ // Thr-Val layouts from the MMA_Traits
165+ using Shape_MNK = typename Traits::Shape_MNK;
166+ using ThrID = typename Traits::ThrID;
167+ using LayoutC_TV = typename Traits::CLayout;
168+ using LayoutA_TV = typename Traits::ALayout;
169+ using LayoutB_TV = typename Traits::BLayout;
170+
171+ // Fragment value types from the MMA_Traits (optional, defaults to Val type)
172+ using FrgTypeD = typename detail::FrgTypeC_or_Default<Traits >::type;
173+ using FrgTypeA = typename detail::FrgTypeA_or_Default<Traits >::type;
174+ using FrgTypeB = typename detail::FrgTypeB_or_Default<Traits >::type;
175+ using FrgTypeC = typename detail::FrgTypeC_or_Default<Traits >::type;
176+ };
177+
178+ template <class TiledMMA , class ThrCoord>
179+ struct ThrMMA;
180+
181+ // @tparam MMA_Atom The MMA_Atom to use in the TiledMMA
182+ // @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed.
183+ // @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom.
184+ template <class MMA_Atom,
185+ class AtomLayoutMNK,
186+ class PermutationMNK = Tile<Underscore,Underscore,Underscore>>
187+ struct TiledMMA : MMA_Atom
188+ {
189+ using Atom = MMA_Atom;
190+ using AtomShape_MNK = typename MMA_Atom::Shape_MNK;
191+ using AtomThrID = typename MMA_Atom::ThrID;
192+ using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV;
193+ using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV;
194+ using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV;
195+
196+ static_assert ( rank_v<AtomLayoutMNK > == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
197+ static_assert( rank_v<PermutationMNK > == 3, "TiledMMA requires rank-3 PermutationMNK");
198+ static_assert( is_tuple<PermutationMNK >::value, "TiledMMA requires independent permutations of MNK.");
199+ static_assert(is_static<PermutationMNK >::value, "TiledMMA requires static permutations of MNK.");
200+
201+ using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
202+ ThrLayoutVMNK thr_layout_vmnk_ ;
203+
204+ ...
205+ };
206+
207+ template <class TiledMMA, class ThrVMNK>
208+ struct ThrMMA : TiledMMA
209+ {
210+ ...
211+ };
212+ ```
213+
214+ - make_tiled_mma
215+
216+ ### Copy_Atom
217+ ```cpp
218+ template <class... Args>
219+ struct Copy_Atom;
220+
221+ template <class CopyOperation, class CopyInternalType>
222+ struct Copy_Atom<CopyOperation, CopyInternalType> : Copy_Atom<Copy_Traits<CopyOperation>, CopyInternalType>
223+ {};
224+
225+ template <class... Args, class CopyInternalType>
226+ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
227+ : Copy_Traits<Args...>
228+ {
229+ ...
230+ };
231+
232+ template <class TiledCopy, class ThrIdx>
233+ struct ThrCopy;
234+
235+ template <class Copy_Atom,
236+ class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
237+ class ShapeTiler_MN> // coord space
238+ struct TiledCopy : Copy_Atom
239+ {
240+ ...
241+ };
242+
243+ template <class TiledCopy, class ThrIdx>
244+ struct ThrCopy
245+ {
246+ ...
247+ };
248+
249+ template <class... Args,
250+ class LayoutCopy_TV,
251+ class Tiler>
252+ CUTE_HOST_DEVICE
253+ auto
254+ make_tiled_copy_impl(Copy_Atom<Args...> const& atom,
255+ LayoutCopy_TV const&,
256+ Tiler const&)
257+ {
258+ return TiledCopy<Copy_Atom<Args...>, LayoutCopy_TV, Tiler>{atom};
259+ }
260+ ```
261+
262+ - make_tiled_copy
0 commit comments