Skip to content

Commit c0abab7

Browse files
author
Siyuan Feng
authored
[TIR][DLight] Enable SimdGroup op for Metal (#17112)
1 parent 4ef9011 commit c0abab7

File tree

11 files changed

+1124
-7
lines changed

11 files changed

+1124
-7
lines changed

include/tvm/tir/builtin.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers();
746746
TVM_DLL const Op& mma_store();
747747

748748
/*!
749-
* \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
749+
* \brief tvm intrinsic for zero-initializing an MMA accumulation register.
750750
* For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
751751
* m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
752752
* 4 accumulation registers.
@@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store();
758758
*/
759759
TVM_DLL const Op& mma_fill();
760760

761+
// Metal SimdGroup matrix intrinsics
762+
763+
/*!
764+
* \brief tvm intrinsic for initializing and simdgroup with given value.
765+
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
766+
* keeping the similar interface with Metal Spec.
767+
*
768+
* void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value,
769+
* int col = 8, int row = 8);
770+
*/
771+
TVM_DLL const Op& make_filled_simdgroup_matrix();
772+
773+
/*!
774+
* \brief tvm intrinsic for loading data from device memory or threadgroup memory to simdgroup.
775+
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
776+
* keeping the similar interface with Metal Spec.
777+
*
778+
* void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
779+
int col = 8, int row = 8, bool transpose_matrix = false);
780+
*/
781+
TVM_DLL const Op& simdgroup_load();
782+
783+
/*!
784+
* \brief tvm intrinsic for storing data from simdgroup to device memory or threadgroup memory.
785+
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
786+
* keeping the similar interface with Metal Spec.
787+
*
788+
* void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
789+
* int col = 8, int row = 8, bool transpose_matrix = false);
790+
*/
791+
TVM_DLL const Op& simdgroup_store();
792+
793+
/*!
794+
* \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup
795+
* \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params,
796+
* keeping the similar interface with Metal Spec.
797+
*
798+
* void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a,
799+
* Var b, PrimExpr index_b, Var c, PrimExpr index_c);
800+
*/
801+
TVM_DLL const Op& simdgroup_multiply_accumulate();
802+
761803
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
762804
/*!
763805
* \brief Get the high level half of the vector

python/tvm/dlight/gpu/matmul.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
313313
return int(sm_version) if sm_version.isdigit() else -1
314314

315315

316+
class MetalMatmul(GPUScheduleRule):
317+
"""
318+
The schedule rule for Metal matmul computation.
319+
"""
320+
321+
def apply( # pylint: disable=too-many-locals,missing-docstring
322+
self,
323+
func: tir.PrimFunc,
324+
target: Target,
325+
_: bool,
326+
) -> Optional[tir.Schedule]:
327+
from tvm.tir.tensor_intrin.metal import ( # pylint: disable=import-outside-toplevel
328+
get_simdgroup_intrin_group,
329+
)
330+
331+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
332+
return None
333+
sch = tir.Schedule(func)
334+
root_block = analysis.get_root_block(sch)
335+
blocks = sch.get_child_blocks(root_block)
336+
337+
reduction_blocks = get_reduction_blocks(sch, blocks)
338+
if reduction_blocks is None:
339+
return None
340+
341+
main_block = reduction_blocks[0]
342+
block_stmt = sch.get(main_block)
343+
index_maps = get_index_map(block_stmt)
344+
if index_maps is None:
345+
return None
346+
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
347+
348+
# Step 0. Configs
349+
block_size_x: int = 16
350+
block_size_y: int = 16
351+
block_size_k: int = 32
352+
micro_size: int = 8
353+
warp_size: int = 32
354+
ty_len: int = 1
355+
tz_len: int = 4
356+
vector_size: int = 4
357+
358+
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
359+
block = sch.reindex(main_block, ("read", 0))
360+
sch.transform_layout(block, ("write", 0), a_index_map)
361+
block = sch.reindex(main_block, ("read", 1))
362+
sch.transform_layout(block, ("write", 0), b_index_map)
363+
block = sch.reindex(main_block, ("write", 0))
364+
sch.transform_layout(block, ("read", 0), c_index_map)
365+
sch.transform_block_layout(main_block, matmul_index_map)
366+
367+
# Step 2. Padding for dynamic shape kernels
368+
sch.pad_einsum(
369+
main_block,
370+
[
371+
1,
372+
ty_len * block_size_x,
373+
tz_len * block_size_y,
374+
block_size_k,
375+
],
376+
)
377+
378+
# Step 3. Schedule matmul to use simdgroup intrinsics
379+
batch, i, j, k = sch.get_loops(main_block)
380+
bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x // micro_size, micro_size])
381+
by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y // micro_size, micro_size])
382+
k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size, micro_size])
383+
sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2)
384+
sch.bind(bx, "blockIdx.x")
385+
sch.bind(by, "blockIdx.y")
386+
sch.bind(batch, "blockIdx.z")
387+
sch.bind(ty, "threadIdx.y")
388+
sch.bind(tz, "threadIdx.z")
389+
390+
def fetch_to_shared(block, idx):
391+
block_read = sch.cache_read(block, idx, "shared")
392+
sch.compute_at(block_read, k0, preserve_unit_loops=True)
393+
fused = sch.fuse(*sch.get_loops(block_read)[-2:])
394+
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])
395+
396+
sch.bind(_tz, "threadIdx.z")
397+
sch.bind(_ty, "threadIdx.y")
398+
sch.bind(_tx, "threadIdx.x")
399+
sch.vectorize(vec)
400+
401+
return block_read
402+
403+
a_g2s = fetch_to_shared(main_block, 0)
404+
b_g2s = fetch_to_shared(main_block, 1)
405+
406+
auto_inline_producers(sch, a_g2s)
407+
auto_inline_producers(sch, b_g2s)
408+
409+
# create read cache to load matrix from shared memory to wmma fragments
410+
A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup")
411+
B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup")
412+
sch.compute_at(A_simdgroup, k1)
413+
sch.compute_at(B_simdgroup, k1)
414+
415+
C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup")
416+
C_s2g = sch.cache_write(C_simd2s, 0, "shared")
417+
sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True)
418+
sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True)
419+
420+
intrin_group = get_simdgroup_intrin_group(
421+
load_scope="shared",
422+
store_scope="shared",
423+
dtype="float16",
424+
trans_a=False,
425+
trans_b=True,
426+
)
427+
sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, i))
428+
429+
def tensorize_block(block: tir.schedule.BlockRV, intrin: str):
430+
*_, i, j = sch.get_loops(block)
431+
io, ii = sch.split(i, [None, micro_size])
432+
jo, ji = sch.split(j, [None, micro_size])
433+
sch.reorder(io, jo, ii, ji)
434+
sch.tensorize(ii, intrin)
435+
436+
C_init = sch.decompose_reduction(main_block, k0)
437+
tensorize_block(A_simdgroup, intrin_group["load_a"])
438+
tensorize_block(B_simdgroup, intrin_group["load_b"])
439+
tensorize_block(C_simd2s, intrin_group["store"])
440+
tensorize_block(C_init, intrin_group["init"])
441+
442+
*_, i, j, k = sch.get_loops(main_block)
443+
sch.tensorize(i, intrin_group["compute"])
444+
445+
auto_inline_consumer_chain(sch, C_s2g)
446+
fused = sch.fuse(*sch.get_loops(C_s2g)[-2:])
447+
_, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size])
448+
sch.bind(_tz, "threadIdx.z")
449+
sch.bind(_ty, "threadIdx.y")
450+
sch.bind(_tx, "threadIdx.x")
451+
sch.vectorize(vec)
452+
453+
return sch
454+
455+
316456
class MatmulTensorization(GPUScheduleRule):
317457
"""
318458
The schedule rule for float16 tensor core matmul computation.
@@ -848,6 +988,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
848988
tensorize_sch = MatmulTensorization().apply(func, target, _)
849989
if tensorize_sch is not None:
850990
return tensorize_sch
991+
elif target.kind.name == "metal":
992+
try:
993+
return MetalMatmul().apply(func, target, _)
994+
except: # pylint: disable=bare-except
995+
pass
851996

852997
# Step 2. Get schedule config.
853998
config = self.get_configs(target)

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,10 @@ def wrapped(*args, **kwargs):
18871887
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
18881888
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
18891889
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
1890+
make_filled_simdgroup_matrix = _op_wrapper(_tir_op.make_filled_simdgroup_matrix)
1891+
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
1892+
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
1893+
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
18901894
create_barriers = _op_wrapper(_tir_op.create_barriers)
18911895
assume = _op_wrapper(_tir_op.assume)
18921896
undef = _op_wrapper(_tir_op.undef)
@@ -2177,6 +2181,10 @@ def wrapped(*args, **kwargs):
21772181
"ptx_arrive_barrier",
21782182
"ptx_arrive_barrier_expect_tx",
21792183
"ptx_wait_barrier",
2184+
"make_filled_simdgroup_matrix",
2185+
"simdgroup_load",
2186+
"simdgroup_store",
2187+
"simdgroup_multiply_accumulate",
21802188
"create_barriers",
21812189
"mma_store",
21822190
"mma_fill",

python/tvm/tir/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@
7373
ptx_wait_barrier,
7474
create_barriers,
7575
)
76+
from .op import (
77+
make_filled_simdgroup_matrix,
78+
simdgroup_load,
79+
simdgroup_multiply_accumulate,
80+
simdgroup_store,
81+
)
7682
from .op import vectorlow, vectorhigh, vectorcombine
7783
from .op import infinity, reinterpret
7884
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz

0 commit comments

Comments
 (0)