Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions tilelang/intrinsics/mma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer | BufferRegion, k
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
A_other = [r.min for r in A_region.region[:-2]]

@T.macro
def _warp_ld_a_fp64(
Expand All @@ -267,9 +268,9 @@ def _warp_ld_a_fp64(
mi = tx // micro_size_k
mk = tx % micro_size_k
if a_transposed:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
A_local_buf[i * local_size_a] = A_buf[tuple(A_other) + (A_base0 + wk + mk, A_base1 + wi + mi)]
else:
A_local_buf[i * local_size_a] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
A_local_buf[i * local_size_a] = A_buf[tuple(A_other) + (A_base0 + wi + mi, A_base1 + wk + mk)]

return _warp_ld_a_fp64(A_local_buf, A_region, ki, thread_binding, rk)

Expand Down Expand Up @@ -304,6 +305,7 @@ def mma_load_layout(i, j):
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
A_other = [r.min for r in A_region.region[:-2]]
A_stride_last = A_buf.shape[-1]

@T.macro
Expand All @@ -321,7 +323,11 @@ def _warp_ldmatrix_a(
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
A_shared_buf_elem = A_buf[A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[A_base0 + wi, A_base1 + wk]
A_shared_buf_elem = (
A_buf[tuple(A_other) + (A_base0 + wk, A_base1 + wi)]
if a_transposed
else A_buf[tuple(A_other) + (A_base0 + wi, A_base1 + wk)]
)

if ldmatrix_available:
T.ptx_ldmatrix(
Expand All @@ -338,9 +344,9 @@ def _warp_ldmatrix_a(
for j in T.serial(local_size_a):
mi, mk = mma_load_layout(tx, j)
if a_transposed:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wk + mk, A_base1 + wi + mi]
A_local_buf[i * local_size_a + j] = A_buf[tuple(A_other) + (A_base0 + wk + mk, A_base1 + wi + mi)]
else:
A_local_buf[i * local_size_a + j] = A_buf[A_base0 + wi + mi, A_base1 + wk + mk]
A_local_buf[i * local_size_a + j] = A_buf[tuple(A_other) + (A_base0 + wi + mi, A_base1 + wk + mk)]

return _warp_ldmatrix_a(A_local_buf, A_region, ki, thread_binding, rk)

Expand All @@ -361,6 +367,7 @@ def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer | BufferRegion, k
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_other = [r.min for r in B_region.region[:-2]]

@T.macro
def _warp_ld_b_fp64(
Expand All @@ -377,9 +384,9 @@ def _warp_ld_b_fp64(
mi = tx // micro_size_k
mk = tx % micro_size_k
if b_transposed:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
B_local_buf[j * local_size_b] = B_buf[tuple(B_other) + (B_base0 + wi + mi, B_base1 + wk + mk)]
else:
B_local_buf[j * local_size_b] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
B_local_buf[j * local_size_b] = B_buf[tuple(B_other) + (B_base0 + wk + mk, B_base1 + wi + mi)]

return _warp_ld_b_fp64(B_local_buf, B_region, ki, thread_binding, rk)

Expand All @@ -398,6 +405,7 @@ def _warp_ld_b_fp64(
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
B_other = [r.min for r in B_region.region[:-2]]
B_stride_last = B_buf.shape[-1]
replicate_b = self.n_dim == 16
# ldmatrix cannot be used for int8 + trans case.
Expand Down Expand Up @@ -436,7 +444,11 @@ def _warp_ldmatrix_b(
)

if ldmatrix_available:
B_shared_buf_elem = B_buf[B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[B_base0 + wk, B_base1 + wi]
B_shared_buf_elem = (
B_buf[tuple(B_other) + (B_base0 + wi, B_base1 + wk)]
if b_transposed
else B_buf[tuple(B_other) + (B_base0 + wk, B_base1 + wi)]
)

T.ptx_ldmatrix(
b_dtype,
Expand All @@ -455,9 +467,9 @@ def _warp_ldmatrix_b(
for j in T.serial(local_size_b):
mi, mk = mma_load_layout(tx, j)
if b_transposed:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wi + mi, B_base1 + wk + mk]
B_local_buf[i * local_size_b + j] = B_buf[tuple(B_other) + (B_base0 + wi + mi, B_base1 + wk + mk)]
else:
B_local_buf[i * local_size_b + j] = B_buf[B_base0 + wk + mk, B_base1 + wi + mi]
B_local_buf[i * local_size_b + j] = B_buf[tuple(B_other) + (B_base0 + wk + mk, B_base1 + wi + mi)]

return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)

Expand Down
Loading