Skip to content

Commit a9ab486

Browse files
authored
[SM100] Add sm100 GEMM layouts and tcgen05 support (tile-ai#887)
* update sm100 related utcmma, tmem, ld/st256 in src * update sm100 related utcmma, tmem, ld/st256 in tilelang * Remove deprecated GEMM examples and related README documentation for SM100 architecture support * Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files * Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes * Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation * Update README and source files to reflect TCGEN5.MMA terminology changes * Refactor CUDA GEMM header for improved readability
1 parent 50cecdf commit a9ab486

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3063
-185
lines changed

.clang-tidy

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ Checks: >
4242
-cppcoreguidelines-pro-type-static-cast-downcast,
4343
-performance-unnecessary-value-param,
4444
-performance-enum-size,
45+
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
46+
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
4547
-clang-analyzer-deadcode.DeadStores,
48+
-clang-analyzer-optin.cplusplus.VirtualCall,
4649
4750
WarningsAsErrors: '*'
4851

examples/gemm_sm100/README.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# TileLang SM100 Support (Preview)
2+
3+
This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality.
4+
5+
## Current Limitations (Manual Implementation Required)
6+
7+
### 1. Manual TCGEN5.MMA Management
8+
Users must manually handle TCGEN5MMA operations using:
9+
- `T.alloc_tmem()` - Allocate Tensor Memory
10+
- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting
11+
- Manual synchronization with mbarrier
12+
13+
### 2. Manual mbarrier Synchronization
14+
TCGEN5MMA is asynchronous and requires explicit synchronization:
15+
```python
16+
mbar = T.alloc_barrier(1) # expect-arrive-count = 1
17+
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0)
18+
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
19+
```
20+
21+
## Examples
22+
23+
### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
24+
Demonstrates TCGEN5MMA operations with:
25+
- Tensor Memory allocation
26+
- Manual mbarrier synchronization
27+
- TCGEN5MMA gemm operations
28+
29+
### Traditional MMA Example (`gemm_mma.py`)
30+
Shows standard MMA operations that work across architectures for comparison.
31+
32+
## Code Example
33+
34+
The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication:
35+
36+
```python
37+
import torch
38+
import tilelang
39+
import tilelang.language as T
40+
41+
@T.prim_func
42+
def main(
43+
A: T.Tensor((M, K), "bfloat16"),
44+
B: T.Tensor((N, K), "bfloat16"),
45+
C: T.Tensor((M, N), "bfloat16"),
46+
):
47+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
48+
# 1. Allocate memory buffers
49+
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
50+
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
51+
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
52+
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
53+
54+
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
55+
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory
56+
57+
# 2. Main computation loop
58+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
59+
# Data loading: global memory to shared memory
60+
T.copy(A[by * block_M, k * block_K], A_shared)
61+
T.copy(B[bx * block_N, k * block_K], B_shared)
62+
63+
# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
64+
T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True,
65+
mbar=mbar, wg_wait=-1, clear_accum=k==0)
66+
67+
# Critical: wait for TCGEN5MMA completion
68+
T.mbarrier_wait_parity(mbar, k%2)
69+
70+
# 3. Output processing (only subset of threads)
71+
T.copy(C_tmem, C_local) # Tensor Memory → registers
72+
T.copy(C_local, C_shared) # registers → shared memory
73+
74+
# 4. Write back to global memory
75+
T.copy(C_shared, C[by * block_M, bx * block_N])
76+
```
77+
78+
### Compilation and Usage
79+
80+
```python
81+
# Parameter setup
82+
M, N, K = 4096, 4096, 8192
83+
block_M, block_N, block_K = 128, 256, 128
84+
85+
# Compile kernel
86+
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
87+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
88+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
89+
})
90+
91+
# Run test
92+
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
93+
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
94+
c = jit_kernel(a, b)
95+
96+
# Verify correctness
97+
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
98+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
99+
100+
# Performance benchmark
101+
profiler = jit_kernel.get_profiler()
102+
latency = profiler.do_bench()
103+
print(f"Latency: {latency} ms")
104+
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
105+
```
106+

examples/gemm_sm100/gemm_mma.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import tilelang
2+
import tilelang.language as T
3+
4+
5+
# add decorator @tilelang.jit if you want to return a torch function
6+
# @tilelang.jit
7+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
8+
9+
@T.prim_func
10+
def main(
11+
A: T.Tensor((M, K), dtype),
12+
B: T.Tensor((N, K), dtype),
13+
C: T.Tensor((M, N), dtype),
14+
):
15+
# Initialize Kernel Context
16+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
17+
A_shared = T.alloc_shared((block_M, block_K), dtype)
18+
B_shared = T.alloc_shared((block_N, block_K), dtype)
19+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
20+
21+
# Clear local accumulation
22+
T.clear(C_local)
23+
24+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
25+
# Copy tile of A
26+
# This is a sugar syntax for parallelized copy
27+
# for i, k in T.Parallel(M, block_K):
28+
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
29+
T.copy(A[by * block_M, ko * block_K], A_shared)
30+
31+
# Copy tile of B
32+
T.copy(B[bx * block_N, ko * block_K], B_shared)
33+
34+
# Perform a tile-level GEMM on the shared buffers
35+
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
36+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
37+
38+
# Copy result back to global memory
39+
T.copy(C_local, C[by * block_M, bx * block_N])
40+
41+
return main
42+
43+
44+
M = 128 # M = T.symbolic("m") if you want to use dynamic shape
45+
N = 128
46+
K = 32
47+
block_M = 128
48+
block_N = 128
49+
block_K = 32
50+
51+
# 1. Define the kernel (matmul) and compile/lower it into an executable module
52+
func = matmul(M, N, K, block_M, block_N, block_K)
53+
54+
# 2. Compile the kernel into a torch function
55+
# out_idx specifies the index of the output buffer in the argument list
56+
# if out_idx is specified, the tensor will be created during runtime
57+
# target currently can be "cuda" or "hip" or "cpu".
58+
jit_kernel = tilelang.compile(
59+
func,
60+
out_idx=[2],
61+
target="cuda",
62+
pass_configs={
63+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
64+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
65+
})
66+
print(jit_kernel.get_kernel_source())
67+
# 3. Test the kernel in Python with PyTorch data
68+
import torch
69+
70+
# Create random input tensors on the GPU
71+
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
72+
b = torch.randn(N, K, device="cuda", dtype=torch.float16)
73+
74+
# Run the kernel through the Profiler
75+
c = jit_kernel(a, b)
76+
77+
print(c)
78+
# Reference multiplication using PyTorch
79+
ref_c = a @ b.T
80+
81+
# Validate correctness
82+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
83+
print("Kernel output matches PyTorch reference.")
84+
85+
# 4. Retrieve and inspect the generated CUDA source (optional)
86+
# cuda_source = jit_kernel.get_kernel_source()
87+
# print("Generated CUDA kernel:\n", cuda_source)
88+
89+
# 5.Profile latency with kernel
90+
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
91+
92+
latency = profiler.do_bench()
93+
94+
print(f"Latency: {latency} ms")
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import tilelang
3+
import tilelang.language as T
4+
5+
tilelang.disable_cache()
6+
7+
8+
def matmul(
9+
M,
10+
N,
11+
K,
12+
block_M,
13+
block_N,
14+
block_K,
15+
trans_A,
16+
trans_B,
17+
in_dtype,
18+
out_dtype,
19+
accum_dtype,
20+
num_stages,
21+
threads,
22+
):
23+
A_shape = (K, M) if trans_A else (M, K)
24+
B_shape = (N, K) if trans_B else (K, N)
25+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
26+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
27+
28+
@T.prim_func
29+
def main(
30+
A: T.Tensor(A_shape, in_dtype),
31+
B: T.Tensor(B_shape, in_dtype),
32+
C: T.Tensor((M, N), out_dtype),
33+
):
34+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
35+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
36+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
37+
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
38+
mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count
39+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
40+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
41+
42+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
43+
T.copy(A[by * block_M, k * block_K], A_shared)
44+
T.copy(B[bx * block_N, k * block_K], B_shared)
45+
T.gemm(
46+
A_shared,
47+
B_shared,
48+
C_tmem,
49+
trans_A,
50+
trans_B,
51+
mbar=mbar,
52+
wg_wait=-1,
53+
clear_accum=k == 0)
54+
T.mbarrier_wait_parity(mbar, k % 2)
55+
56+
if T.get_thread_binding() < 128:
57+
T.copy(C_tmem, C_local)
58+
T.copy(C_local, C_shared)
59+
60+
T.copy(C_shared, C[by * block_M, bx * block_N])
61+
62+
return main
63+
64+
65+
M, N, K = 4096, 4096, 8192
66+
block_M, block_N, block_K = 128, 256, 128
67+
trans_A, trans_B = False, True
68+
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
69+
num_stages = 0
70+
threads = 256
71+
72+
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
73+
accum_dtype, num_stages, threads)
74+
jit_kernel = tilelang.compile(
75+
func,
76+
out_idx=[2],
77+
target="cuda",
78+
pass_configs={
79+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
80+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
81+
})
82+
83+
print(jit_kernel.get_kernel_source())
84+
85+
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
86+
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
87+
c = jit_kernel(a, b)
88+
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
89+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
90+
91+
profiler = jit_kernel.get_profiler()
92+
latency = profiler.do_bench()
93+
print(f"Latency: {latency} ms")
94+
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")

