Skip to content

Commit 441fd19

Browse files
committed
adding fp16 accum case
1 parent c9d40b6 commit 441fd19

File tree

3 files changed

+869
-1
lines changed

3 files changed

+869
-1
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
822822
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
823823
smem_ptr, smem_elem_offset);
824824
} else if (op->op.same_as(builtin::mma_store())) {
825-
int m = Downcast<Integer>(op->args[1])->value;
825+
int m = Downcast<Integer>(op->args[0])->value;
826826
int n = Downcast<Integer>(op->args[1])->value;
827827
std::string dst = this->PrintExpr(op->args[2]);
828828
std::string src = this->PrintExpr(op->args[3]);
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
import numpy as np
2+
3+
import tvm
4+
import tvm.testing
5+
import tvm.meta_schedule.testing.te_workload as te_workload
6+
from tvm import te
7+
from tvm.te import create_prim_func
8+
from tvm.tir import Schedule
9+
from tvm.script import tir as T
10+
from tvm import tir
11+
12+
13+
@T.prim_func
14+
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
15+
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
16+
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
17+
18+
with T.block("root"):
19+
T.reads(A_shared[0:16, 0:16])
20+
T.writes(A_warp[0:32, 0:8])
21+
22+
for ax0, ax1 in T.grid(16, 16):
23+
with T.block("A_shared_warp"):
24+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
25+
T.reads(A_shared[v0, v1])
26+
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
27+
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
28+
v0, v1
29+
]
30+
31+
32+
@T.prim_func
33+
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
34+
s1 = T.var("int32")
35+
s0 = T.var("int32")
36+
A_shared = T.match_buffer(
37+
a,
38+
(16, 16),
39+
"float16",
40+
align=128,
41+
offset_factor=16,
42+
scope="shared",
43+
strides=[s1, s0],
44+
)
45+
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
46+
with T.block("root"):
47+
T.reads(A_shared[0:16, 0:16])
48+
T.writes(A_warp[0:32, 0:8])
49+
tx = T.env_thread("threadIdx.x")
50+
T.launch_thread(tx, 32)
51+
52+
T.evaluate(
53+
T.ptx_ldmatrix(
54+
0,
55+
4,
56+
".b16",
57+
A_warp.data,
58+
8 * tx,
59+
A_shared.data,
60+
16 * (tx % 16) + 8 * (tx // 16),
61+
dtype="float16",
62+
)
63+
)
64+
65+
66+
@T.prim_func
67+
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
68+
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
69+
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
70+
71+
with T.block("root"):
72+
T.reads(B_shared[0:16, 0:16])
73+
T.writes(B_warp[0:32, 0:8])
74+
75+
for ax0, ax1 in T.grid(16, 16):
76+
with T.block("B_shared_warp"):
77+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
78+
T.reads(B_shared[v0, v1])
79+
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
80+
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
81+
v0, v1
82+
]
83+
84+
85+
@T.prim_func
86+
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
87+
s1 = T.var("int32")
88+
s0 = T.var("int32")
89+
B_shared = T.match_buffer(
90+
a,
91+
(16, 16),
92+
"float16",
93+
align=128,
94+
offset_factor=16,
95+
scope="shared",
96+
strides=[s1, s0],
97+
)
98+
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
99+
with T.block("root"):
100+
T.reads(B_shared[0:16, 0:16])
101+
T.writes(B_warp[0:32, 0:8])
102+
tx = T.env_thread("threadIdx.x")
103+
T.launch_thread(tx, 32)
104+
105+
T.evaluate(
106+
T.ptx_ldmatrix(
107+
1,
108+
4,
109+
".b16",
110+
B_warp.data,
111+
8 * tx,
112+
B_shared.data,
113+
16 * (tx % 16) + 8 * (tx // 16),
114+
dtype="float16",
115+
)
116+
)
117+
118+
119+
@T.prim_func
120+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
121+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
122+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
123+
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
124+
125+
with T.block("root"):
126+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
127+
T.writes(C[0:32, 0:8])
128+
for i, j, k in T.grid(16, 16, 16):
129+
with T.block("C"):
130+
i, j, k = T.axis.remap("SSR", [i, j, k])
131+
T.reads(
132+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
133+
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
134+
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
135+
)
136+
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
137+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
138+
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
139+
] + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2] * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
140+
141+
142+
@T.prim_func
143+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
144+
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
145+
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
146+
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
147+
148+
with T.block("root"):
149+
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
150+
T.writes(C[0:32, 0:8])
151+
tx = T.env_thread("threadIdx.x")
152+
T.launch_thread(tx, 32)
153+
154+
T.evaluate(
155+
T.ptx_mma(
156+
"m16n8k16",
157+
"row",
158+
"col",
159+
"fp16",
160+
"fp16",
161+
"fp16",
162+
A.data,
163+
A.elem_offset + tx * 8,
164+
B.data,
165+
B.elem_offset + tx * 8,
166+
C.data,
167+
C.elem_offset + tx * 8,
168+
False,
169+
dtype="float16",
170+
)
171+
)
172+
173+
T.evaluate(
174+
T.ptx_mma(
175+
"m16n8k16",
176+
"row",
177+
"col",
178+
"fp16",
179+
"fp16",
180+
"fp16",
181+
A.data,
182+
A.elem_offset + tx * 8,
183+
B.data,
184+
B.elem_offset + tx * 8 + 4,
185+
C.data,
186+
C.elem_offset + tx * 8 + 4,
187+
False,
188+
dtype="float16",
189+
)
190+
)
191+
192+
193+
@T.prim_func
194+
def mma_store_desc(a: T.handle, c: T.handle) -> None:
195+
C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp")
196+
C = T.match_buffer(c, [16, 16], dtype="float16", scope="global")
197+
198+
with T.block("root"):
199+
T.reads(C_warp[0:32, 0:8])
200+
T.writes(C[0:16, 0:16])
201+
for ax1_0, i0, i1 in T.grid(2, 32, 4):
202+
with T.block("C_warp"):
203+
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
204+
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
205+
206+
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
207+
T.writes(C[v0, v1])
208+
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
209+
210+
211+
@T.prim_func
212+
def mma_store_impl(a: T.handle, c: T.handle) -> None:
213+
s1 = T.var("int32")
214+
s0 = T.var("int32")
215+
216+
C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1)
217+
C = T.match_buffer(
218+
c, [16, 16], dtype="float16", scope="global", offset_factor=1, strides=[s1, s0]
219+
)
220+
221+
with T.block("root"):
222+
T.reads(C_warp[0:32, 0:8])
223+
T.writes(C[0:16, 0:16])
224+
tx = T.env_thread("threadIdx.x")
225+
T.launch_thread(tx, 32)
226+
227+
T.evaluate(
228+
T.mma_store(
229+
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float16"
230+
)
231+
)
232+
233+
234+
@T.prim_func
235+
def mma_fill_desc(a: T.handle) -> None:
236+
C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp")
237+
238+
with T.block("root"):
239+
T.reads()
240+
T.writes(C_warp[0:32, 0:8])
241+
for i0, i1 in T.grid(32, 8):
242+
with T.block("C_warp"):
243+
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
244+
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
245+
T.reads()
246+
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
247+
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.float16(0)
248+
249+
250+
@T.prim_func
251+
def mma_fill_impl(a: T.handle) -> None:
252+
C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1)
253+
254+
with T.block("root"):
255+
T.reads()
256+
T.writes(C_warp[0:32, 0:8])
257+
tx = T.env_thread("threadIdx.x")
258+
T.launch_thread(tx, 32)
259+
260+
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float16"))
261+
262+
263+
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
264+
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
265+
tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl)
266+
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
267+
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
268+
269+
270+
def matmul_fp16(n, m, k):
271+
a = te.placeholder((n, k), name="A", dtype="float16")
272+
b = te.placeholder((k, m), name="B", dtype="float16")
273+
k = te.reduce_axis((0, k), name="k")
274+
275+
def f_compute(i, j):
276+
v_a = a[i, k]
277+
v_b = b[k, j]
278+
return te.sum(v_a * v_b, axis=[k])
279+
280+
c = te.compute((n, m), f_compute, name="C")
281+
return (a, b, c)
282+
283+
284+
M = N = K = 16
285+
matmul = create_prim_func(matmul_fp16(n=N, m=M, k=K))
286+
287+
sch = Schedule(matmul)
288+
block = sch.get_block("C")
289+
290+
i, j, k = sch.get_loops(block)
291+
292+
use_gpu = True
293+
use_ldmatrix = K == 16 and use_gpu
294+
295+
i1, i2 = sch.split(i, factors=[None, 16])
296+
sch.bind(i1, "blockIdx.x")
297+
298+
299+
def fetch_to_shared(block, idx):
300+
block_read = sch.cache_read(block, idx, "shared")
301+
if use_gpu:
302+
sch.compute_at(block_read, i1, True)
303+
warp_size = 32
304+
loops = sch.get_loops(block_read)
305+
fused = sch.fuse(*loops[-2:])
306+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
307+
sch.bind(f_1, "threadIdx.x")
308+
309+
return block_read
310+
311+
312+
A_shared = fetch_to_shared(block, 0)
313+
B_shared = fetch_to_shared(block, 1)
314+
315+
316+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
317+
thread_id = 4 * (i % 8) + (j % 8) // 2
318+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
319+
320+
321+
block = sch.get_block("C")
322+
323+
A_warp = sch.cache_read(block, 0, "warp")
324+
325+
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
326+
327+
B_warp = sch.cache_read(block, 1, "warp")
328+
329+
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
330+
331+
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
332+
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
333+
334+
C_warp = sch.cache_write(block, 0, "warp")
335+
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
336+
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
337+
338+
339+
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
340+
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
341+
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
342+
sch.reorder(outer, f_1, f_2, f_0, f_3)
343+
fused_1 = sch.fuse(f_1, f_2)
344+
fused_2 = sch.fuse(f_0, f_3)
345+
346+
sch.tensorize(outer, "mma_store")
347+
348+
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
349+
350+
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
351+
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
352+
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
353+
sch.reorder(f_1, f_2, f_0, f_3)
354+
fused_1 = sch.fuse(f_1, f_2)
355+
fused_2 = sch.fuse(f_0, f_3)
356+
sch.tensorize(fused_1, "mma_fill")
357+
358+
sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
359+
360+
print(sch.mod.script())
361+
362+
# lowered = tvm.lower(sch.mod["main"])
363+
364+
target = "cuda"
365+
366+
f = tvm.build(sch.mod["main"], target=target, name="dense")
367+
dev = tvm.device(target, 0)
368+
369+
a_np = np.random.uniform(size=(16, K)).astype("float16")
370+
b_np = np.random.uniform(size=(K, K)).astype("float16")
371+
c_np = np.dot(a_np, b_np)
372+
373+
a = tvm.nd.array(a_np, dev)
374+
b = tvm.nd.array(b_np, dev)
375+
c = tvm.nd.array(np.zeros((16, K), dtype="float16"), dev)
376+
377+
# print(f.imported_modules[0].get_source())
378+
f(a, b, c)
379+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

0 commit comments

Comments
 (0)