Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 36 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,35 +123,24 @@ Below is an example that demonstrates more advanced features: layout annotation,
```python
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)

# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit

# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
Expand All @@ -164,53 +153,58 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)

# Demonstrate parallelized copy from global to shared for B
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

return main
return matmul_relu_kernel


# 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32)
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)

# 3. Test the kernel in Python with PyTorch data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numbering of steps in the guide has a gap, jumping from step 1 to 3. To avoid confusion for new users, this should be corrected to step 2.

Suggested change
# 3. Test the kernel in Python with PyTorch data
# 2. Test the kernel in Python with PyTorch data

import torch

# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)

a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel through the JIT-compiled function
c = jit_kernel(a, b)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
ref_c = torch.relu(a @ b)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# cuda_source = jit_kernel.get_kernel_source()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable jit_kernel is no longer used in this example; it has been replaced by matmul_relu_kernel. This commented-out line should be updated to use the correct variable so that it works correctly if a user decides to uncomment it.

Suggested change
# cuda_source = jit_kernel.get_kernel_source()
# cuda_source = matmul_relu_kernel.get_kernel_source()

# print("Generated CUDA kernel:\n", cuda_source)

# 5.Pofile latency with the profiler
profiler = jit_kernel.get_profiler()
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

latency = profiler.do_bench()

Expand Down
45 changes: 15 additions & 30 deletions examples/quickstart.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import tilelang
import tilelang.language as T

# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401


# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def main(
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
Expand All @@ -25,13 +20,6 @@ def main(
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })

# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)

Expand All @@ -41,8 +29,6 @@ def main(
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)

# Copy tile of B
Expand All @@ -52,10 +38,14 @@ def main(
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

return main
return matmul_relu_kernel


M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
Expand All @@ -66,27 +56,22 @@ def main(
block_K = 32

# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel through the Profiler
c = jit_kernel(a, b)
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
ref_c = torch.relu(a @ b)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
Expand All @@ -97,7 +82,7 @@ def main(
# print("Generated CUDA kernel:\n", cuda_source)

# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

latency = profiler.do_bench()

Expand Down
Loading