src/layout/gemm_layouts.cc

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace tvm {
1414
namespace tl {
1515

16-
static IterVar make_itervar(std::string name, PrimExpr dom) {
16+
IterVar make_itervar(std::string name, PrimExpr dom) {
1717
Var var = Var(name, dom->dtype);
1818
return IterVar(Range(0, dom), var, IterVarType::kDataPar);
1919
}
@@ -749,16 +749,41 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
749749
element_size);
750750
}
751751
int vector_size = 128 / element_size;
752-
if (kfactor == 1 && element_size == 8) // int8 KxN
752+
if (mat_continuous % (vector_size * 8) == 0)
753+
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
754+
else if (mat_continuous % (vector_size * 4) == 0)
755+
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
756+
else if (mat_continuous % (vector_size * 2) == 0)
753757
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
754758
element_size);
755-
else if (mat_continuous % (vector_size * 8) == 0)
759+
else if (mat_continuous % vector_size == 0)
760+
return makeGemmLayoutLinear(mat_stride, mat_continuous);
761+
else
762+
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
763+
<< ", continuous=" << mat_continuous
764+
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
765+
}
766+
767+
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
768+
int element_size, int kfactor) {
769+
if (element_size == 64) {
770+
ICHECK(0) << "float64 on sm100 is not supported now";
771+
}
772+
int vector_size = 128 / element_size;
773+
if (mat_continuous % (vector_size * 8) == 0)
756774
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
757775
else if (mat_continuous % (vector_size * 4) == 0)
758776
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
759-
else
777+
else if (mat_continuous % (vector_size * 2) == 0)
760778
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
761779
element_size);
780+
else if (mat_continuous % vector_size == 0)
781+
return makeGemmLayoutLinear(mat_stride, mat_continuous);
782+
else
783+
ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
784+
<< ", continuous=" << mat_continuous
785+
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
786+
__builtin_unreachable(); // to prevent compiler warning
762787
}
763788

764789
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,

src/layout/layout.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Fragment : public Layout {
131131

132132
Var InputPlaceholder(size_t idx);
133133
Var ReplicationPlaceholder();
134+
IterVar make_itervar(std::string name, PrimExpr dom);
134135

135136
Fragment makeGemmFragment8x8();
136137
Fragment makeGemmFragment8x8Transposed();
@@ -166,6 +167,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
166167
int element_size, int kfactor);
167168
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
168169
int continuity, int element_size, int kfactor);
170+
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
171+
int element_size, int kfactor);
169172
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
170173
int kfactor);
171174

0 commit comments

Comments
 (0)