@@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
364364 if reduction_blocks is None :
365365 return None
366366
367- main_block = reduction_blocks [0 ]
368- block_stmt = sch .get (main_block )
369- index_maps = get_index_map (block_stmt )
370- if index_maps is None :
371- return None
372- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
373-
374367 # Step 0. Configs
375368 block_size_x : int = 16
376369 block_size_y : int = 16
@@ -382,12 +375,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
382375 vector_size : int = 4
383376
384377 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
385- block = sch .reindex (main_block , ("read" , 0 ))
386- sch .transform_layout (block , ("write" , 0 ), a_index_map )
387- block = sch .reindex (main_block , ("read" , 1 ))
388- sch .transform_layout (block , ("write" , 0 ), b_index_map )
389- block = sch .reindex (main_block , ("write" , 0 ))
390- sch .transform_layout (block , ("read" , 0 ), c_index_map )
378+ # Reindex first and than analyze the index map
379+ main_block = reduction_blocks [0 ]
380+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
381+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
382+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
383+
384+ index_maps = get_index_map (sch .get (main_block ))
385+ assert index_maps is not None
386+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
387+
388+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
389+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
390+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
391391 sch .transform_block_layout (main_block , matmul_index_map )
392392
393393 # Step 2. Padding for dynamic shape kernels
@@ -508,13 +508,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
508508 if reduction_blocks is None :
509509 return None
510510
511- main_block = reduction_blocks [0 ]
512- block_stmt = sch .get (main_block )
513- index_maps = get_index_map (block_stmt )
514- if index_maps is None :
515- return None
516- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
517-
518511 # Start Schedule
519512 # Step 0. Get schedule config.
520513 # NOTE: we can analyze the config by the hardware spec in the future
@@ -539,12 +532,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
539532 k_pad_factor = k_factors [1 ]
540533
541534 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
542- block = sch .reindex (main_block , ("read" , 0 ))
543- sch .transform_layout (block , ("write" , 0 ), a_index_map )
544- block = sch .reindex (main_block , ("read" , 1 ))
545- sch .transform_layout (block , ("write" , 0 ), b_index_map )
546- block = sch .reindex (main_block , ("write" , 0 ))
547- sch .transform_layout (block , ("read" , 0 ), c_index_map )
535+ # Reindex first and than analyze the index map
536+ main_block = reduction_blocks [0 ]
537+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
538+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
539+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
540+
541+ index_maps = get_index_map (sch .get (main_block ))
542+ assert index_maps is not None
543+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
544+
545+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
546+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
547+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
548548 sch .transform_block_layout (main_block , matmul_index_map )
549549
550550 # Step 2. Padding for dynamic shape kernels
@@ -729,13 +729,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
729729 if reduction_blocks is None :
730730 return None
731731
732- main_block = reduction_blocks [0 ]
733- block_stmt = sch .get (main_block )
734- index_maps = get_index_map (block_stmt )
735- if index_maps is None :
736- return None
737- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
738-
739732 # Start Schedule
740733 # Step 0. Get schedule config.
741734 # NOTE: we can analyze the config by the hardware spec in the future
@@ -760,12 +753,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
760753 k_pad_factor = k_factors [1 ]
761754
762755 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
763- block = sch .reindex (main_block , ("read" , 0 ))
764- sch .transform_layout (block , ("write" , 0 ), a_index_map )
765- block = sch .reindex (main_block , ("read" , 1 ))
766- sch .transform_layout (block , ("write" , 0 ), b_index_map )
767- block = sch .reindex (main_block , ("write" , 0 ))
768- sch .transform_layout (block , ("read" , 0 ), c_index_map )
756+ # Reindex first and than analyze the index map
757+ main_block = reduction_blocks [0 ]
758+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
759+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
760+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
761+
762+ index_maps = get_index_map (sch .get (main_block ))
763+ assert index_maps is not None
764+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
765+
766+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
767+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
768+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
769769 sch .transform_block_layout (main_block , matmul_index_map )
770770
771771 # Step 2. Padding for dynamic shape kernels
@@ -979,12 +979,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
979979
980980 main_block = reduction_blocks [0 ]
981981 block_stmt = sch .get (main_block )
982- index_maps = get_index_map (block_stmt )
983- if index_maps is None :
984- return None
985982
986983 main_block_info = get_block_info (sch , main_block )
987984 iter_infos = main_block_info .iters
985+ if not get_index_map (block_stmt ):
986+ return None
988987
989988 # Checks if it's a inner reduction by getting the last matrix's inner Index
990989 def is_inner_reduction (block_stmt , iter_infos ):
@@ -1000,13 +999,18 @@ def is_inner_reduction(block_stmt, iter_infos):
1000999 return ret
10011000
10021001 # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
1002+ # Reindex first and than analyze the index map
1003+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
1004+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
1005+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
1006+
1007+ index_maps = get_index_map (sch .get (main_block ))
1008+ assert index_maps is not None
10031009 matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
1004- block = sch .reindex (main_block , ("read" , 0 ))
1005- sch .transform_layout (block , ("write" , 0 ), a_index_map )
1006- block = sch .reindex (main_block , ("read" , 1 ))
1007- sch .transform_layout (block , ("write" , 0 ), b_index_map )
1008- block = sch .reindex (main_block , ("write" , 0 ))
1009- sch .transform_layout (block , ("read" , 0 ), c_index_map )
1010+
1011+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
1012+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
1013+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
10101014 sch .transform_block_layout (main_block , matmul_index_map )
10111015
10121016 # Step 1. Check Tensor Core support
0 commit comments