diff --git a/README.md b/README.md index eeef6d401..b7c8e81a4 100644 --- a/README.md +++ b/README.md @@ -132,93 +132,95 @@ In this section, you'll learn how to write and execute a straightforward GEMM (m Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware. ```python -import tilelang -import tilelang.language as T - # @tilelang.jit(target="cuda") # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): +def matmul_relu( + A, B, + block_M: int = 64, + block_N: int = 64, + block_K: int = 64, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float32, +): + # declare compilation shape constant + M, N, K = T.const('M, N, K') + + # annotate input tensor shape + A: T.Tensor[[M, K], dtype] + B: T.Tensor[[K, N], dtype] - @T.prim_func - def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # allocate output tensor + C = T.empty([M, N], dtype) - # Enable rasterization for better L2 cache locality (Optional) - # T.use_swizzle(panel_size=10, enable=True) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - # Clear local accumulation - T.clear(C_local) + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy tile of A - # This is a sugar syntax for parallelized copy - T.copy(A[by * block_M, ko * block_K], A_shared) + # Clear local accumulation + T.clear(C_local) - # Copy tile of B - T.copy(B[ko * block_K, bx * block_N], B_shared) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) - # Perform a tile-level GEMM on the shared buffers - # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - T.gemm(A_shared, B_shared, C_local) + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local) - # Copy result back to global memory - T.copy(C_local, C[by * block_M, bx * block_N]) + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) - return matmul_relu_kernel + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + # You can write multiple cuda kernel in one function, they execute sequentially + # with T.Kernel(...) as ... -M = 1024 # M = T.dynamic("m") if you want to use dynamic shape -N = 1024 -K = 1024 -block_M = 128 -block_N = 128 -block_K = 32 + # Return the tensor, you can also return multiple tensors + return C -# 1. Define the kernel (matmul) and compile/lower it into an executable module -matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) -# 3. Test the kernel in Python with PyTorch data -import torch +M, N, K = 1024, 1024, 1024 -# Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) - -# Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) - -print(c) -# Reference multiplication using PyTorch -ref_c = torch.relu(a @ b) - -# Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") - -# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = matmul_relu_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) +c_ref = torch.relu(a @ b) + +# Call the kernel +c = matmul_relu(a, b) +torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2) + +# Call the kernel with overwritten compilation constants +c = matmul_relu(a, b, block_M=128, block_N=128, block_K=64) +torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2) + +# Retrieve the compiled kernel +kernel = matmul_relu.compile(a, b) # use torch.Tensor +kernel = matmul_relu.compile( # use T.Tensor as placeholder + T.Tensor((M, K), T.float16), + T.Tensor((K, N), T.float16) +) +kernel = matmul_relu.compile( # directly specify the shape constants + M=M, N=N, K=K, + block_M=128, block_N=128, block_K=64 +) +print(kernel.get_kernel_source()) +c = kernel(a, b) # 5.Profile latency with kernel -profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) - +profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) latency = profiler.do_bench() - print(f"Latency: {latency} ms") ```