1010 make_full_bank_swizzled_layout ,
1111 make_half_bank_swizzled_layout ,
1212 make_quarter_bank_swizzled_layout ,
13+ make_linear_layout ,
1314)
1415from tvm .runtime import convert
1516from tilelang .intrinsics .mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a ,
@@ -131,13 +132,20 @@ def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
131132 self .micro_size_k = k_dim
132133
133134 def _determinate_swizzle_mode (self , buffer : Buffer , layout : Layout ) -> SwizzleMode :
135+ # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
136+ mat_stride = int (buffer .shape [- 2 ])
137+ mat_continuous = int (buffer .shape [- 1 ])
138+ element_size = DataType (buffer .dtype ).bits
139+ print (f"_determinate_swizzle_mode mat_stride: { mat_stride } , mat_continuous: { mat_continuous } , element_size: { element_size } " )
134140 if layout is None :
135141 return SwizzleMode .NONE
136- elif layout .is_equal (make_quarter_bank_swizzled_layout (buffer )):
142+ elif layout .is_equal (make_linear_layout (mat_stride , mat_continuous )):
143+ return SwizzleMode .NONE
144+ elif layout .is_equal (make_quarter_bank_swizzled_layout (mat_stride , mat_continuous , element_size )):
137145 return SwizzleMode .SWIZZLE_32B
138- elif layout .is_equal (make_half_bank_swizzled_layout (buffer )):
146+ elif layout .is_equal (make_half_bank_swizzled_layout (mat_stride , mat_continuous , element_size )):
139147 return SwizzleMode .SWIZZLE_64B
140- elif layout .is_equal (make_full_bank_swizzled_layout (buffer )):
148+ elif layout .is_equal (make_full_bank_swizzled_layout (mat_stride , mat_continuous , element_size )):
141149 return SwizzleMode .SWIZZLE_128B
142150 else :
143151 raise ValueError (f"Unsupported swizzle mode: { layout } " )
@@ -173,7 +181,11 @@ def wgmma(self,
173181 a_swizzle_mode = self ._determinate_swizzle_mode (A_buf , self .a_shared_layout )
174182 b_swizzle_mode = self ._determinate_swizzle_mode (B_buf , self .b_shared_layout )
175183
176- elems_in_bytes = DataType (self .a_dtype ).bits // 8
184+ elems_in_bits = DataType (self .a_dtype ).bits
185+ elems_in_bytes = elems_in_bits // 8
186+
187+ a_swizzle_atom_elems = a_swizzle_mode .swizzle_byte_size () // elems_in_bytes
188+ b_swizzle_atom_elems = n_dim if b_swizzle_mode .is_none () else b_swizzle_mode .swizzle_byte_size () // elems_in_bytes
177189
178190 # by default, we utilize non-swizzle layout offset
179191 a_leading_byte_offset = (8 * 8 * elems_in_bytes ) if a_is_k_major else (8 * m_dim *
@@ -186,52 +198,59 @@ def wgmma(self,
186198 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
187199 if a_is_k_major :
188200 a_leading_byte_offset = 16
201+ a_stride_byte_offset = 8 * a_swizzle_mode .swizzle_byte_size ()
189202 else :
190203 # MN Major
191204 # LBO represents the distance between two atoms along the M dimension
192205 # SBO represents the distance between two atoms along the K dimension
193- a_leading_byte_offset = a_swizzle_mode .swizzle_atom_size ()
194- a_stride_byte_offset = 8 * 64 * elems_in_bytes
206+ a_m_axis_atoms = m_dim // a_swizzle_atom_elems
207+ if a_m_axis_atoms <= 1 :
208+ a_leading_byte_offset = 0
209+ else :
210+ a_leading_byte_offset = 8 * a_swizzle_mode .swizzle_atom_size () * (a_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
211+
212+ if a_m_axis_atoms <= 1 :
213+ a_stride_byte_offset = 8 * elems_in_bytes * m_dim
214+ else :
215+ a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
195216
196217 b_leading_byte_offset = (8 * 8 * elems_in_bytes ) if b_is_k_major else (8 * n_dim *
197218 elems_in_bytes )
198- b_stride_byte_offset = (8 * k_dim * elems_in_bytes ) if b_is_k_major else (8 * 8 *
199- elems_in_bytes )
219+ b_stride_byte_offset = (8 * k_dim * elems_in_bytes ) if b_is_k_major else (
220+ 0 if n_dim == 8 else (8 * 8 * elems_in_bytes )
221+ )
200222 if not b_swizzle_mode .is_none ():
201223 # swizzle mode doesn't require LBO/SBO to be 1
202224 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
203225 if b_is_k_major :
204226 b_leading_byte_offset = 16
227+ b_stride_byte_offset = 8 * b_swizzle_mode .swizzle_byte_size ()
205228 else :
206229 # MN Major, K * N
207230 # LBO represents the distance between two atoms along the N dimension
208231 # SBO represents the distance between two atoms along the K dimension
209- b_n_axis_atoms = n_dim // ( b_swizzle_mode . swizzle_byte_size () // elems_in_bytes )
232+ b_n_axis_atoms = n_dim // b_swizzle_atom_elems
210233 if b_n_axis_atoms <= 1 :
211234 b_leading_byte_offset = 0
212235 else :
213236 b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
214-
215237 if b_n_axis_atoms <= 1 :
216238 b_stride_byte_offset = 8 * elems_in_bytes * n_dim
217239 else :
218- b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
219-
220-
240+ b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
221241 print (f"a_leading_byte_offset: { a_leading_byte_offset >> 4 } " )
222242 print (f"a_stride_byte_offset: { a_stride_byte_offset >> 4 } " )
243+ print (f"b_leading_byte_offset: { b_leading_byte_offset >> 4 } " )
244+ print (f"b_stride_byte_offset: { b_stride_byte_offset >> 4 } " )
223245
224246 print (f"b_swizzle_atom_size: { b_swizzle_mode .swizzle_atom_size ()} " )
225247 print (f"b_swizzle_byte_size: { b_swizzle_mode .swizzle_byte_size ()} " )
226- print (f"m_dim: { m_dim } " )
227- print (f"n_dim: { n_dim } " )
228- print (f"k_dim: { k_dim } " )
229- print (f"micro_size_k: { micro_size_k } " )
230- print (f"a_leading_byte_offset: { a_leading_byte_offset } " )
231- print (f"a_stride_byte_offset: { a_stride_byte_offset } " )
232- print (f"b_leading_byte_offset: { b_leading_byte_offset } " )
233- print (f"b_stride_byte_offset: { b_stride_byte_offset } " )
234- # exit()
248+
249+ # for example, if [n, k] where k is 128, we should split it into 2 atoms
250+ # where max specially handles the case when n_dim is 8.
251+ ak_atom_size = max (a_swizzle_atom_elems // micro_size_k , 1 )
252+ bk_atom_size = max (b_swizzle_atom_elems // micro_size_k , 1 )
253+
235254 @T .macro
236255 def _warp_mma (A_buf , B_buf , C_local_buf ):
237256 desc_a = T .alloc_descriptor ()
@@ -242,10 +261,8 @@ def _warp_mma(A_buf, B_buf, C_local_buf):
242261 int (b_leading_byte_offset >> 4 ), int (b_stride_byte_offset >> 4 ))
243262 for ki in T .serial (0 , (k_dim // micro_size_k )):
244263 for i in T .serial (m_dim // 64 ):
245- k_dim_offset = ki * micro_size_k
246- A_offset = i * 64 * A_buf .shape [
247- - 1 ] + k_dim_offset if a_is_k_major else ki * micro_size_k * 64 + i * 64 * k_dim
248- B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
264+ A_offset = (ki % ak_atom_size ) * micro_size_k + i * 64 * a_swizzle_atom_elems + (ki // ak_atom_size ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
265+ B_offset = (ki // bk_atom_size ) * n_dim * b_swizzle_atom_elems + (ki % bk_atom_size ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
249266 C_offset = i * warp_cols * local_size_out # 4 warps as an unit
250267 T .ptx_wgmma_ss (accum_dtype , wgmma_prefix , a_is_k_major ,
251268 b_is_k_major , a_dtype_abbrv , b_dtype_abbrv ,
@@ -300,7 +317,7 @@ def wgmma_rs(self,
300317 if b_n_axis_atoms <= 1 :
301318 b_leading_byte_offset = 0
302319 else :
303- b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
320+ b_leading_byte_offset = 8 * b_swizzle_mode . swizzle_atom_size () * ( b_swizzle_mode . swizzle_byte_size () // elems_in_bytes )
304321
305322 if b_n_axis_atoms <= 1 :
306323 b_stride_byte_offset = 8 * elems_in_bytes * n_dim
0 commit comments