Skip to content

Commit 5afb5f0

Browse files
committed
ldmatrix intrin generation with meta programming
1 parent fb62abb commit 5afb5f0

File tree

8 files changed

+246
-739
lines changed

8 files changed

+246
-739
lines changed

python/tvm/script/parser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ def transform_Assign(self, node):
617617
if node.ty is None and hasattr(value, "dtype"):
618618
var_ty = value.dtype
619619
else:
620-
print(node.ty, ast_var)
621620
var_ty = self.parse_type(node.ty, ast_var)
622621

623622
var = tvm.te.var(
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Intrinsics for tensorization on NVIDIA GPU."""
19+
from ..._ffi import register_func
20+
from ...runtime import convert
21+
from .. import TensorIntrin
22+
from tvm.script import tir as T
23+
24+
25+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
26+
thread_id = 4 * (i % 8) + (j % 8) // 2
27+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
28+
29+
30+
def shared_16x32_to_ldmatrix_32x16_layout(i, j):
31+
thread_id = 4 * (i % 8) + (j % 16) // 4
32+
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
33+
34+
35+
def shared_32x16_to_ldmatrix_32x16_layout(i, j):
36+
thread_id = (i % 4) + 4 * (j % 8)
37+
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
38+
39+
40+
@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
41+
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
42+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
43+
return convert([thread_id, local_id])
44+
45+
46+
lift = convert
47+
48+
M_DIM = 16
49+
WARP_SIZE = 32
50+
HALF_WARP = WARP_SIZE // 2
51+
HALF_WARP_expr = lift(HALF_WARP)
52+
53+
54+
def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
55+
local_size = (M_DIM * k_dim) // WARP_SIZE
56+
shared_offset = None
57+
index_map = None
58+
59+
if transposed:
60+
assert is_b, "Transposed A matrix not supported"
61+
62+
ldmatrix_col_major = is_b and not transposed
63+
64+
if k_dim == 16:
65+
assert dtype == "float16"
66+
67+
index_map = shared_16x16_to_ldmatrix_32x8_layout
68+
69+
if transposed:
70+
shared_offset = (
71+
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
72+
+ stride * (tx % 8)
73+
+ 8 * ((tx % HALF_WARP_expr) // 8)
74+
)
75+
else:
76+
shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * (
77+
tx // HALF_WARP_expr
78+
)
79+
80+
elif k_dim == 32:
81+
assert dtype == "int8"
82+
83+
if ldmatrix_col_major:
84+
print("foo")
85+
index_map = shared_32x16_to_ldmatrix_32x16_layout
86+
shared_offset = (
87+
lambda _, stride: stride
88+
) # dummy offset, ldmatrix cannot be used for int8 + trans case
89+
elif is_b and transposed:
90+
index_map = shared_16x32_to_ldmatrix_32x16_layout
91+
shared_offset = (
92+
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
93+
+ (tx % 8) * stride
94+
+ 16 * ((tx % HALF_WARP_expr) // 8)
95+
)
96+
else:
97+
index_map = shared_16x32_to_ldmatrix_32x16_layout
98+
shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16)
99+
100+
else:
101+
assert False, "Unsupported k dim"
102+
103+
assert index_map and shared_offset
104+
105+
if is_b and not transposed:
106+
row_dim = k_dim
107+
col_dim = M_DIM
108+
else:
109+
row_dim = M_DIM
110+
col_dim = k_dim
111+
112+
shmem_shape = (row_dim, col_dim)
113+
114+
@T.prim_func
115+
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
116+
shared = T.match_buffer(
117+
shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared"
118+
)
119+
warp = T.match_buffer(
120+
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
121+
)
122+
123+
with T.block("root"):
124+
T.reads(shared[0:row_dim, 0:col_dim])
125+
T.writes(warp[0:WARP_SIZE, 0:local_size])
126+
127+
for ax0, ax1 in T.grid(row_dim, col_dim):
128+
with T.block("shared_warp"):
129+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
130+
T.reads(shared[v0, v1])
131+
132+
thread_id, local_id = index_map(v0, v1)
133+
T.writes(warp[thread_id, local_id])
134+
warp[thread_id, local_id] = shared[v0, v1]
135+
136+
@T.prim_func
137+
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
138+
s0 = T.var("int32")
139+
s1 = T.var("int32")
140+
shared = T.match_buffer(
141+
shared_handle,
142+
shmem_shape,
143+
dtype,
144+
align=128,
145+
offset_factor=16,
146+
scope="shared",
147+
strides=[s0, s1],
148+
)
149+
warp = T.match_buffer(
150+
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
151+
)
152+
153+
with T.block("root"):
154+
T.reads(shared[0:row_dim, 0:col_dim])
155+
T.writes(warp[0:WARP_SIZE, 0:local_size])
156+
tx = T.env_thread("threadIdx.x")
157+
T.launch_thread(tx, WARP_SIZE)
158+
159+
T.evaluate(
160+
T.ptx_ldmatrix(
161+
ldmatrix_col_major,
162+
4, # Always load 4 matrices
163+
".b16",
164+
warp.data,
165+
warp.elem_offset + lift(local_size) * tx,
166+
shared.access_ptr("r"),
167+
shared_offset(tx, s0),
168+
dtype=dtype,
169+
)
170+
)
171+
172+
return ldmatrix_desc, ldmatrix_impl
173+
174+
175+
LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
176+
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))
177+
178+
LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
179+
TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False))
180+
181+
LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
182+
TensorIntrin.register(
183+
LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True)
184+
)
185+
186+
LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
187+
TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, "int8", False, False))
188+
189+
LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
190+
TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, "int8", True, False))
191+
192+
LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
193+
TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True))

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -4,126 +4,15 @@
44
import tvm.meta_schedule.testing.te_workload as te_workload
55
from tvm import te, tir
66
from tvm import meta_schedule as ms
7+
from tvm.tir.tensor_intrin.cuda import (
8+
LDMATRIX_16x16_A_INTRIN,
9+
LDMATRIX_16x16_B_INTRIN,
10+
shared_16x16_to_ldmatrix_32x8_layout,
11+
)
712
import tvm.testing
813
import numpy as np
914

1015

11-
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
12-
thread_id = 4 * (i % 8) + (j % 8) // 2
13-
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
14-
15-
16-
@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
17-
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
18-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
19-
return tvm.runtime.convert([thread_id, local_id])
20-
21-
22-
@T.prim_func
23-
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
24-
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
25-
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
26-
27-
with T.block("root"):
28-
T.reads(A_shared[0:16, 0:16])
29-
T.writes(A_warp[0:32, 0:8])
30-
31-
for ax0, ax1 in T.grid(16, 16):
32-
with T.block("A_shared_warp"):
33-
v0, v1 = T.axis.remap("SS", [ax0, ax1])
34-
T.reads(A_shared[v0, v1])
35-
36-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
37-
T.writes(A_warp[thread_id, local_id])
38-
A_warp[thread_id, local_id] = A_shared[v0, v1]
39-
40-
41-
@T.prim_func
42-
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
43-
s1 = T.var("int32")
44-
s0 = T.var("int32")
45-
A_shared = T.match_buffer(
46-
a,
47-
(16, 16),
48-
"float16",
49-
align=128,
50-
offset_factor=16,
51-
scope="shared",
52-
strides=[s1, s0],
53-
)
54-
A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
55-
with T.block("root"):
56-
T.reads(A_shared[0:16, 0:16])
57-
T.writes(A_warp[0:32, 0:8])
58-
tx = T.env_thread("threadIdx.x")
59-
T.launch_thread(tx, 32)
60-
61-
T.evaluate(
62-
T.ptx_ldmatrix(
63-
0,
64-
4,
65-
".b16",
66-
A_warp.data,
67-
A_warp.elem_offset + 8 * tx,
68-
A_shared.access_ptr("r"),
69-
s1 * (tx % 16) + 8 * (tx // 16),
70-
dtype="float16",
71-
)
72-
)
73-
74-
75-
@T.prim_func
76-
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
77-
B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
78-
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
79-
80-
with T.block("root"):
81-
T.reads(B_shared[0:16, 0:16])
82-
T.writes(B_warp[0:32, 0:8])
83-
84-
for ax0, ax1 in T.grid(16, 16):
85-
with T.block("B_shared_warp"):
86-
v0, v1 = T.axis.remap("SS", [ax0, ax1])
87-
T.reads(B_shared[v0, v1])
88-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
89-
T.writes(B_warp[thread_id, local_id])
90-
B_warp[thread_id, local_id] = B_shared[v0, v1]
91-
92-
93-
@T.prim_func
94-
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
95-
s1 = T.var("int32")
96-
s0 = T.var("int32")
97-
B_shared = T.match_buffer(
98-
a,
99-
(16, 16),
100-
"float16",
101-
align=128,
102-
offset_factor=16,
103-
scope="shared",
104-
strides=[s1, s0],
105-
)
106-
B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
107-
with T.block("root"):
108-
T.reads(B_shared[0:16, 0:16])
109-
T.writes(B_warp[0:32, 0:8])
110-
tx = T.env_thread("threadIdx.x")
111-
T.launch_thread(tx, 32)
112-
113-
T.evaluate(
114-
T.ptx_ldmatrix(
115-
1,
116-
4,
117-
".b16",
118-
B_warp.data,
119-
B_warp.elem_offset + 8 * tx,
120-
B_shared.access_ptr("r"),
121-
s1 * (tx % 16) + 8 * (tx // 16),
122-
dtype="float16",
123-
)
124-
)
125-
126-
12716
@T.prim_func
12817
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
12918
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
@@ -271,8 +160,6 @@ def mma_fill_impl(a: T.handle) -> None:
271160
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
272161

273162

274-
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
275-
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
276163
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
277164
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
278165
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
@@ -402,8 +289,8 @@ def index_map(i, j):
402289
sch.transform_layout(B_warp, 0, "write", index_map)
403290
sch.transform_layout(C_warp, 0, "read", index_map)
404291

405-
sch.tensorize(loop_a, "mma.ldmatrix_a")
406-
sch.tensorize(loop_b, "mma.ldmatrix_b")
292+
sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
293+
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
407294
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
408295
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
409296
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")

0 commit comments

Comments
 (0)