diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index b8690ce08..e2135744e 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -234,6 +234,10 @@ def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") + assert_tl_matmul_correctness( + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index 3d8a7fd14..73cdc280b 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -3,8 +3,7 @@ from tilelang import tvm as tvm import tilelang.language as T from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout -from tilelang.intrinsics.mfma_macro_generator import ( - MatrixCoreIntrinEmitter,) +from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(0) @@ -22,16 +21,8 @@ def tl_matmul( b_transposed=True, k_pack=1, b_preshuffle=False, + b_g2l_load=False, ): - assert in_dtype in [ - "float16", - "int8", - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - "float16", - "float32", - "int32", - ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 @@ -47,15 +38,14 @@ def tl_matmul( if b_preshuffle: block_row_warps = 1 block_col_warps = 4 - warp_row_tiles = 128 - warp_col_tiles = 32 + warp_row_tiles = 64 + warp_col_tiles = 16 - chunk = 32 * k_pack + chunk = 256 * k_pack pack_size_k = micro_size_k * k_pack shared_scope = "shared" - cache_write_shared = False block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles @@ -68,6 +58,7 @@ def tl_matmul( pack_size_k, micro_size_y) else: B_shape = (N, K) if b_transposed else (K, N) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) if b_preshuffle: B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, @@ -76,12 +67,6 @@ def tl_matmul( micro_size_y) else: B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) warp_size = 64 threads = warp_size * (block_row_warps * block_col_warps) @@ -92,7 +77,7 @@ def tl_matmul( warp_cols = warp_col_tiles // micro_size_y # MMA Wrapper to Auto Generate Code for MMA - mfma_emitter = MatrixCoreIntrinEmitter( + mfma_emitter = MatrixCorePreshuffleIntrinEmitter( a_dtype=in_dtype, b_dtype=in_dtype, accum_dtype=accum_dtype, @@ -117,7 +102,6 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) @@ -126,12 +110,15 @@ def main( A_shared: make_swizzle_layout(A_shared), }) + num_ko = K // block_K + num_ki = block_K // (k_pack * micro_size_k) + # Improve L2 Cache T.use_swizzle(panel_size=10) T.clear(C_local) - for ko in T.Pipelined((K // block_K), num_stages=0): + for ko in T.Pipelined(num_ko, num_stages=0): # Load A into shared memory if a_transposed: @@ -140,7 +127,7 @@ def main( T.copy(A[by * block_M, ko * block_K], A_shared) # Load B into shared memory - if b_preshuffle: + if b_g2l_load is False: if b_transposed: for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, @@ -153,53 +140,37 @@ def main( micro_size_y): B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj] - else: - if b_transposed: - T.copy(B[bx * block_N, ko * block_K], B_shared) - else: - T.copy(B[ko * block_K, bx * block_N], B_shared) - for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): + for ki in T.serial(0, num_ki): - # Load A into fragment + # Load A S2L mfma_emitter.ldmatrix_a( A_local, A_shared, ki, ) - # Load B into fragment - mfma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) + if b_g2l_load: + # Load B G2L + mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx) + else: + # Load B S2L + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) # Perform Matrix Multiplication mfma_emitter.mfma(A_local, B_local, C_local) # Perform STMatrix - if cache_write_shared: - mfma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - else: - mfma_emitter.stmatrix( - C_local, - C, - pid_m=by, - pid_n=bx, - ) + mfma_emitter.stmatrix( + C_local, + C, + pid_m=by, + pid_n=bx, + ) return main @@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M, a_transposed=False, b_transposed=True, k_pack=1, - b_preshuffle=False): + b_preshuffle=False, + b_g2l_load=False): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, - k_pack, b_preshuffle) + k_pack, b_preshuffle, b_g2l_load) print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() @@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M, print(C) print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @tilelang.testing.requires_rocm def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") - assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) - assert_tl_matmul_correctness( - 128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + 256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + 256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) + 256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( - 128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) + 256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) assert_tl_matmul_correctness( - 128, 256, 256, + 512, "int8", "int32", b_transposed=False, diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 195961144..12551b193 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -293,52 +293,27 @@ def _warp_ldmatrix_b( rk=0, ): tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * (k_pack * micro_size_k), + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, + r + col] - # 4 dim - if self.b_preshuffle: - if is_transposed: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - warp_n * warp_cols + j, - rk * (chunk // micro_size_k) + ki, - ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, - row, - col] - else: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - rk * (chunk // micro_size_k) + ki, - warp_n * warp_cols + j, - ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, - row, - col] else: - if is_transposed: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - warp_n * warp_col_tiles + j * micro_size_y, - rk * chunk + ki * (k_pack * micro_size_k), - ) - B_local_buf[j * k_pack * local_size_b + - local_id] = B_shared_buf[l + row, r + col] - else: - for j in T.serial(warp_cols): - for local_id in T.vectorized(k_pack * local_size_b): - row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( - rk * chunk + ki * (k_pack * micro_size_k), - warp_n * warp_col_tiles + j * micro_size_y, - ) - B_local_buf[j * k_pack * local_size_b + - local_id] = B_shared_buf[l + row, r + col] + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * chunk + ki * (k_pack * micro_size_k), + warp_n * warp_col_tiles + j * micro_size_y, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, + r + col] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -425,3 +400,210 @@ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): return _warp_stmatrix_global(C_local_buf, C_buf, thread_binding) if is_global else _warp_stmatrix_shared( C_local_buf, C_buf, thread_binding) + + +class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): + + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: Optional[int] = None, + is_m_first: Optional[bool] = False, + a_preshuffle: Optional[bool] = False, + b_preshuffle: Optional[bool] = False, + ): + + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_mfma_prefix(self.k_dim) + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self._initialize_k_pack(k_pack) + self._initialize_is_m_first(is_m_first) + self._initialize_preshuffle(a_preshuffle, b_preshuffle) + + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + self.reduce_k = reduce_k + self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) + self.num_elems_per_byte = num_elems_per_byte + + def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): + if a_preshuffle is not None: + self.a_preshuffle = a_preshuffle + if b_preshuffle is not None: + self.b_preshuffle = b_preshuffle + + def ldmatrix_a(self, A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + k_pack = self.k_pack + is_transposed = self.a_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + is_global = pid_m is not None and pid_n is not None + + # no preshuffle, use the default implementation + if self.a_preshuffle is False: + return super().ldmatrix_a(A_local_buf, A_buf, ki, rk) + + def _warp_ldmatrix_a_global( + A_local_buf, + A_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + else: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_m * self.block_row_warps + warp_m) * warp_rows + i, + rk * (chunk // micro_size_k) + ki, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_a_shared( + A_local_buf, + A_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, _, warp_m = self.extract_thread_binding(thread_binding) + if is_transposed: + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_m * warp_rows + i, + ) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, + col] + else: + print(self.a_preshuffle) + for i in T.serial(warp_rows): + for local_id in T.vectorized(k_pack * local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, + col] + + return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, + rk) if is_global else _warp_ldmatrix_a_shared( + A_local_buf, A_buf, ki, thread_binding, rk) + + def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + k_pack = self.k_pack + is_transposed = self.b_transposed + current_frame = T.KernelLaunchFrame.Current() + thread_binding = current_frame.get_thread_binding() + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + is_global = pid_m is not None and pid_n is not None + + if self.b_preshuffle is False: + return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) + + @T.macro + def _warp_ldmatrix_b_global( + B_local_buf, + B_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + (pid_n * self.block_col_warps + warp_n) * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col] + + @T.macro + def _warp_ldmatrix_b_shared( + B_local_buf, + B_shared_buf, + ki, + thread_binding, + rk=0, + ): + tx, warp_n, _ = self.extract_thread_binding(thread_binding) + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + warp_n * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, + col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(k_pack * local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * (chunk // micro_size_k) + ki, + warp_n * warp_cols + j, + ) + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, + col] + + return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, + rk) if is_global else _warp_ldmatrix_b_shared( + B_local_buf, B_buf, ki, thread_binding, rk)