@@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
313313 return int (sm_version ) if sm_version .isdigit () else - 1
314314
315315
316+ class MetalMatmul (GPUScheduleRule ):
317+ """
318+ The schedule rule for Metal matmul computation.
319+ """
320+
321+ def apply ( # pylint: disable=too-many-locals,missing-docstring
322+ self ,
323+ func : tir .PrimFunc ,
324+ target : Target ,
325+ _ : bool ,
326+ ) -> Optional [tir .Schedule ]:
327+ from tvm .tir .tensor_intrin .metal import ( # pylint: disable=import-outside-toplevel
328+ get_simdgroup_intrin_group ,
329+ )
330+
331+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
332+ return None
333+ sch = tir .Schedule (func )
334+ root_block = analysis .get_root_block (sch )
335+ blocks = sch .get_child_blocks (root_block )
336+
337+ reduction_blocks = get_reduction_blocks (sch , blocks )
338+ if reduction_blocks is None :
339+ return None
340+
341+ main_block = reduction_blocks [0 ]
342+ block_stmt = sch .get (main_block )
343+ index_maps = get_index_map (block_stmt )
344+ if index_maps is None :
345+ return None
346+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
347+
348+ # Step 0. Configs
349+ block_size_x : int = 16
350+ block_size_y : int = 16
351+ block_size_k : int = 32
352+ micro_size : int = 8
353+ warp_size : int = 32
354+ ty_len : int = 1
355+ tz_len : int = 4
356+ vector_size : int = 4
357+
358+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
359+ block = sch .reindex (main_block , ("read" , 0 ))
360+ sch .transform_layout (block , ("write" , 0 ), a_index_map )
361+ block = sch .reindex (main_block , ("read" , 1 ))
362+ sch .transform_layout (block , ("write" , 0 ), b_index_map )
363+ block = sch .reindex (main_block , ("write" , 0 ))
364+ sch .transform_layout (block , ("read" , 0 ), c_index_map )
365+ sch .transform_block_layout (main_block , matmul_index_map )
366+
367+ # Step 2. Padding for dynamic shape kernels
368+ sch .pad_einsum (
369+ main_block ,
370+ [
371+ 1 ,
372+ ty_len * block_size_x ,
373+ tz_len * block_size_y ,
374+ block_size_k ,
375+ ],
376+ )
377+
378+ # Step 3. Schedule matmul to use simdgroup intrinsics
379+ batch , i , j , k = sch .get_loops (main_block )
380+ bx , ty , i0 , i1 = sch .split (i , [None , ty_len , block_size_x // micro_size , micro_size ])
381+ by , tz , j0 , j1 = sch .split (j , [None , tz_len , block_size_y // micro_size , micro_size ])
382+ k0 , k1 , k2 = sch .split (k , [None , block_size_k // micro_size , micro_size ])
383+ sch .reorder (bx , by , ty , tz , k0 , k1 , i0 , j0 , i1 , j1 , k2 )
384+ sch .bind (bx , "blockIdx.x" )
385+ sch .bind (by , "blockIdx.y" )
386+ sch .bind (batch , "blockIdx.z" )
387+ sch .bind (ty , "threadIdx.y" )
388+ sch .bind (tz , "threadIdx.z" )
389+
390+ def fetch_to_shared (block , idx ):
391+ block_read = sch .cache_read (block , idx , "shared" )
392+ sch .compute_at (block_read , k0 , preserve_unit_loops = True )
393+ fused = sch .fuse (* sch .get_loops (block_read )[- 2 :])
394+ _ , _tz , _ty , _tx , vec = sch .split (fused , [None , tz_len , ty_len , warp_size , vector_size ])
395+
396+ sch .bind (_tz , "threadIdx.z" )
397+ sch .bind (_ty , "threadIdx.y" )
398+ sch .bind (_tx , "threadIdx.x" )
399+ sch .vectorize (vec )
400+
401+ return block_read
402+
403+ a_g2s = fetch_to_shared (main_block , 0 )
404+ b_g2s = fetch_to_shared (main_block , 1 )
405+
406+ auto_inline_producers (sch , a_g2s )
407+ auto_inline_producers (sch , b_g2s )
408+
409+ # create read cache to load matrix from shared memory to wmma fragments
410+ A_simdgroup = sch .cache_read (main_block , 0 , "metal.simdgroup" )
411+ B_simdgroup = sch .cache_read (main_block , 1 , "metal.simdgroup" )
412+ sch .compute_at (A_simdgroup , k1 )
413+ sch .compute_at (B_simdgroup , k1 )
414+
415+ C_simd2s = sch .cache_write (main_block , 0 , "metal.simdgroup" )
416+ C_s2g = sch .cache_write (C_simd2s , 0 , "shared" )
417+ sch .reverse_compute_at (C_simd2s , tz , preserve_unit_loops = True )
418+ sch .reverse_compute_at (C_s2g , by , preserve_unit_loops = True )
419+
420+ intrin_group = get_simdgroup_intrin_group (
421+ load_scope = "shared" ,
422+ store_scope = "shared" ,
423+ dtype = "float16" ,
424+ trans_a = False ,
425+ trans_b = True ,
426+ )
427+ sch .transform_layout (B_simdgroup , ("write" , 0 ), lambda s , i , j : (s , j , i ))
428+
429+ def tensorize_block (block : tir .schedule .BlockRV , intrin : str ):
430+ * _ , i , j = sch .get_loops (block )
431+ io , ii = sch .split (i , [None , micro_size ])
432+ jo , ji = sch .split (j , [None , micro_size ])
433+ sch .reorder (io , jo , ii , ji )
434+ sch .tensorize (ii , intrin )
435+
436+ C_init = sch .decompose_reduction (main_block , k0 )
437+ tensorize_block (A_simdgroup , intrin_group ["load_a" ])
438+ tensorize_block (B_simdgroup , intrin_group ["load_b" ])
439+ tensorize_block (C_simd2s , intrin_group ["store" ])
440+ tensorize_block (C_init , intrin_group ["init" ])
441+
442+ * _ , i , j , k = sch .get_loops (main_block )
443+ sch .tensorize (i , intrin_group ["compute" ])
444+
445+ auto_inline_consumer_chain (sch , C_s2g )
446+ fused = sch .fuse (* sch .get_loops (C_s2g )[- 2 :])
447+ _ , _tz , _ty , _tx , vec = sch .split (fused , [None , tz_len , ty_len , warp_size , vector_size ])
448+ sch .bind (_tz , "threadIdx.z" )
449+ sch .bind (_ty , "threadIdx.y" )
450+ sch .bind (_tx , "threadIdx.x" )
451+ sch .vectorize (vec )
452+
453+ return sch
454+
455+
316456class MatmulTensorization (GPUScheduleRule ):
317457 """
318458 The schedule rule for float16 tensor core matmul computation.
@@ -848,6 +988,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
848988 tensorize_sch = MatmulTensorization ().apply (func , target , _ )
849989 if tensorize_sch is not None :
850990 return tensorize_sch
991+ elif target .kind .name == "metal" :
992+ try :
993+ return MetalMatmul ().apply (func , target , _ )
994+ except : # pylint: disable=bare-except
995+ pass
851996
852997 # Step 2. Get schedule config.
853998 config = self .get_configs (target )
0 commit comments