Skip to content

Commit 2b05b5a

Browse files
committed
generate mma fill/store
1 parent bf23fc5 commit 2b05b5a

File tree

7 files changed

+134
-460
lines changed

7 files changed

+134
-460
lines changed

python/tvm/tir/tensor_intrin/cuda.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name,missing-function-docstring
1818
"""Intrinsics for tensorization on NVIDIA GPU."""
19-
from .. import Cast
19+
from .. import IntImm, Cast
2020
from ..._ffi import register_func
2121
from ...runtime import convert
2222
from .. import TensorIntrin
@@ -315,6 +315,97 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
315315
return mma_sync_desc, mma_sync_impl
316316

317317

318+
def get_mma_fill_intrin(dtype, local_size):
319+
zero = IntImm("int32", 0).astype(dtype)
320+
321+
# Assume M = N = 16
322+
index_map = shared_16x16_to_ldmatrix_32x8_layout
323+
324+
@T.prim_func
325+
def mma_fill_desc(a: T.handle) -> None:
326+
C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp")
327+
328+
with T.block("root"):
329+
T.reads()
330+
T.writes(C_warp[0:WARP_SIZE, 0:local_size])
331+
for i0, i1 in T.grid(M_DIM, N_DIM):
332+
with T.block("C_warp"):
333+
i, j = T.axis.remap("SS", [i0, i1])
334+
thread_id, local_id = index_map(i, j)
335+
T.reads()
336+
T.writes(C_warp[thread_id, local_id])
337+
C_warp[thread_id, local_id] = zero
338+
339+
@T.prim_func
340+
def mma_fill_impl(a: T.handle) -> None:
341+
C_warp = T.match_buffer(
342+
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
343+
)
344+
345+
with T.block("root"):
346+
T.reads()
347+
T.writes(C_warp[0:WARP_SIZE, 0:local_size])
348+
tx = T.env_thread("threadIdx.x")
349+
T.launch_thread(tx, WARP_SIZE)
350+
351+
T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype))
352+
353+
return mma_fill_desc, mma_fill_impl
354+
355+
356+
def get_mma_store_intrin(dtype, local_size):
357+
# Assume M = N = 16
358+
index_map = shared_16x16_to_ldmatrix_32x8_layout
359+
360+
@T.prim_func
361+
def mma_store_desc(a: T.handle, c: T.handle) -> None:
362+
C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp")
363+
C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global")
364+
365+
with T.block("root"):
366+
T.reads(C_warp[0:WARP_SIZE, 0:local_size])
367+
T.writes(C[0:M_DIM, 0:N_DIM])
368+
for i0, i1 in T.grid(M_DIM, N_DIM):
369+
with T.block("C_warp"):
370+
v0, v1 = T.axis.remap("SS", [i0, i1])
371+
thread_id, local_id = index_map(v0, v1)
372+
T.reads(C_warp[thread_id, local_id])
373+
T.writes(C[v0, v1])
374+
C[v0, v1] = C_warp[thread_id, local_id]
375+
376+
@T.prim_func
377+
def mma_store_impl(a: T.handle, c: T.handle) -> None:
378+
s0 = T.var("int32")
379+
s1 = T.var("int32")
380+
381+
C_warp = T.match_buffer(
382+
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
383+
)
384+
C = T.match_buffer(
385+
c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1]
386+
)
387+
388+
with T.block("root"):
389+
T.reads(C_warp[0:WARP_SIZE, 0:local_size])
390+
T.writes(C[0:M_DIM, 0:N_DIM])
391+
tx = T.env_thread("threadIdx.x")
392+
T.launch_thread(tx, WARP_SIZE)
393+
394+
T.evaluate(
395+
T.mma_store(
396+
M_DIM,
397+
N_DIM,
398+
C.access_ptr("w"),
399+
C_warp.data,
400+
C_warp.elem_offset,
401+
s0,
402+
dtype=dtype,
403+
)
404+
)
405+
406+
return mma_store_desc, mma_store_impl
407+
408+
318409
LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
319410
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))
320411

@@ -352,3 +443,21 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
352443

353444
MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
354445
TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True))
446+
447+
MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
448+
TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8))
449+
450+
MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
451+
TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8))
452+
453+
MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
454+
TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8))
455+
456+
MMA_store_16x16_f32_INTRIN = "mma_store_16x16_f32"
457+
TensorIntrin.register(MMA_store_16x16_f32_INTRIN, *get_mma_store_intrin("float32", 8))
458+
459+
MMA_store_16x16_f16_INTRIN = "mma_store_16x16_f16"
460+
TensorIntrin.register(MMA_store_16x16_f16_INTRIN, *get_mma_store_intrin("float16", 8))
461+
462+
MMA_store_16x16_i32_INTRIN = "mma_store_16x16_i32"
463+
TensorIntrin.register(MMA_store_16x16_i32_INTRIN, *get_mma_store_intrin("int32", 8))

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,19 @@
1-
import tempfile
21
import tvm
3-
from tvm.script import tir as T
42
import tvm.meta_schedule.testing.te_workload as te_workload
53
from tvm import te, tir
64
from tvm import meta_schedule as ms
75
from tvm.tir.tensor_intrin.cuda import (
86
LDMATRIX_16x16_A_INTRIN,
97
LDMATRIX_16x16_B_INTRIN,
108
MMA_f16f16f32_INTRIN,
9+
MMA_fill_16x16_f32_INTRIN,
10+
MMA_store_16x16_f32_INTRIN,
1111
shared_16x16_to_ldmatrix_32x8_layout,
1212
)
1313
import tvm.testing
1414
import numpy as np
1515

1616

