Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)


if __name__ == "__main__":
Expand Down
111 changes: 39 additions & 72 deletions testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)
Expand All @@ -22,16 +21,8 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
b_preshuffle=False,
b_g2l_load=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

Expand All @@ -47,15 +38,14 @@ def tl_matmul(
if b_preshuffle:
block_row_warps = 1
block_col_warps = 4
warp_row_tiles = 128
warp_col_tiles = 32
warp_row_tiles = 64
warp_col_tiles = 16

chunk = 32 * k_pack
chunk = 256 * k_pack

pack_size_k = micro_size_k * k_pack

shared_scope = "shared"
cache_write_shared = False

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
Expand All @@ -68,6 +58,7 @@ def tl_matmul(
pack_size_k, micro_size_y)
else:
B_shape = (N, K) if b_transposed else (K, N)

A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
Expand All @@ -76,12 +67,6 @@ def tl_matmul(
micro_size_y)
else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
Expand All @@ -92,7 +77,7 @@ def tl_matmul(
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter = MatrixCoreIntrinEmitter(
mfma_emitter = MatrixCorePreshuffleIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
Expand All @@ -117,7 +102,6 @@ def main(

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
Expand All @@ -126,12 +110,15 @@ def main(
A_shared: make_swizzle_layout(A_shared),
})

num_ko = K // block_K
num_ki = block_K // (k_pack * micro_size_k)

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=0):
for ko in T.Pipelined(num_ko, num_stages=0):

# Load A into shared memory
if a_transposed:
Expand All @@ -140,7 +127,7 @@ def main(
T.copy(A[by * block_M, ko * block_K], A_shared)

# Load B into shared memory
if b_preshuffle:
if b_g2l_load is False:
if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
block_K // pack_size_k, micro_size_y,
Expand All @@ -153,53 +140,37 @@ def main(
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
else:
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)

for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
for ki in T.serial(0, num_ki):

# Load A into fragment
# Load A S2L
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)

# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
if b_g2l_load:
# Load B G2L
mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx)
else:
# Load B S2L
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)

# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local)

# Perform STMatrix
if cache_write_shared:
mfma_emitter.stmatrix(
C_local,
C_shared,
)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
else:
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)

return main

Expand Down Expand Up @@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False):
b_preshuffle=False,
b_g2l_load=False):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack, b_preshuffle)
k_pack, b_preshuffle, b_g2l_load)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
Expand Down Expand Up @@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,

print(C)
print(ref_c)

torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)

assert_tl_matmul_correctness(
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)

assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
128,
256,
256,
512,
"int8",
"int32",
b_transposed=False,
Expand Down
Loading
Loading