From 662c7abaefc410afbe860cc56f8cbb463819c388 Mon Sep 17 00:00:00 2001 From: xwhzz Date: Thu, 5 Feb 2026 02:40:46 +0000 Subject: [PATCH 1/2] [BugFix] Update buffer access in TensorCoreIntrinEmitter to handle variable dimensions correctly --- tilelang/intrinsics/mma_macro_generator.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 4f8a0ff1a..ca27572b2 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -251,6 +251,7 @@ def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, k A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min + A_other = [r.min for r in A_region.region[:-2]] @T.macro def _warp_ld_a_fp64( @@ -267,9 +268,9 @@ def _warp_ld_a_fp64( mi = tx // micro_size_k mk = tx % micro_size_k if a_transposed: - A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] else: - A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk) @@ -304,6 +305,7 @@ def mma_load_layout(i, j): A_buf = A_region.buffer A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min + A_other = [r.min for r in A_region.region[:-2]] A_stride_last = A_buf.shape[-1] @T.macro @@ -321,7 +323,7 @@ def _warp_ldmatrix_a( for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk] + A_shared_buf_elem = A_buf[*A_other, A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[*A_other, A_base0 + wi, A_base1 + wk] if ldmatrix_available: T.ptx_ldmatrix( @@ -338,9 +340,9 @@ def _warp_ldmatrix_a( for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) if a_transposed: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] else: - A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) @@ -361,6 +363,7 @@ def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, k B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min + B_other = [r.min for r in B_region.region[:-2]] @T.macro def _warp_ld_b_fp64( @@ -377,9 +380,9 @@ def _warp_ld_b_fp64( mi = tx // micro_size_k mk = tx % micro_size_k if b_transposed: - B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] else: - B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk) @@ -398,6 +401,7 @@ def _warp_ld_b_fp64( B_buf = B_region.buffer B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min + B_other = [r.min for r in B_region.region[:-2]] B_stride_last = B_buf.shape[-1] replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. @@ -436,7 +440,7 @@ def _warp_ldmatrix_b( ) if ldmatrix_available: - B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi] + B_shared_buf_elem = B_buf[*B_other, B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[*B_other, B_base0 + wk, B_base1 + wi] T.ptx_ldmatrix( b_dtype, @@ -455,9 +459,9 @@ def _warp_ldmatrix_b( for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) if b_transposed: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] else: - B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) From ae984405342b89f68e67718450c151968e32bbb7 Mon Sep 17 00:00:00 2001 From: xwhzz Date: Thu, 5 Feb 2026 02:52:29 +0000 Subject: [PATCH 2/2] lint fix --- tilelang/intrinsics/mma_macro_generator.py | 28 ++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index ca27572b2..14831050f 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -268,9 +268,9 @@ def _warp_ld_a_fp64( mi = tx // micro_size_k mk = tx % micro_size_k if a_transposed: - A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a] = A_buf[tuple(A_other) + (A_base0 + wk + mk, A_base1 + wi + mi)] else: - A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a] = A_buf[tuple(A_other) + (A_base0 + wi + mi, A_base1 + wk + mk)] return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk) @@ -323,7 +323,11 @@ def _warp_ldmatrix_a( for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = A_buf[*A_other, A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[*A_other, A_base0 + wi, A_base1 + wk] + A_shared_buf_elem = ( + A_buf[tuple(A_other) + (A_base0 + wk, A_base1 + wi)] + if a_transposed + else A_buf[tuple(A_other) + (A_base0 + wi, A_base1 + wk)] + ) if ldmatrix_available: T.ptx_ldmatrix( @@ -340,9 +344,9 @@ def _warp_ldmatrix_a( for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) if a_transposed: - A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a + j] = A_buf[tuple(A_other) + (A_base0 + wk + mk, A_base1 + wi + mi)] else: - A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a + j] = A_buf[tuple(A_other) + (A_base0 + wi + mi, A_base1 + wk + mk)] return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk) @@ -380,9 +384,9 @@ def _warp_ld_b_fp64( mi = tx // micro_size_k mk = tx % micro_size_k if b_transposed: - B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[j * local_size_b] = B_buf[tuple(B_other) + (B_base0 + wi + mi, B_base1 + wk + mk)] else: - B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[j * local_size_b] = B_buf[tuple(B_other) + (B_base0 + wk + mk, B_base1 + wi + mi)] return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk) @@ -440,7 +444,11 @@ def _warp_ldmatrix_b( ) if ldmatrix_available: - B_shared_buf_elem = B_buf[*B_other, B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[*B_other, B_base0 + wk, B_base1 + wi] + B_shared_buf_elem = ( + B_buf[tuple(B_other) + (B_base0 + wi, B_base1 + wk)] + if b_transposed + else B_buf[tuple(B_other) + (B_base0 + wk, B_base1 + wi)] + ) T.ptx_ldmatrix( b_dtype, @@ -459,9 +467,9 @@ def _warp_ldmatrix_b( for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) if b_transposed: - B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[tuple(B_other) + (B_base0 + wi + mi, B_base1 + wk + mk)] else: - B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[tuple(B_other) + (B_base0 + wk + mk, B_base1 + wi + mi)] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)