diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py new file mode 100644 index 000000000..3d8a7fd14 --- /dev/null +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -0,0 +1,321 @@ +import torch +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.intrinsics.mfma_macro_generator import ( + MatrixCoreIntrinEmitter,) +from tilelang.transform import simplify_prim_func + +tilelang.testing.set_random_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if in_dtype in {"float8_e4m3fnuz", "int8"}: + micro_size_k = 32 + + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + + # for preshuffle_b, warp_layout = {1, 4} + if b_preshuffle: + block_row_warps = 1 + block_col_warps = 4 + warp_row_tiles = 128 + warp_col_tiles = 32 + + chunk = 32 * k_pack + + pack_size_k = micro_size_k * k_pack + + shared_scope = "shared" + cache_write_shared = False + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (K, M) if a_transposed else (M, K) + if b_preshuffle: + B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, + pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, + pack_size_k, micro_size_y) + else: + B_shape = (N, K) if b_transposed else (K, N) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + if b_preshuffle: + B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, + pack_size_k) if b_transposed else (block_K // pack_size_k, + block_N // micro_size_y, pack_size_k, + micro_size_y) + else: + B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size + local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=0): + + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Load B into shared memory + if b_preshuffle: + if b_transposed: + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, + block_K // pack_size_k, micro_size_y, + pack_size_k): + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, + ko * block_K // pack_size_k + k, jj, kk] + else: + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, + block_N // micro_size_y, pack_size_k, + micro_size_y): + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, + bx * block_N // micro_size_y + j, kk, jj] + else: + if b_transposed: + T.copy(B[bx * block_N, ko * block_K], B_shared) + else: + T.copy(B[ko * block_K, bx * block_N], B_shared) + + for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): + + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local) + + # Perform STMatrix + if cache_write_shared: + mfma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + mfma_emitter.stmatrix( + C_local, + C, + pid_m=by, + pid_n=bx, + ) + + return main + + +def shuffle_weight( + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, +) -> torch.Tensor: + IN, IK = layout + BK = IK * k_pack + BN = IN + + N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) + assert N % BN == 0 + assert K % BK == 0 + + x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) + x = x.permute(0, 2, 1, 3) + return x.contiguous() + + +def assert_tl_matmul_correctness(M, + N, + K, + in_dtype, + out_dtype, + accum_dtype="float32", + a_transposed=False, + b_transposed=True, + k_pack=1, + b_preshuffle=False): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, + k_pack, b_preshuffle) + print(matmul) + kernel = tilelang.compile(matmul) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + if in_dtype == "int8": + A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + else: + A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + B_preshuffle = B + if b_preshuffle: + B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed) + kernel(A, B_preshuffle, C) + else: + kernel(A, B, C) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler() + + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.to(torch.float32), + B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.Tto(torch.float32), + B.to(torch.float32)).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), + B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) + + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_rocm +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) + + assert_tl_matmul_correctness( + 128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) + + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness( + 128, + 256, + 256, + "int8", + "int32", + b_transposed=False, + accum_dtype="int32", + k_pack=2, + b_preshuffle=True) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 7758cdddc..195961144 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -53,6 +53,7 @@ def __init__( num_elems_per_byte: int = 1, k_pack: Optional[int] = None, is_m_first: Optional[bool] = False, + b_preshuffle: Optional[bool] = False, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -72,6 +73,7 @@ def __init__( self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) self._initialize_k_pack(k_pack) self._initialize_is_m_first(is_m_first) + self._initialize_b_preshuffle(b_preshuffle) self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y @@ -141,6 +143,10 @@ def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): if is_m_first is not None: self.is_m_first = is_m_first + def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + def get_ldmatrix_index_map(self, is_b=False): from .mfma_layout import ( shared_16x4_to_local_64x1_layout_A, @@ -288,26 +294,51 @@ def _warp_ldmatrix_b( ): tx, warp_n, _ = self.extract_thread_binding(thread_binding) - if is_transposed: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - warp_n * warp_col_tiles + j * micro_size_y, - rk * chunk + ki * (k_pack * micro_size_k), - ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, - r + col] + # 4 dim + if self.b_preshuffle: + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, + row, + col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_n * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, + row, + col] else: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - rk * chunk + ki * (k_pack * micro_size_k), - warp_n * warp_col_tiles + j * micro_size_y, - ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, - r + col] + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * (k_pack * micro_size_k), + ) + B_local_buf[j * k_pack * local_size_b + + local_id] = B_shared_buf[l + row, r + col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * chunk + ki * (k_pack * micro_size_k), + warp_n * warp_col_tiles + j * micro_size_y, + ) + B_local_buf[j * k_pack * local_size_b + + local_id] = B_shared_buf[l + row, r + col] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)