|
| 1 | +import tilelang.testing |
| 2 | +import tilelang |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +@tilelang.jit( |
| 7 | + out_idx=-1, # create the output tensor during runtime |
| 8 | + verbose=True, |
| 9 | +) |
| 10 | +def matmul_kernel_jit( |
| 11 | + M, |
| 12 | + N, |
| 13 | + K, |
| 14 | + block_M, |
| 15 | + block_N, |
| 16 | + block_K, |
| 17 | + trans_A=False, |
| 18 | + trans_B=True, |
| 19 | + in_dtype='float16', |
| 20 | + out_dtype='float32', |
| 21 | + accum_dtype='float32', |
| 22 | + num_stages=2, |
| 23 | + threads=128, |
| 24 | +): |
| 25 | + A_shape = (K, M) if trans_A else (M, K) |
| 26 | + B_shape = (N, K) if trans_B else (K, N) |
| 27 | + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) |
| 28 | + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) |
| 29 | + |
| 30 | + import tilelang.language as T |
| 31 | + |
| 32 | + @T.prim_func |
| 33 | + def main( |
| 34 | + A: T.Tensor(A_shape, in_dtype), |
| 35 | + B: T.Tensor(B_shape, in_dtype), |
| 36 | + C: T.Tensor((M, N), out_dtype), |
| 37 | + ): |
| 38 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): |
| 39 | + A_shared = T.alloc_shared(A_shared_shape, in_dtype) |
| 40 | + B_shared = T.alloc_shared(B_shared_shape, in_dtype) |
| 41 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 42 | + T.clear(C_local) |
| 43 | + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): |
| 44 | + if trans_A: |
| 45 | + T.copy(A[k * block_K, by * block_M], A_shared) |
| 46 | + else: |
| 47 | + T.copy(A[by * block_M, k * block_K], A_shared) |
| 48 | + if trans_B: |
| 49 | + T.copy(B[bx * block_N, k * block_K], B_shared) |
| 50 | + else: |
| 51 | + T.copy(B[k * block_K, bx * block_N], B_shared) |
| 52 | + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) |
| 53 | + T.copy(C_local, C[by * block_M, bx * block_N]) |
| 54 | + |
| 55 | + return main |
| 56 | + |
| 57 | + |
| 58 | +def test_par_compile(): |
| 59 | + configs = [ |
| 60 | + (1024, 1024, 1024, 128, 128, 32), |
| 61 | + (2048, 2048, 2048, 256, 256, 64), |
| 62 | + (4096, 4096, 4096, 64, 64, 128), |
| 63 | + ] |
| 64 | + kernels = matmul_kernel_jit.par_compile(configs) |
| 65 | + for (M, N, K, _, _, _), kernel in zip(configs, kernels): |
| 66 | + A = torch.randn(M, K, dtype=torch.float16).cuda() |
| 67 | + B = torch.randn(N, K, dtype=torch.float16).cuda() |
| 68 | + ref = (A @ B.T).float() |
| 69 | + C = kernel(A, B) |
| 70 | + tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2) |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + tilelang.testing.main() |
0 commit comments