17-
@T.prim_func
18-
def mma_store_desc(a: T.handle, c: T.handle) -> None:
19-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
20-
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")
21-
22-
with T.block("root"):
23-
T.reads(C_warp[0:32, 0:8])
24-
T.writes(C[0:16, 0:16])
25-
for i0, i1 in T.grid(16, 16):
26-
with T.block("C_warp"):
27-
v0, v1 = T.axis.remap("SS", [i0, i1])
28-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
29-
T.reads(C_warp[thread_id, local_id])
30-
T.writes(C[v0, v1])
31-
C[v0, v1] = C_warp[thread_id, local_id]
32-
33-
34-
@T.prim_func
35-
def mma_store_impl(a: T.handle, c: T.handle) -> None:
36-
s1 = T.var("int32")
37-
s0 = T.var("int32")
38-
39-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
40-
C = T.match_buffer(
41-
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
42-
)
43-
44-
with T.block("root"):
45-
T.reads(C_warp[0:32, 0:8])
46-
T.writes(C[0:16, 0:16])
47-
tx = T.env_thread("threadIdx.x")
48-
T.launch_thread(tx, 32)
49-
50-
T.evaluate(
51-
T.mma_store(
52-
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
53-
)
54-
)
55-
56-
57-
@T.prim_func
58-
def mma_fill_desc(a: T.handle) -> None:
59-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
60-
61-
with T.block("root"):
62-
T.reads()
63-
T.writes(C_warp[0:32, 0:8])
64-
for i0, i1 in T.grid(16, 16):
65-
with T.block("C_warp"):
66-
i_init, j_init = T.axis.remap("SS", [i0, i1])
67-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
68-
T.reads()
69-
T.writes(C_warp[thread_id, local_id])
70-
C_warp[thread_id, local_id] = T.float32(0)
71-
72-
73-
@T.prim_func
74-
def mma_fill_impl(a: T.handle) -> None:
75-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
76-
77-
with T.block("root"):
78-
T.reads()
79-
T.writes(C_warp[0:32, 0:8])
80-
tx = T.env_thread("threadIdx.x")
81-
T.launch_thread(tx, 32)
82-
83-
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
84-
85-
86-
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
87-
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
88-
8917
N = 4096
9018
M = 4096
9119
K = 4096
@@ -214,8 +142,8 @@ def index_map(i, j):
214142
sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
215143
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
216144
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN)
217-
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
218-
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
145+
sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN)
146+
sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN)
219147

220148

221149
ir_module = tvm.IRModule({"main": workload})

tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py

Lines changed: 4 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,18 @@
1-
import tempfile
21
import tvm
3-
from tvm.script import tir as T
4-
import tvm.meta_schedule.testing.te_workload as te_workload
52
from tvm import te, tir
63
from tvm.tir.tensor_intrin.cuda import (
74
LDMATRIX_16x16_A_INTRIN,
85
LDMATRIX_16x16_B_TRANS_INTRIN,
96
MMA_f16f16f32_TRANS_INTRIN,
7+
MMA_fill_16x16_f32_INTRIN,
8+
MMA_store_16x16_f32_INTRIN,
109
shared_16x16_to_ldmatrix_32x8_layout,
1110
)
1211
from tvm import meta_schedule as ms
1312
import tvm.testing
1413
import numpy as np
1514

1615

17-
@T.prim_func
18-
def mma_store_desc(a: T.handle, c: T.handle) -> None:
19-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
20-
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")
21-
22-
with T.block("root"):
23-
T.reads(C_warp[0:32, 0:8])
24-
T.writes(C[0:16, 0:16])
25-
for i0, i1 in T.grid(16, 16):
26-
with T.block("C_warp"):
27-
v0, v1 = T.axis.remap("SS", [i0, i1])
28-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
29-
T.reads(C_warp[thread_id, local_id])
30-
T.writes(C[v0, v1])
31-
C[v0, v1] = C_warp[thread_id, local_id]
32-
33-
34-
@T.prim_func
35-
def mma_store_impl(a: T.handle, c: T.handle) -> None:
36-
s1 = T.var("int32")
37-
s0 = T.var("int32")
38-
39-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
40-
C = T.match_buffer(
41-
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
42-
)
43-
44-
with T.block("root"):
45-
T.reads(C_warp[0:32, 0:8])
46-
T.writes(C[0:16, 0:16])
47-
tx = T.env_thread("threadIdx.x")
48-
T.launch_thread(tx, 32)
49-
50-
T.evaluate(
51-
T.mma_store(
52-
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
53-
)
54-
)
55-
56-
57-
@T.prim_func
58-
def mma_fill_desc(a: T.handle) -> None:
59-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
60-
61-
with T.block("root"):
62-
T.reads()
63-
T.writes(C_warp[0:32, 0:8])
64-
for i0, i1 in T.grid(16, 16):
65-
with T.block("C_warp"):
66-
i_init, j_init = T.axis.remap("SS", [i0, i1])
67-
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
68-
T.reads()
69-
T.writes(C_warp[thread_id, local_id])
70-
C_warp[thread_id, local_id] = T.float32(0)
71-
72-
73-
@T.prim_func
74-
def mma_fill_impl(a: T.handle) -> None:
75-
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
76-
77-
with T.block("root"):
78-
T.reads()
79-
T.writes(C_warp[0:32, 0:8])
80-
tx = T.env_thread("threadIdx.x")
81-
T.launch_thread(tx, 32)
82-
83-
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
84-
85-
86-
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
87-
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
88-
8916
N = 4096
9017
M = 4096
9118
K = 4096
@@ -231,8 +158,8 @@ def index_map(i, j):
231158
sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN)
232159
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_TRANS_INTRIN)
233160

234-
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
235-
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
161+
sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN)
162+
sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN)
236163

237164

238165
ir_module = tvm.IRModule({"main": workload})

0 commit comments

Comments
 (0)