Skip to content

Commit ce9f545

Browse files
committed
Remove debug logging for wgmma assembly code and refactor swizzle byte size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.
1 parent 51fcf15 commit ce9f545

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

src/target/codegen_cuda.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1358,7 +1358,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13581358
replacer.register_rule("(C)", c_ref + " + " + c_offset);
13591359
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
13601360
wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
1361-
LOG(INFO) << "wgmma_asm_code: " << wgmma_asm_code;
13621361
this->stream << wgmma_asm_code;
13631362
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
13641363
// arg 0: dtype

tilelang/intrinsics/wgmma_macro_generator.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def is_swizzle_128b(self) -> bool:
4040

4141
def swizzle_byte_size(self) -> int:
4242
if self.is_swizzle_32b():
43-
return 32 // 8
43+
return 32
4444
elif self.is_swizzle_64b():
45-
return 64 // 8
45+
return 64
4646
elif self.is_swizzle_128b():
47-
return 128 // 8
47+
return 128
4848
else:
4949
return 1
5050

@@ -203,12 +203,35 @@ def wgmma(self,
203203
if b_is_k_major:
204204
b_leading_byte_offset = 16
205205
else:
206-
# MN Major
206+
# MN Major, K * N
207207
# LBO represents the distance between two atoms along the N dimension
208208
# SBO represents the distance between two atoms along the K dimension
209-
b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
210-
b_stride_byte_offset = 8 * n_dim * elems_in_bytes
211-
209+
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
210+
if b_n_axis_atoms <= 1:
211+
b_leading_byte_offset = 0
212+
else:
213+
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
214+
215+
if b_n_axis_atoms <= 1:
216+
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
217+
else:
218+
b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
219+
220+
221+
print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}")
222+
print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}")
223+
224+
print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}")
225+
print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")
226+
print(f"m_dim: {m_dim}")
227+
print(f"n_dim: {n_dim}")
228+
print(f"k_dim: {k_dim}")
229+
print(f"micro_size_k: {micro_size_k}")
230+
print(f"a_leading_byte_offset: {a_leading_byte_offset}")
231+
print(f"a_stride_byte_offset: {a_stride_byte_offset}")
232+
print(f"b_leading_byte_offset: {b_leading_byte_offset}")
233+
print(f"b_stride_byte_offset: {b_stride_byte_offset}")
234+
# exit()
212235
@T.macro
213236
def _warp_mma(A_buf, B_buf, C_local_buf):
214237
desc_a = T.alloc_descriptor()
@@ -222,7 +245,7 @@ def _warp_mma(A_buf, B_buf, C_local_buf):
222245
k_dim_offset = ki * micro_size_k
223246
A_offset = i * 64 * A_buf.shape[
224247
-1] + k_dim_offset if a_is_k_major else ki * micro_size_k * 64 + i * 64 * k_dim
225-
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1]
248+
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
226249
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
227250
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major,
228251
b_is_k_major, a_dtype_abbrv, b_dtype_abbrv,
@@ -273,8 +296,16 @@ def wgmma_rs(self,
273296
# MN Major
274297
# LBO represents the distance between two atoms along the N dimension
275298
# SBO represents the distance between two atoms along the K dimension
276-
b_leading_byte_offset = b_swizzle_mode.swizzle_atom_size()
277-
b_stride_byte_offset = 8 * n_dim * elems_in_bytes
299+
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
300+
if b_n_axis_atoms <= 1:
301+
b_leading_byte_offset = 0
302+
else:
303+
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
304+
305+
if b_n_axis_atoms <= 1:
306+
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
307+
else:
308+
b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
278309

279310
@T.macro
280311
def _warp_mma(A_buf, B_buf, C_local_buf):

0 commit comments

Comments
 (0)