Skip to content

Commit d99853b

Browse files
authored
[Language] Add Correctness and performance check scripts for V2 (#1174)
* fix * lint fix * fix * lint fix * fix * upd
1 parent aef0a6b commit d99853b

File tree

9 files changed

+878
-60
lines changed

9 files changed

+878
-60
lines changed

maint/gemm_v2/correctness_evaluation.py

Lines changed: 726 additions & 0 deletions
Large diffs are not rendered by default.

maint/gemm_v2/latency.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import argparse
4+
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("--use_v2", action="store_true")
7+
args = parser.parse_args()
8+
9+
use_v2 = args.use_v2
10+
11+
12+
# @tilelang.jit(target="cuda")
13+
# target currently can be "cuda" or "hip" or "cpu".
14+
# if not specified, it will be inferred from the input tensors during compile time
15+
@tilelang.jit
16+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
17+
18+
@T.prim_func
19+
def matmul_relu_kernel(
20+
A: T.Tensor((M, K), dtype),
21+
B: T.Tensor((K, N), dtype),
22+
C: T.Tensor((M, N), dtype),
23+
):
24+
# Initialize Kernel Context
25+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
26+
A_shared = T.alloc_shared((block_M, block_K), dtype)
27+
B_shared = T.alloc_shared((block_K, block_N), dtype)
28+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
29+
30+
# Enable rasterization for better L2 cache locality (Optional)
31+
# T.use_swizzle(panel_size=10, enable=True)
32+
33+
# Clear local accumulation
34+
T.clear(C_local)
35+
36+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
37+
# Copy tile of A
38+
# This is a sugar syntax for parallelized copy
39+
T.copy(A[by * block_M, ko * block_K], A_shared)
40+
41+
# Copy tile of B
42+
T.copy(B[ko * block_K, bx * block_N], B_shared)
43+
44+
# Perform a tile-level GEMM on the shared buffers
45+
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
46+
if use_v2:
47+
T.gemm_v2(A_shared, B_shared, C_local)
48+
else:
49+
T.gemm_v1(A_shared, B_shared, C_local)
50+
51+
# relu
52+
for i, j in T.Parallel(block_M, block_N):
53+
C_local[i, j] = T.max(C_local[i, j], 0)
54+
55+
# Copy result back to global memory
56+
T.copy(C_local, C[by * block_M, bx * block_N])
57+
58+
return matmul_relu_kernel
59+
60+
61+
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
62+
N = 16384
63+
K = 16384
64+
block_M = 128
65+
block_N = 128
66+
block_K = 64
67+
68+
# 1. Define the kernel (matmul) and compile/lower it into an executable module
69+
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
70+
71+
# 3. Test the kernel in Python with PyTorch data
72+
import torch
73+
74+
# Create random input tensors on the GPU
75+
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
76+
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
77+
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
78+
79+
# Run the kernel through the Profiler
80+
matmul_relu_kernel(a, b, c)
81+
82+
print(c)
83+
# Reference multiplication using PyTorch
84+
ref_c = torch.relu(a @ b)
85+
86+
# Validate correctness
87+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
88+
print("Kernel output matches PyTorch reference.")
89+
90+
# 4. Retrieve and inspect the generated CUDA source (optional)
91+
# cuda_source = jit_kernel.get_kernel_source()
92+
# print("Generated CUDA kernel:\n", cuda_source)
93+
94+
# 5.Profile latency with kernel
95+
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
96+
97+
latency = profiler.do_bench()
98+
99+
print(f"Latency: {latency} ms")

src/op/gemm.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const {
122122
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
123123
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
124124
bool allow_wgmma = AllowWGMMA(block_size, target);
125-
LOG(INFO) << "allow_tcgen5mma: " << allow_tcgen5mma
126-
<< ", allow_wgmma: " << allow_wgmma;
127125
if (allow_tcgen5mma) {
128126
return GemmInst::kTCGEN5MMA;
129127
} else if (allow_wgmma) {

src/target/codegen_cuda.cc

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,10 +1749,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
17491749
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
17501750
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
17511751
tl::codegen::Replacer replacer;
1752+
std::string AType = tl::codegen::ptx::DTypeEnumToString(dtype_a_enum);
1753+
if (AType == "tl::DataType::kFloat32") {
1754+
AType = "tl::DataType::kTensorFloat32";
1755+
}
1756+
std::string BType = tl::codegen::ptx::DTypeEnumToString(dtype_b_enum);
1757+
if (BType == "tl::DataType::kFloat32") {
1758+
BType = "tl::DataType::kTensorFloat32";
1759+
}
1760+
17521761
replacer.register_rule("(AType)",
1753-
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
1762+
tl::codegen::ptx::DTypeEnumToString(AType));
17541763
replacer.register_rule("(BType)",
1755-
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
1764+
tl::codegen::ptx::DTypeEnumToString(BType));
17561765
replacer.register_rule("(CType)",
17571766
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
17581767
replacer.register_rule("(M)", std::to_string(m));
@@ -1838,16 +1847,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18381847
std::string B_offset = this->PrintExpr(op->args[9]);
18391848
std::string c_ref = this->PrintExpr(op->args[10]);
18401849
std::string c_offset = this->PrintExpr(op->args[11]);
1841-
bool scale_out = Downcast<Bool>(op->args[12])->value;
1850+
std::string scale_out = this->PrintExpr(op->args[12]);
18421851
bool scale_in_a = Downcast<Bool>(op->args[13])->value;
18431852
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
18441853

18451854
const bool a_is_shared = true;
18461855
this->PrintIndent();
1847-
std::string asm_code = PrintWGMMAAssembly(
1848-
shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
1849-
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
1850-
scale_in_b, a_is_shared, "", "", "", false);
18511856
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
18521857
need_wgmma_instruction_h_ = true;
18531858
std::string wgmma_asm_code =
@@ -1856,10 +1861,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18561861
"uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
18571862
// replace patterns
18581863
tl::codegen::Replacer replacer;
1859-
replacer.register_rule("(AType)",
1860-
tl::codegen::ptx::DTypeEnumToString(A_dtype));
1861-
replacer.register_rule("(BType)",
1862-
tl::codegen::ptx::DTypeEnumToString(B_dtype));
1864+
1865+
std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype);
1866+
if (AType == "tl::DataType::kFloat32") {
1867+
AType = "tl::DataType::kTensorFloat32";
1868+
}
1869+
std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype);
1870+
if (BType == "tl::DataType::kFloat32") {
1871+
BType = "tl::DataType::kTensorFloat32";
1872+
}
1873+
1874+
replacer.register_rule("(AType)", AType);
1875+
replacer.register_rule("(BType)", BType);
18631876
replacer.register_rule("(CType)",
18641877
tl::codegen::ptx::DTypeEnumToString(C_dtype));
18651878
replacer.register_rule("(M)", std::to_string(m));
@@ -1874,7 +1887,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
18741887
replacer.register_rule("(desc_b)", b_desc);
18751888
replacer.register_rule("(B_offset)", B_offset);
18761889
replacer.register_rule("(C)", c_ref + " + " + c_offset);
1877-
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
1890+
replacer.register_rule("(scale_out)", scale_out);
18781891
wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
18791892
this->stream << wgmma_asm_code;
18801893
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
@@ -1904,7 +1917,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19041917
std::string B_offset = this->PrintExpr(op->args[8]);
19051918
std::string c_ref = this->PrintExpr(op->args[9]);
19061919
std::string c_offset = this->PrintExpr(op->args[10]);
1907-
bool scale_out = Downcast<Bool>(op->args[11])->value;
1920+
std::string scale_out = this->PrintExpr(op->args[11]);
19081921
bool scale_in_a = Downcast<Bool>(op->args[12])->value;
19091922
bool scale_in_b = Downcast<Bool>(op->args[13])->value;
19101923

@@ -1924,10 +1937,17 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19241937
"(scale_out));\n";
19251938

19261939
tl::codegen::Replacer replacer;
1927-
replacer.register_rule("(AType)",
1928-
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
1929-
replacer.register_rule("(BType)",
1930-
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
1940+
std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype);
1941+
if (AType == "tl::DataType::kFloat32") {
1942+
AType = "tl::DataType::kTensorFloat32";
1943+
}
1944+
std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype);
1945+
if (BType == "tl::DataType::kFloat32") {
1946+
BType = "tl::DataType::kTensorFloat32";
1947+
}
1948+
1949+
replacer.register_rule("(AType)", AType);
1950+
replacer.register_rule("(BType)", BType);
19311951
replacer.register_rule("(CType)",
19321952
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
19331953
replacer.register_rule("(M)", std::to_string(m));
@@ -1943,7 +1963,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19431963
replacer.register_rule("(B_offset)", B_offset);
19441964
replacer.register_rule("(C_ptr)", c_ref);
19451965
replacer.register_rule("(C_offset)", c_offset);
1946-
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
1966+
replacer.register_rule("(scale_out)", scale_out);
19471967
wgmma_call = replacer.rewrite(wgmma_call);
19481968
this->stream << wgmma_call;
19491969
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {

src/tl_templates/cuda/instruction/mma.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
127127
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
128128
true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN)
129129

130+
// TF32 inputs (FP32 math on Tensor Cores)
131+
// Support both k=4 and k=8 variants on SM80
132+
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 4,
133+
false, true, false,
134+
cute::SM80_16x8x4_F32TF32TF32F32_TN)
135+
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8,
136+
false, true, false,
137+
cute::SM80_16x8x8_F32TF32TF32F32_TN)
138+
130139
#undef TL_DEFINE_MMA_DISPATCHER
131140

132141
} // namespace detail

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def test_gemm_sr():
397397
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
398398

