Skip to content

Commit 93f9fe7

Browse files
committed
add 4k test
1 parent 3689ef7 commit 93f9fe7

File tree

1 file changed

+350
-0
lines changed

1 file changed

+350
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm import te, tir
19+
from tvm.script import tir as T
20+
import tvm.testing
21+
import numpy as np
22+
23+
24+
@T.prim_func
25+
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
26+
A_shared = T.match_buffer(
27+
a, (16, 8), "float16", align=128, offset_factor=16, scope="shared"
28+
)
29+
A_warp = T.match_buffer(
30+
c, (32, 4), "float16", align=128, offset_factor=16, scope="warp"
31+
)
32+
33+
with T.block("root"):
34+
T.reads(A_shared[0:16, 0:8])
35+
T.writes(A_warp[0:32, 0:4])
36+
37+
for ax0, ax1 in T.grid(16, 8):
38+
with T.block("A_shared_warp"):
39+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
40+
T.reads(A_shared[v0, v1])
41+
T.writes(A_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2])
42+
A_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2] = A_shared[v0, v1]
43+
44+
45+
@T.prim_func
46+
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
47+
s1 = T.var("int32")
48+
s0 = T.var("int32")
49+
A_shared = T.match_buffer(
50+
a,
51+
(16, 8),
52+
"float16",
53+
align=128,
54+
offset_factor=16,
55+
scope="shared",
56+
strides=[s1, s0],
57+
)
58+
A_warp = T.match_buffer(
59+
c, (32, 4), "float16", align=128, offset_factor=16, scope="warp"
60+
)
61+
with T.block("root"):
62+
T.reads(A_shared[0:16, 0:8])
63+
T.writes(A_warp[0:32, 0:4])
64+
tx = T.env_thread("threadIdx.x")
65+
T.launch_thread(tx, 32)
66+
67+
T.evaluate(
68+
T.ptx_ldmatrix(
69+
0,
70+
2,
71+
".b16",
72+
A_warp.data,
73+
4 * tx,
74+
A_shared.data,
75+
8 * (tx % 16),
76+
dtype="float16",
77+
)
78+
)
79+
80+
81+
@T.prim_func
82+
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
83+
B_shared = T.match_buffer(
84+
a, (8, 8), "float16", align=128, offset_factor=16, scope="shared"
85+
)
86+
B_shared_warp = T.match_buffer(
87+
c, (32, 2), "float16", align=128, offset_factor=16, scope="warp"
88+
)
89+
90+
with T.block("root"):
91+
T.reads(B_shared[0:8, 0:8])
92+
T.writes(B_shared_warp[0:32, 0:2])
93+
94+
for ax0, ax1 in T.grid(8, 8):
95+
with T.block("A_shared_warp"):
96+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
97+
T.reads(B_shared[v0, v1])
98+
T.writes(B_shared_warp[v1 * 4 + v0 // 2, v0 % 2])
99+
B_shared_warp[v1 * 4 + v0 // 2, v0 % 2] = B_shared[v0, v1]
100+
101+
102+
@T.prim_func
103+
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
104+
s1 = T.var("int32")
105+
s0 = T.var("int32")
106+
B_shared = T.match_buffer(
107+
a,
108+
(8, 8),
109+
"float16",
110+
align=128,
111+
offset_factor=16,
112+
scope="shared",
113+
strides=[s1, s0],
114+
)
115+
B_warp = T.match_buffer(
116+
c, (32, 2), "float16", align=128, offset_factor=16, scope="warp"
117+
)
118+
with T.block("root"):
119+
T.reads(B_shared[0:8, 0:8])
120+
T.writes(B_warp[0:32, 0:2])
121+
tx = T.env_thread("threadIdx.x")
122+
T.launch_thread(tx, 32)
123+
124+
T.evaluate(
125+
T.ptx_ldmatrix(
126+
0,
127+
1,
128+
".b16",
129+
B_warp.data,
130+
2 * tx,
131+
B_shared.data,
132+
8 * (tx % 8),
133+
dtype="float16",
134+
)
135+
)
136+
137+
138+
@T.prim_func
139+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
140+
A = T.match_buffer(a, [32, 4], dtype="float16", scope="warp")
141+
B = T.match_buffer(b, [32, 2], dtype="float16", scope="warp")
142+
C = T.match_buffer(c, [32, 4], dtype="float32", scope="warp")
143+
with T.block("root"):
144+
T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2])
145+
T.writes(C[0:32, 0:4])
146+
for i0, i1, i2 in T.grid(16, 8, 8):
147+
with T.block("C"):
148+
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
149+
150+
T.reads(
151+
C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2],
152+
A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2],
153+
B[k % 8 * 4 + j % 8 // 2, j % 2],
154+
)
155+
T.writes(C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2])
156+
C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2] = C[
157+
i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2
158+
] + T.cast(
159+
A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2], "float32"
160+
) * T.cast(
161+
B[k % 8 * 4 + j % 8 // 2, j % 2], "float32"
162+
)
163+
164+
165+
@T.prim_func
166+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
167+
A = T.match_buffer(a, (32, 4), "float16", align=128, offset_factor=1, scope="warp")
168+
B = T.match_buffer(b, (32, 2), "float16", align=128, offset_factor=1, scope="warp")
169+
C = T.match_buffer(c, (32, 4), "float32", align=128, offset_factor=1, scope="warp")
170+
171+
with T.block("root"):
172+
T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2])
173+
T.writes(C[0:32, 0:4])
174+
tx = T.env_thread("threadIdx.x")
175+
T.launch_thread(tx, 32)
176+
T.evaluate(
177+
T.ptx_mma(
178+
"m16n8k8",
179+
"row",
180+
"col",
181+
"fp16",
182+
"fp16",
183+
"fp32",
184+
A.data,
185+
A.elem_offset + tx * 4,
186+
B.data,
187+
B.elem_offset + tx * 2,
188+
C.data,
189+
C.elem_offset + tx * 4,
190+
False,
191+
dtype="float32",
192+
)
193+
)
194+
195+
196+
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
197+
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
198+
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
199+
200+
201+
def dense(n: int, m: int, k: int):
202+
a = te.placeholder((n, k), name="A", dtype="float16")
203+
b = te.placeholder((m, k), name="B", dtype="float16")
204+
k = te.reduce_axis((0, k), name="k")
205+
c = te.compute(
206+
(n, m),
207+
lambda i, j: te.sum(
208+
tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]),
209+
axis=[k],
210+
),
211+
name="C",
212+
)
213+
return (a, b, c)
214+
215+
216+
def test_integration_matmul():
217+
N = 4096
218+
M = 4096
219+
K = 4096
220+
221+
workload = te.create_prim_func(dense(n=N, m=M, k=K))
222+
223+
def schedule(sch: tir.Schedule):
224+
block = sch.get_block("C")
225+
i, j, k = sch.get_loops(block)
226+
227+
i, i_tc = sch.split(i, factors=[None, 16])
228+
j, j_tc = sch.split(j, factors=[None, 8])
229+
k_outer, k_tc = sch.split(k, factors=[None, 8])
230+
231+
sch.reorder(
232+
# fmt: off
233+
i, j, k_outer,
234+
# tensor core
235+
i_tc, j_tc, k_tc
236+
)
237+
238+
block_outer = sch.blockize(i_tc)
239+
block, _ = block_outer, block
240+
241+
sch.bind(sch.fuse(i, j), "blockIdx.x")
242+
243+
def fetch_to_shared(block, idx, ndim):
244+
block_read = sch.cache_read(block, idx, "shared")
245+
sch.compute_at(block_read, k_outer)
246+
warp_size = 32
247+
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
248+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
249+
sch.bind(f_1, "threadIdx.x")
250+
251+
fetch_to_shared(block, 0, 2)
252+
fetch_to_shared(block, 1, 2)
253+
254+
# fetch to A_warp 16 * 8 -> 32 * 4
255+
A_warp = sch.cache_read(block, 0, "warp")
256+
257+
def lambda_a(i, j):
258+
i_0 = i // 16
259+
j_0 = j // 8
260+
261+
i = i % 16
262+
j = j % 8
263+
return (
264+
i_0,
265+
j_0,
266+
(i % 8) * 4 + (j % 8) // 2,
267+
4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2,
268+
)
269+
270+
sch.transform_layout(A_warp, 0, "write", index_map=lambda_a)
271+
272+
sch.tensorize(sch.get_loops(A_warp)[2], "mma.ldmatrix_a")
273+
274+
def lambda_b(i, j):
275+
i_0 = i // 8
276+
j_0 = j // 8
277+
i = i % 8
278+
j = j % 8
279+
return i_0, j_0, i // 2 + j * 4, i % 2
280+
281+
B_warp = sch.cache_read(block, 1, "warp")
282+
sch.transform_layout(
283+
B_warp,
284+
0,
285+
"write",
286+
index_map=lambda_b,
287+
)
288+
sch.tensorize(sch.get_loops(B_warp)[2], "mma.ldmatrix_b")
289+
290+
# fetch to C_warp 16 * 8 -> 32 * 4
291+
C_warp = sch.cache_write(block, 0, "warp")
292+
# sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
293+
# need to do a reverse_compute_at to place it under blockidx.x
294+
295+
sch.transform_layout(
296+
C_warp,
297+
0,
298+
"read",
299+
index_map=lambda_a,
300+
)
301+
302+
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
303+
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
304+
f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
305+
sch.reorder(f_1, f_2, f_0, f_3)
306+
fused_1 = sch.fuse(f_1, f_2)
307+
fused_2 = sch.fuse(f_0, f_3)
308+
sch.bind(fused_1, "threadIdx.x")
309+
310+
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
311+
312+
block_init_c = sch.get_block("C_init")
313+
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
314+
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
315+
f_2, f_3 = sch.split(init_loop2, factors=[None, 2])
316+
sch.reorder(f_1, f_2, f_0, f_3)
317+
fused_1 = sch.fuse(f_1, f_2)
318+
fused_2 = sch.fuse(f_0, f_3)
319+
sch.bind(fused_1, "threadIdx.x")
320+
321+
block = sch.get_block("C")
322+
323+
i1, _, _ = sch.get_loops(block)[-3:]
324+
325+
sch.tensorize(i1, "mma_sync")
326+
327+
328+
sch = tir.Schedule(workload)
329+
schedule(sch)
330+
331+
print(sch.mod["main"].script())
332+
333+
target = "cuda"
334+
f = tvm.build(sch.mod["main"], target=target, name="dense")
335+
dev = tvm.device("cuda", 0)
336+
a_np = np.random.uniform(size=(N, K)).astype("float16")
337+
b_np = np.random.uniform(size=(M, K)).astype("float16")
338+
c_np = np.dot(a_np.astype("float32"), b_np.transpose().astype("float32"))
339+
a = tvm.nd.array(a_np, dev)
340+
b = tvm.nd.array(b_np, dev)
341+
c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev)
342+
# sys.exit(0)
343+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
344+
f(a, b, c)
345+
print(f.imported_modules[0].get_source())
346+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
347+
348+
349+
if __name__ == "__main__":
350+
test_integration_matmul()

0 commit comments

Comments
 (0)