Skip to content

Commit 403050b

Browse files
committed
add 16x8x16 test
1 parent 18e8d73 commit 403050b

File tree

1 file changed

+390
-0
lines changed

1 file changed

+390
-0
lines changed
Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,390 @@
1+
import numpy as np
2+
3+
import tvm
4+
import tvm.testing
5+
import tvm.meta_schedule.testing.te_workload as te_workload
6+
from tvm import te
7+
from tvm.te import create_prim_func
8+
from tvm.tir import Schedule
9+
from tvm.script import tir as T
10+
from tvm import tir
11+
12+
13+
@T.prim_func
14+
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
15+
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
16+
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
17+
18+
with T.block("root"):
19+
T.reads(A_shared[0:16, 0:16])
20+
T.writes(A_warp[0:32, 0:8])
21+
22+
for ax0, ax1 in T.grid(16, 16):
23+
with T.block("A_shared_warp"):
24+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
25+
T.reads(A_shared[v0, v1])
26+
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
27+
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
28+
v0, v1
29+
]
30+
31+
32+
@T.prim_func
33+
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
34+
s1 = T.var("int32")
35+
s0 = T.var("int32")
36+
A_shared = T.match_buffer(
37+
a,
38+
(16, 16),
39+
"float16",
40+
align=128,
41+
offset_factor=16,
42+
scope="shared",
43+
strides=[s1, s0],
44+
)
45+
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
46+
with T.block("root"):
47+
T.reads(A_shared[0:16, 0:16])
48+
T.writes(A_warp[0:32, 0:8])
49+
tx = T.env_thread("threadIdx.x")
50+
T.launch_thread(tx, 32)
51+
52+
T.evaluate(
53+
T.ptx_ldmatrix(
54+
0,
55+
4,
56+
".b16",
57+
A_warp.data,
58+
8 * tx,
59+
A_shared.data,
60+
16 * (tx % 16) + 8 * (tx // 16),
61+
dtype="float16",
62+
)
63+
)
64+
65+
66+
@T.prim_func
67+
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
68+
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
69+
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
70+
71+
with T.block("root"):
72+
T.reads(B_shared[0:16, 0:16])
73+
T.writes(B_warp[0:32, 0:8])
74+
75+
for ax0, ax1 in T.grid(16, 16):
76+
with T.block("B_shared_warp"):
77+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
78+
T.reads(B_shared[v0, v1])
79+
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
80+
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
81+
v0, v1
82+
]
83+
84+
85+
@T.prim_func
86+
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
87+
s1 = T.var("int32")
88+
s0 = T.var("int32")
89+
B_shared = T.match_buffer(
90+
a,
91+
(16, 16),
92+
"float16",
93+
align=128,
94+
offset_factor=16,
95+
scope="shared",
96+
strides=[s1, s0],
97+
)
98+
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
99+
with T.block("root"):
100+
T.reads(B_shared[0:16, 0:16])
101+
T.writes(B_warp[0:32, 0:8])
102+
tx = T.env_thread("threadIdx.x")
103+
T.launch_thread(tx, 32)
104+
105+
T.evaluate(
106+
T.ptx_ldmatrix(
107+
1,
108+
4,
109+
".b16",
110+
B_warp.data,
111+
8 * tx,
112+
B_shared.data,
113+
16 * (tx % 16) + 8 * (tx // 16),
114+
dtype="float16",
115+
)
116+
)
117+
118+
119+
@T.prim_func
120+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
121+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
122+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
123+
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")
124+
125+
with T.block("root"):
126+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
127+
T.writes(C[0:32, 0:8])
128+
for i, j, k in T.grid(16, 16, 16):
129+
with T.block("C"):
130+
i, j, k = T.axis.remap("SSR", [i, j, k])
131+
T.reads(
132+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
133+
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
134+
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
135+
)
136+
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
137+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
138+
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
139+
] + T.cast(
140+
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32"
141+
) * T.cast(
142+
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "float32"
143+
)
144+
145+
146+
@T.prim_func
147+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
148+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
149+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
150+
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")
151+
152+
with T.block("root"):
153+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
154+
T.writes(C[0:32, 0:8])
155+
tx = T.env_thread("threadIdx.x")
156+
T.launch_thread(tx, 32)
157+
158+
T.evaluate(
159+
T.ptx_mma(
160+
"m16n8k16",
161+
"row",
162+
"col",
163+
"fp16",
164+
"fp16",
165+
"fp32",
166+
A.data,
167+
A.elem_offset + tx * 8,
168+
B.data,
169+
B.elem_offset + tx * 8,
170+
C.data,
171+
C.elem_offset + tx * 8,
172+
False,
173+
dtype="float32",
174+
)
175+
)
176+
177+
T.evaluate(
178+
T.ptx_mma(
179+
"m16n8k16",
180+
"row",
181+
"col",
182+
"fp16",
183+
"fp16",
184+
"fp32",
185+
A.data,
186+
A.elem_offset + tx * 8 + 4,
187+
B.data,
188+
B.elem_offset + tx * 8 + 4,
189+
C.data,
190+
C.elem_offset + tx * 8 + 4,
191+
False,
192+
dtype="float32",
193+
)
194+
)
195+
196+
197+
@T.prim_func
198+
def mma_store_desc(a: T.handle, c: T.handle) -> None:
199+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
200+
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")
201+
202+
with T.block("root"):
203+
T.reads(C_warp[0:32, 0:8])
204+
T.writes(C[0:16, 0:16])
205+
for ax1_0, i0, i1 in T.grid(2, 32, 4):
206+
with T.block("C_warp"):
207+
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
208+
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
209+
210+
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
211+
T.writes(C[v0, v1])
212+
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
213+
214+
215+
@T.prim_func
216+
def mma_store_impl(a: T.handle, c: T.handle) -> None:
217+
s1 = T.var("int32")
218+
s0 = T.var("int32")
219+
220+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
221+
C = T.match_buffer(
222+
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
223+
)
224+
225+
with T.block("root"):
226+
T.reads(C_warp[0:32, 0:8])
227+
T.writes(C[0:16, 0:16])
228+
tx = T.env_thread("threadIdx.x")
229+
T.launch_thread(tx, 32)
230+
231+
T.evaluate(
232+
T.mma_store(
233+
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
234+
)
235+
)
236+
237+
238+
@T.prim_func
239+
def mma_fill_desc(a: T.handle) -> None:
240+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
241+
242+
with T.block("root"):
243+
T.reads()
244+
T.writes(C_warp[0:32, 0:8])
245+
for i0, i1 in T.grid(32, 8):
246+
with T.block("C_warp"):
247+
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
248+
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
249+
T.reads()
250+
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
251+
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.float32(0)
252+
253+
254+
@T.prim_func
255+
def mma_fill_impl(a: T.handle) -> None:
256+
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
257+
258+
with T.block("root"):
259+
T.reads()
260+
T.writes(C_warp[0:32, 0:8])
261+
tx = T.env_thread("threadIdx.x")
262+
T.launch_thread(tx, 32)
263+
264+
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
265+
266+
267+
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
268+
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
269+
tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl)
270+
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
271+
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
272+
273+
274+
def dense(n: int, m: int, k: int):
275+
a = te.placeholder((n, k), name="A", dtype="float16")
276+
b = te.placeholder((m, k), name="B", dtype="float16")
277+
k = te.reduce_axis((0, k), name="k")
278+
c = te.compute(
279+
(n, m),
280+
lambda i, j: te.sum(
281+
tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]),
282+
axis=[k],
283+
),
284+
name="C",
285+
)
286+
return (a, b, c)
287+
288+
289+
M = N = K = 16
290+
# matmul = create_prim_func(dense(n=16, m=K, k=K))
291+
matmul = create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K))
292+
293+
sch = Schedule(matmul)
294+
block = sch.get_block("C")
295+
296+
i, j, k = sch.get_loops(block)
297+
298+
use_gpu = True
299+
use_ldmatrix = K == 16 and use_gpu
300+
301+
if use_gpu:
302+
i1, i2 = sch.split(i, factors=[None, 16])
303+
sch.bind(i1, "blockIdx.x")
304+
# sch.bind(i2, "threadIdx.x")
305+
306+
307+
def fetch_to_shared(block, idx):
308+
block_read = sch.cache_read(block, idx, "shared")
309+
if use_gpu:
310+
sch.compute_at(block_read, i1, True)
311+
warp_size = 32
312+
loops = sch.get_loops(block_read)
313+
fused = sch.fuse(*loops[-2:])
314+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
315+
sch.bind(f_1, "threadIdx.x")
316+
317+
return block_read
318+
319+
320+
A_shared = fetch_to_shared(block, 0)
321+
B_shared = fetch_to_shared(block, 1)
322+
323+
324+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
325+
thread_id = 4 * (i % 8) + (j % 8) // 2
326+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
327+
328+
329+
block = sch.get_block("C")
330+
331+
A_warp = sch.cache_read(block, 0, "warp")
332+
333+
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
334+
335+
B_warp = sch.cache_read(block, 1, "warp")
336+
337+
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
338+
339+
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
340+
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
341+
342+
C_warp = sch.cache_write(block, 0, "warp")
343+
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
344+
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
345+
346+
347+
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
348+
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
349+
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
350+
sch.reorder(outer, f_1, f_2, f_0, f_3)
351+
fused_1 = sch.fuse(f_1, f_2)
352+
fused_2 = sch.fuse(f_0, f_3)
353+
354+
sch.tensorize(outer, "mma_store")
355+
356+
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
357+
358+
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
359+
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
360+
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
361+
sch.reorder(f_1, f_2, f_0, f_3)
362+
fused_1 = sch.fuse(f_1, f_2)
363+
fused_2 = sch.fuse(f_0, f_3)
364+
sch.tensorize(fused_1, "mma_fill")
365+
366+
sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
367+
368+
print(sch.mod.script())
369+
370+
# lowered = tvm.lower(sch.mod["main"])
371+
372+
# if use_gpu:
373+
# target = "vulkan -from_device=0"
374+
# else:
375+
# target = "llvm"
376+
377+
# f = tvm.build(sch.mod["main"], target=target, name="dense")
378+
# dev = tvm.device(target, 0)
379+
380+
# a_np = np.random.uniform(size=(16, K)).astype("float16")
381+
# b_np = np.random.uniform(size=(K, K)).astype("float16")
382+
# c_np = np.dot(a_np.astype("float32"), b_np..astype("float32"))
383+
384+
# a = tvm.nd.array(a_np, dev)
385+
# b = tvm.nd.array(b_np, dev)
386+
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
387+
388+
# # print(f.imported_modules[0].get_source())
389+
# f(a, b, c)
390+
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

0 commit comments

Comments
 (0)