Skip to content

Commit f6aadbf

Browse files
committed
Add 16x8x8 MMA + LDMatrix test
1 parent 4cf6b20 commit f6aadbf

File tree

1 file changed

+322
-0
lines changed

1 file changed

+322
-0
lines changed
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm import te, tir
19+
from tvm.script import tir as T
20+
import tvm.testing
21+
import numpy as np
22+
23+
24+
@T.prim_func
25+
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
26+
A_shared = T.match_buffer(
27+
a, (16, 8), "float16", align=128, offset_factor=16, scope="shared"
28+
)
29+
A_warp = T.match_buffer(
30+
c, (32, 4), "float16", align=128, offset_factor=16, scope="warp"
31+
)
32+
33+
with T.block("root"):
34+
T.reads(A_shared[0:16, 0:8])
35+
T.writes(A_warp[0:32, 0:4])
36+
37+
for ax0, ax1 in T.grid(16, 8):
38+
with T.block("A_shared_warp"):
39+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
40+
T.reads(A_shared[v0, v1])
41+
T.writes(
42+
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
43+
)
44+
A_warp[
45+
v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2
46+
] = A_shared[v0, v1]
47+
48+
49+
@T.prim_func
50+
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
51+
s1 = T.var("int32")
52+
s0 = T.var("int32")
53+
A_shared = T.match_buffer(
54+
a,
55+
(16, 8),
56+
"float16",
57+
align=128,
58+
offset_factor=16,
59+
scope="shared",
60+
strides=[s1, s0],
61+
)
62+
A_warp = T.match_buffer(
63+
c, (32, 4), "float16", align=128, offset_factor=16, scope="warp"
64+
)
65+
with T.block("root"):
66+
T.reads(A_shared[0:16, 0:8])
67+
T.writes(A_warp[0:32, 0:4])
68+
tx = T.env_thread("threadIdx.x")
69+
T.launch_thread(tx, 32)
70+
71+
T.evaluate(
72+
T.ptx_ldmatrix(
73+
0,
74+
2,
75+
".b16",
76+
A_warp.data,
77+
4 * tx,
78+
A_shared.data,
79+
8 * (tx % 16),
80+
dtype="float16",
81+
)
82+
)
83+
84+
85+
@T.prim_func
86+
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
87+
B_shared = T.match_buffer(
88+
a, (8, 8), "float16", align=128, offset_factor=16, scope="shared"
89+
)
90+
B_shared_warp = T.match_buffer(
91+
c, (32, 2), "float16", align=128, offset_factor=16, scope="warp"
92+
)
93+
94+
with T.block("root"):
95+
T.reads(B_shared[0:8, 0:8])
96+
T.writes(B_shared_warp[0:32, 0:2])
97+
98+
for ax0, ax1 in T.grid(8, 8):
99+
with T.block("A_shared_warp"):
100+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
101+
T.reads(B_shared[v0, v1])
102+
T.writes(B_shared_warp[v1 * 4 + v0 // 2, v0 % 2])
103+
B_shared_warp[v1 * 4 + v0 // 2, v0 % 2] = B_shared[v0, v1]
104+
105+
106+
@T.prim_func
107+
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
108+
s1 = T.var("int32")
109+
s0 = T.var("int32")
110+
B_shared = T.match_buffer(
111+
a,
112+
(8, 8),
113+
"float16",
114+
align=128,
115+
offset_factor=16,
116+
scope="shared",
117+
strides=[s1, s0],
118+
)
119+
B_warp = T.match_buffer(
120+
c, (32, 2), "float16", align=128, offset_factor=16, scope="warp"
121+
)
122+
with T.block("root"):
123+
T.reads(B_shared[0:8, 0:8])
124+
T.writes(B_warp[0:32, 0:2])
125+
tx = T.env_thread("threadIdx.x")
126+
T.launch_thread(tx, 32)
127+
128+
T.evaluate(
129+
T.ptx_ldmatrix(
130+
0,
131+
1,
132+
".b16",
133+
B_warp.data,
134+
2 * tx,
135+
B_shared.data,
136+
8 * (tx % 8),
137+
dtype="float16",
138+
)
139+
)
140+
141+
142+
@T.prim_func
143+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
144+
A = T.match_buffer(a, [32, 4], dtype="float16", scope="warp")
145+
B = T.match_buffer(b, [32, 2], dtype="float16", scope="warp")
146+
C = T.match_buffer(c, [32, 4], dtype="float32", scope="warp")
147+
with T.block("root"):
148+
T.reads(C[0 : 32, 0 : 4], A[0 : 32, 0 : 4], B[0 : 32, 0 : 2])
149+
T.writes(C[0 : 32, 0 : 4])
150+
for i0, i1, i2 in T.grid(16, 8, 8):
151+
with T.block("C"):
152+
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
153+
154+
T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], B[k * 4 + j // 2, j % 2])
155+
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
156+
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32") * T.cast(B[k * 4 + j // 2, j % 2], "float32")
157+
158+
159+
@T.prim_func
160+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
161+
A = T.match_buffer(a, (32, 4), "float16", align=128, offset_factor=1, scope="warp")
162+
B = T.match_buffer(b, (32, 2), "float16", align=128, offset_factor=1, scope="warp")
163+
C = T.match_buffer(c, (32, 4), "float32", align=128, offset_factor=1, scope="warp")
164+
165+
with T.block("root"):
166+
T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2])
167+
T.writes(C[0:32, 0:4])
168+
tx = T.env_thread("threadIdx.x")
169+
T.launch_thread(tx, 32)
170+
T.evaluate(
171+
T.ptx_mma(
172+
"m16n8k8",
173+
"row",
174+
"col",
175+
"fp16",
176+
"fp16",
177+
"fp32",
178+
A.data,
179+
A.elem_offset + tx * 4,
180+
B.data,
181+
B.elem_offset + tx * 2,
182+
C.data,
183+
C.elem_offset + tx * 4,
184+
False,
185+
dtype="float32",
186+
)
187+
)
188+
189+
190+
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
191+
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
192+
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
193+
194+
195+
def dense(n: int, m: int, k: int):
196+
a = te.placeholder((n, k), name="A", dtype="float16")
197+
b = te.placeholder((m, k), name="B", dtype="float16")
198+
k = te.reduce_axis((0, k), name="k")
199+
c = te.compute(
200+
(n, m),
201+
lambda i, j: te.sum(
202+
tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]),
203+
axis=[k],
204+
),
205+
name="C",
206+
)
207+
return (a, b, c)
208+
209+
210+
def test_integration_matmul():
211+
N = 16
212+
M = 8
213+
K = 8
214+
215+
workload = te.create_prim_func(dense(n=N, m=M, k=K))
216+
217+
def schedule(sch: tir.Schedule):
218+
block = sch.get_block("C")
219+
i, j, k = sch.get_loops(block)
220+
221+
# Step 2. Rule-Multi-Level-Tiling
222+
i1, i2 = sch.split(i, factors=[None, 16])
223+
sch.bind(i1, "blockIdx.x")
224+
225+
def fetch_to_shared(block, idx, ndim):
226+
block_read = sch.cache_read(block, idx, "shared")
227+
sch.compute_at(block_read, i1)
228+
warp_size = 32
229+
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
230+
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
231+
sch.bind(f_1, "threadIdx.x")
232+
233+
fetch_to_shared(block, 0, 2)
234+
fetch_to_shared(block, 1, 2)
235+
236+
# fetch to A_warp 16 * 8 -> 32 * 4
237+
A_warp = sch.cache_read(block, 0, "warp")
238+
sch.transform_layout(
239+
A_warp,
240+
0,
241+
"write",
242+
index_map=lambda i, j: (
243+
(i % 8) * 4 + (j % 8) // 2,
244+
4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2,
245+
),
246+
)
247+
248+
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
249+
250+
B_warp = sch.cache_read(block, 1, "warp")
251+
sch.transform_layout(
252+
B_warp,
253+
0,
254+
"write",
255+
index_map=lambda i, j: (i // 2 + j * 4, i % 2),
256+
)
257+
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
258+
259+
# fetch to C_warp 16 * 8 -> 32 * 4
260+
C_warp = sch.cache_write(block, 0, "warp")
261+
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
262+
# need to do a reverse_compute_at to place it under blockidx.x
263+
sch.transform_layout(
264+
C_warp,
265+
0,
266+
"read",
267+
index_map=lambda i, j: (
268+
(i % 8) * 4 + (j % 8) // 2,
269+
4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2,
270+
),
271+
)
272+
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
273+
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
274+
f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
275+
sch.reorder(f_1, f_2, f_0, f_3)
276+
fused_1 = sch.fuse(f_1, f_2)
277+
fused_2 = sch.fuse(f_0, f_3)
278+
sch.bind(fused_1, "threadIdx.x")
279+
280+
# Decompose -> separate C_init from C_warp
281+
loop = sch.get_loops(block)[1]
282+
block_init_c = sch.decompose_reduction(block, loop)
283+
284+
# C_init() 16 * 8 -> 32 * 4
285+
# as binding is already transformed by previous step
286+
# only split/reorder/fuse is needed here
287+
C_init = block_init_c
288+
init_loop1, init_loop2 = sch.get_loops(C_init)[-2:]
289+
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
290+
f_2, f_3 = sch.split(init_loop2, factors=[None, 2])
291+
sch.reorder(f_1, f_2, f_0, f_3)
292+
fused_1 = sch.fuse(f_1, f_2)
293+
fused_2 = sch.fuse(f_0, f_3)
294+
sch.bind(fused_1, "threadIdx.x")
295+
296+
# tensorize
297+
i0, i1, i2, i3 = sch.get_loops(block)
298+
sch.tensorize(i1, "mma_sync")
299+
300+
sch = tir.Schedule(workload)
301+
schedule(sch)
302+
303+
print(sch.mod["main"].script())
304+
305+
target = "cuda"
306+
f = tvm.build(sch.mod["main"], target=target, name="dense")
307+
dev = tvm.device("cuda", 0)
308+
a_np = np.random.uniform(size=(N, K)).astype("float16")
309+
b_np = np.random.uniform(size=(M, K)).astype("float16")
310+
c_np = np.dot(a_np.astype("float32"), b_np.transpose().astype("float32"))
311+
a = tvm.nd.array(a_np, dev)
312+
b = tvm.nd.array(b_np, dev)
313+
c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev)
314+
# sys.exit(0)
315+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
316+
f(a, b, c)
317+
print(f.imported_modules[0].get_source())
318+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
319+
320+
321+
if __name__ == "__main__":
322+
test_integration_matmul()

0 commit comments

Comments
 (0)