Skip to content

Commit 9d2844d

Browse files
committed
test tensorize without layout transform
1 parent 86ee6da commit 9d2844d

File tree

2 files changed

+340
-15
lines changed

2 files changed

+340
-15
lines changed

tests/python/unittest/test_mma_16x8x16.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -295,24 +295,17 @@ def dense(n: int, m: int, k: int):
295295

296296
i, j, k = sch.get_loops(block)
297297

298-
use_gpu = True
299-
use_ldmatrix = K == 16 and use_gpu
300-
301-
if use_gpu:
302-
i1, i2 = sch.split(i, factors=[None, 16])
303-
sch.bind(i1, "blockIdx.x")
304-
# sch.bind(i2, "threadIdx.x")
305-
298+
i1, i2 = sch.split(i, factors=[None, 16])
299+
sch.bind(i1, "blockIdx.x")
306300

307301
def fetch_to_shared(block, idx):
308302
block_read = sch.cache_read(block, idx, "shared")
309-
if use_gpu:
310-
sch.compute_at(block_read, i1, True)
311-
warp_size = 32
312-
loops = sch.get_loops(block_read)
313-
fused = sch.fuse(*loops[-2:])
314-
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
315-
sch.bind(f_1, "threadIdx.x")
303+
sch.compute_at(block_read, i1, True)
304+
warp_size = 32
305+
loops = sch.get_loops(block_read)
306+
fused = sch.fuse(*loops[-2:])
307+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
308+
sch.bind(f_1, "threadIdx.x")
316309

