@@ -40,11 +40,11 @@ def is_swizzle_128b(self) -> bool:
4040
4141 def swizzle_byte_size (self ) -> int :
4242 if self .is_swizzle_32b ():
43- return 32 // 8
43+ return 32
4444 elif self .is_swizzle_64b ():
45- return 64 // 8
45+ return 64
4646 elif self .is_swizzle_128b ():
47- return 128 // 8
47+ return 128
4848 else :
4949 return 1
5050
@@ -203,12 +203,35 @@ def wgmma(self,
203203 if b_is_k_major :
204204 b_leading_byte_offset = 16
205205 else :
206- # MN Major
206+ # MN Major, K * N
207207 # LBO represents the distance between two atoms along the N dimension
208208 # SBO represents the distance between two atoms along the K dimension
209- b_leading_byte_offset = b_swizzle_mode .swizzle_atom_size ()
210- b_stride_byte_offset = 8 * n_dim * elems_in_bytes
211-
209+ b_n_axis_atoms = n_dim // (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
210+ if b_n_axis_atoms <= 1 :
211+ b_leading_byte_offset = 0
212+ else :
213+ b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
214+
215+ if b_n_axis_atoms <= 1 :
216+ b_stride_byte_offset = 8 * elems_in_bytes * n_dim
217+ else :
218+ b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
219+
220+
221+ print (f"a_leading_byte_offset: { a_leading_byte_offset >> 4 } " )
222+ print (f"a_stride_byte_offset: { a_stride_byte_offset >> 4 } " )
223+
224+ print (f"b_swizzle_atom_size: { b_swizzle_mode .swizzle_atom_size ()} " )
225+ print (f"b_swizzle_byte_size: { b_swizzle_mode .swizzle_byte_size ()} " )
226+ print (f"m_dim: { m_dim } " )
227+ print (f"n_dim: { n_dim } " )
228+ print (f"k_dim: { k_dim } " )
229+ print (f"micro_size_k: { micro_size_k } " )
230+ print (f"a_leading_byte_offset: { a_leading_byte_offset } " )
231+ print (f"a_stride_byte_offset: { a_stride_byte_offset } " )
232+ print (f"b_leading_byte_offset: { b_leading_byte_offset } " )
233+ print (f"b_stride_byte_offset: { b_stride_byte_offset } " )
234+ # exit()
212235 @T .macro
213236 def _warp_mma (A_buf , B_buf , C_local_buf ):
214237 desc_a = T .alloc_descriptor ()
@@ -222,7 +245,7 @@ def _warp_mma(A_buf, B_buf, C_local_buf):
222245 k_dim_offset = ki * micro_size_k
223246 A_offset = i * 64 * A_buf .shape [
224247 - 1 ] + k_dim_offset if a_is_k_major else ki * micro_size_k * 64 + i * 64 * k_dim
225- B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf . shape [ - 1 ]
248+ B_offset = k_dim_offset if b_is_k_major else k_dim_offset * ( b_swizzle_mode . swizzle_byte_size () // elems_in_bytes )
226249 C_offset = i * warp_cols * local_size_out # 4 warps as an unit
227250 T .ptx_wgmma_ss (accum_dtype , wgmma_prefix , a_is_k_major ,
228251 b_is_k_major , a_dtype_abbrv , b_dtype_abbrv ,
@@ -273,8 +296,16 @@ def wgmma_rs(self,
273296 # MN Major
274297 # LBO represents the distance between two atoms along the N dimension
275298 # SBO represents the distance between two atoms along the K dimension
276- b_leading_byte_offset = b_swizzle_mode .swizzle_atom_size ()
277- b_stride_byte_offset = 8 * n_dim * elems_in_bytes
299+ b_n_axis_atoms = n_dim // (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
300+ if b_n_axis_atoms <= 1 :
301+ b_leading_byte_offset = 0
302+ else :
303+ b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
304+
305+ if b_n_axis_atoms <= 1 :
306+ b_stride_byte_offset = 8 * elems_in_bytes * n_dim
307+ else :
308+ b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode .swizzle_byte_size () // elems_in_bytes )
278309
279310 @T .macro
280311 def _warp_mma (A_buf , B_buf , C_local_buf ):
0 commit comments