399399
# float32 tests
400+
# TODO(lei): fix in future
400401
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
401402
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
402403
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)

testing/python/transform/test_tilelang_transform_inject_fence_proxy.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -186,43 +186,5 @@ def visit(node):
186186
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
187187

188188

189-
def test_wgmma_after_descriptor():
190-
191-
@T.prim_func
192-
def before():
193-
with T.Kernel(1):
194-
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
195-
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
196-
C_local = T.decl_buffer((32,), "float16", scope="local")
197-
T.initialize_wgmma_descriptor(desc_a, T.uint64(0), 2, 1, 32)
198-
T.initialize_wgmma_descriptor(desc_b, T.uint64(0), 2, 1, 32)
199-
T.warpgroup_arrive()
200-
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
201-
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
202-
T.int32(0), T.bool(True), 1, 1)
203-
204-
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
205-
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
206-
mod = tl.transform.InjectFenceProxy()(mod)
207-
208-
fence_count = 0
209-
order = []
210-
211-
def visit(node):
212-
nonlocal fence_count
213-
if isinstance(node, tir.Evaluate):
214-
call = node.value
215-
if isinstance(call, tir.Call):
216-
name = getattr(call.op, "name", "")
217-
order.append(name)
218-
if name == "tl.fence_proxy_async":
219-
fence_count += 1
220-
221-
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
222-
assert fence_count >= 1
223-
assert "tl.warpgroup_arrive" in order
224-
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
225-
226-
227189
if __name__ == "__main__":
228190
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
alloc_tcgen05_instr_desc, # noqa: F401
5252
)
5353
from .copy import copy, c2d_im2col # noqa: F401
54-
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401
54+
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
5555
from .experimental.gemm_sp import gemm_sp # noqa: F401
5656
from .fill import fill, clear # noqa: F401
5757
from .reduce import (

tilelang/language/gemm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tilelang.utils.language import get_buffer_region_from_load
88

99

10-
def gemm(
10+
def gemm_v1(
1111
A: tir.Buffer | tir.Var,
1212
B: tir.Buffer | tir.Var,
1313
C: tir.Buffer | tir.Var,
@@ -432,3 +432,6 @@ def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
432432
C_coords[0],
433433
C_coords[1],
434434
)
435+
436+
437+
gemm = gemm_v1

0 commit comments

Comments
 (0)