317310
return block_read
318311

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import numpy as np
2+
3+
import tvm
4+
import tvm.testing
5+
import tvm.meta_schedule.testing.te_workload as te_workload
6+
from tvm import te
7+
from tvm.te import create_prim_func
8+
from tvm.tir import Schedule
9+
from tvm.script import tir as T
10+
from tvm import tir
11+
12+
13+
@T.prim_func
14+
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
15+
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
16+
A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp")
17+
18+
with T.block("root"):
19+
T.reads(A_shared[0:16, 0:16])
20+
T.writes(A_warp[0:16, 0:16])
21+
22+
for ax0, ax1 in T.grid(16, 16):
23+
with T.block("A_shared_warp"):
24+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
25+
T.reads(A_shared[v0, v1])
26+
T.writes(A_warp[v0, v1])
27+
A_warp[v0, v1] = A_shared[v0, v1]
28+
29+
30+
@T.prim_func
31+
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
32+
s1 = T.var("int32")
33+
s0 = T.var("int32")
34+
A_shared = T.match_buffer(
35+
a,
36+
(16, 16),
37+
"float16",
38+
align=128,
39+
offset_factor=16,
40+
scope="shared",
41+
strides=[s1, s0],
42+
)
43+
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
44+
with T.block("root"):
45+
T.reads(A_shared[0:16, 0:16])
46+
T.writes(A_warp[0:32, 0:8])
47+
tx = T.env_thread("threadIdx.x")
48+
T.launch_thread(tx, 32)
49+
50+
T.evaluate(
51+
T.ptx_ldmatrix(
52+
0,
53+
4,
54+
".b16",
55+
A_warp.data,
56+
8 * tx,
57+
A_shared.data,
58+
16 * (tx % 16) + 8 * (tx // 16),
59+
dtype="float16",
60+
)
61+
)
62+
63+
64+
@T.prim_func
65+
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
66+
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
67+
B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp")
68+
69+
with T.block("root"):
70+
T.reads(B_shared[0:16, 0:16])
71+
T.writes(B_warp[0:16, 0:16])
72+
73+
for ax0, ax1 in T.grid(16, 16):
74+
with T.block("B_shared_warp"):
75+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
76+
T.reads(B_shared[v0, v1])
77+
T.writes(B_warp[v0, v1])
78+
B_warp[v0, v1] = B_shared[v0, v1]
79+
80+
81+
@T.prim_func
82+
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
83+
s1 = T.var("int32")
84+
s0 = T.var("int32")
85+
B_shared = T.match_buffer(
86+
a,
87+
(16, 16),
88+
"float16",
89+
align=128,
90+
offset_factor=16,
91+
scope="shared",
92+
strides=[s1, s0],
93+
)
94+
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
95+
with T.block("root"):
96+
T.reads(B_shared[0:16, 0:16])
97+
T.writes(B_warp[0:32, 0:8])
98+
tx = T.env_thread("threadIdx.x")
99+
T.launch_thread(tx, 32)
100+
101+
T.evaluate(
102+
T.ptx_ldmatrix(
103+
1,
104+
4,
105+
".b16",
106+
B_warp.data,
107+
8 * tx,
108+
B_shared.data,
109+
16 * (tx % 16) + 8 * (tx // 16),
110+
dtype="float16",
111+
)
112+
)
113+
114+
115+
@T.prim_func
116+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
117+
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp")
118+
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp")
119+
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp")
120+
121+
with T.block("root"):
122+
T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
123+
T.writes(C[0:16, 0:16])
124+
for i, j, k in T.grid(16, 16, 16):
125+
with T.block("C"):
126+
i, j, k = T.axis.remap("SSR", [i, j, k])
127+
T.reads(C[i, j], A[i, k], B[k, j])
128+
T.writes(C[i, j])
129+
C[i, j] = C[i, j] + T.cast(A[i, k], "float32") * T.cast(B[k, j], "float32")
130+
131+
132+
@T.prim_func
133+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
134+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
135+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
136+
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")
137+
138+
with T.block("root"):
139+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
140+
T.writes(C[0:32, 0:8])
141+
tx = T.env_thread("threadIdx.x")
142+
T.launch_thread(tx, 32)
143+
144+
T.evaluate(
145+
T.ptx_mma(
146+
"m16n8k16",
147+
"row",
148+
"col",
149+
"fp16",
150+
"fp16",
151+
"fp32",
152+
A.data,
153+
A.elem_offset + tx * 8,
154+
B.data,
155+
B.elem_offset + tx * 8,
156+
C.data,
157+
C.elem_offset + tx * 8,
158+
False,
159+
dtype="float32",
160+
)
161+
)
162+
163+
T.evaluate(
164+
T.ptx_mma(
165+
"m16n8k16",
166+
"row",
167+
"col",
168+
"fp16",
169+
"fp16",
170+
"fp32",
171+
A.data,
172+
A.elem_offset + tx * 8,
173+
B.data,
174+
B.elem_offset + tx * 8 + 4,
175+
C.data,
176+
C.elem_offset + tx * 8 + 4,
177+
False,
178+
dtype="float32",
179+
)
180+
)
181+
182+
183+
@T.prim_func
184+
def mma_store_desc(a: T.handle, c: T.handle) -> None:
185+
C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp")
186+
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")
187+
188+
with T.block("root"):
189+
T.reads(C_warp[0:16, 0:16])
190+
T.writes(C[0:16, 0:16])
191+
for i0, i1 in T.grid(16, 16):
192+
with T.block("C_warp"):
193+
v0, v1 = T.axis.remap("SS", [i0, i1])
194+
T.reads(C_warp[v0, v1])
195+
T.writes(C[v0, v1])
196+
C[v0, v1] = C_warp[v0, v1]
197+
198+
199+
@T.prim_func
200+
def mma_store_impl(a: T.handle, c: T.handle) -> None:
201+
s1 = T.var("int32")
202+
s0 = T.var("int32")
203+
204+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
205+
C = T.match_buffer(
206+
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
207+
)
208+
209+
with T.block("root"):
210+
T.reads(C_warp[0:32, 0:8])
211+
T.writes(C[0:16, 0:16])
212+
tx = T.env_thread("threadIdx.x")
213+
T.launch_thread(tx, 32)
214+
215+
T.evaluate(
216+
T.mma_store(
217+
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
218+
)
219+
)
220+
221+
222+
@T.prim_func
223+
def mma_fill_desc(a: T.handle) -> None:
224+
C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp")
225+
226+
with T.block("root"):
227+
T.reads()
228+
T.writes(C_warp[0:16, 0:16])
229+
for i0, i1 in T.grid(16, 16):
230+
with T.block("C_warp"):
231+
i, j = T.axis.remap("SS", [i0, i1])
232+
T.reads()
233+
T.writes(C_warp[i, j])
234+
C_warp[i, j] = T.float32(0)
235+
236+
237+
@T.prim_func
238+
def mma_fill_impl(a: T.handle) -> None:
239+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
240+
241+
with T.block("root"):
242+
T.reads()
243+
T.writes(C_warp[0:32, 0:8])
244+
tx = T.env_thread("threadIdx.x")
245+
T.launch_thread(tx, 32)
246+
247+
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
248+
249+
250+
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
251+
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
252+
tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl)
253+
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
254+
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
255+
256+
257+
def dense(n: int, m: int, k: int):
258+
a = te.placeholder((n, k), name="A", dtype="float16")
259+
b = te.placeholder((m, k), name="B", dtype="float16")
260+
k = te.reduce_axis((0, k), name="k")
261+
c = te.compute(
262+
(n, m),
263+
lambda i, j: te.sum(
264+
tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]),
265+
axis=[k],
266+
),
267+
name="C",
268+
)
269+
return (a, b, c)
270+
271+
272+
M = N = K = 16
273+
# matmul = create_prim_func(dense(n=16, m=K, k=K))
274+
matmul = create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K))
275+
276+
sch = Schedule(matmul)
277+
block = sch.get_block("C")
278+
279+
i, j, k = sch.get_loops(block)
280+
281+
i1, i2 = sch.split(i, factors=[None, 16])
282+
sch.bind(i1, "blockIdx.x")
283+
284+
def fetch_to_shared(block, idx):
285+
block_read = sch.cache_read(block, idx, "shared")
286+
sch.compute_at(block_read, i1, True)
287+
warp_size = 32
288+
loops = sch.get_loops(block_read)
289+
fused = sch.fuse(*loops[-2:])
290+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
291+
sch.bind(f_1, "threadIdx.x")
292+
293+
return block_read
294+
295+
296+
A_shared = fetch_to_shared(block, 0)
297+
B_shared = fetch_to_shared(block, 1)
298+
299+
block = sch.get_block("C")
300+
301+
A_warp = sch.cache_read(block, 0, "warp")
302+
B_warp = sch.cache_read(block, 1, "warp")
303+
C_warp = sch.cache_write(block, 0, "warp")
304+
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
305+
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
306+
307+
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
308+
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
309+
sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
310+
sch.tensorize(sch.get_loops(C_warp)[1], "mma_store")
311+
sch.tensorize(sch.get_loops(block_init_c)[1], "mma_fill")
312+
313+
print(sch.mod.script())
314+
315+
# lowered = tvm.lower(sch.mod["main"])
316+
317+
target = "cuda"
318+
319+
f = tvm.build(sch.mod["main"], target=target, name="dense")
320+
# dev = tvm.device(target, 0)
321+
322+
# a_np = np.random.uniform(size=(16, K)).astype("float16")
323+
# b_np = np.random.uniform(size=(K, K)).astype("float16")
324+
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
325+
326+
# a = tvm.nd.array(a_np, dev)
327+
# b = tvm.nd.array(b_np, dev)
328+
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
329+
330+
# # print(f.imported_modules[0].get_source())
331+
# f(a, b, c)
332+
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

0 commit comments

Comments
 (0)