Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 66 additions & 64 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,93 +132,95 @@ In this section, you'll learn how to write and execute a straightforward GEMM (m
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.

```python
import tilelang
import tilelang.language as T

# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
def matmul_relu(
A, B,
block_M: int = 64,
block_N: int = 64,
block_K: int = 64,
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float32,
):
# declare compilation shape constant
M, N, K = T.const('M, N, K')

# annotate input tensor shape
A: T.Tensor[[M, K], dtype]
B: T.Tensor[[K, N], dtype]

@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# allocate output tensor
C = T.empty([M, N], dtype)

# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Clear local accumulation
T.clear(C_local)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Clear local accumulation
T.clear(C_local)

# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

return matmul_relu_kernel
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

# You can write multiple cuda kernel in one function, they execute sequentially
# with T.Kernel(...) as ...

M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
# Return the tensor, you can also return multiple tensors
return C

# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)

# 3. Test the kernel in Python with PyTorch data
import torch
M, N, K = 1024, 1024, 1024

# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = matmul_relu_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
c_ref = torch.relu(a @ b)

# Call the kernel
c = matmul_relu(a, b)
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2)

# Call the kernel with overwritten compilation constants
c = matmul_relu(a, b, block_M=128, block_N=128, block_K=64)
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2)

# Retrieve the compiled kernel
kernel = matmul_relu.compile(a, b) # use torch.Tensor
kernel = matmul_relu.compile( # use T.Tensor as placeholder
T.Tensor((M, K), T.float16),
T.Tensor((K, N), T.float16)
)
kernel = matmul_relu.compile( # directly specify the shape constants
M=M, N=N, K=K,
block_M=128, block_N=128, block_K=64
)
print(kernel.get_kernel_source())
c = kernel(a, b)

# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()

print(f"Latency: {latency} ms")
```

Expand Down
Loading