@@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
364364 if reduction_blocks is None :
365365 return None
366366
367- main_block = reduction_blocks [0 ]
368- block_stmt = sch .get (main_block )
369- index_maps = get_index_map (block_stmt )
370- if index_maps is None :
371- return None
372- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
373-
374367 # Step 0. Configs
375368 block_size_x : int = 16
376369 block_size_y : int = 16
@@ -382,12 +375,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
382375 vector_size : int = 4
383376
384377 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
385- block = sch .reindex (main_block , ("read" , 0 ))
386- sch .transform_layout (block , ("write" , 0 ), a_index_map )
387- block = sch .reindex (main_block , ("read" , 1 ))
388- sch .transform_layout (block , ("write" , 0 ), b_index_map )
389- block = sch .reindex (main_block , ("write" , 0 ))
390- sch .transform_layout (block , ("read" , 0 ), c_index_map )
378+ # Reindex first and than analyze the index map
379+ main_block = reduction_blocks [0 ]
380+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
381+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
382+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
383+
384+ index_maps = get_index_map (sch .get (main_block ))
385+ if index_maps is None :
386+ return None
387+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
388+
389+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
390+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
391+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
391392 sch .transform_block_layout (main_block , matmul_index_map )
392393
393394 # Step 2. Padding for dynamic shape kernels
@@ -508,13 +509,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
508509 if reduction_blocks is None :
509510 return None
510511
511- main_block = reduction_blocks [0 ]
512- block_stmt = sch .get (main_block )
513- index_maps = get_index_map (block_stmt )
514- if index_maps is None :
515- return None
516- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
517-
518512 # Start Schedule
519513 # Step 0. Get schedule config.
520514 # NOTE: we can analyze the config by the hardware spec in the future
@@ -539,12 +533,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
539533 k_pad_factor = k_factors [1 ]
540534
541535 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
542- block = sch .reindex (main_block , ("read" , 0 ))
543- sch .transform_layout (block , ("write" , 0 ), a_index_map )
544- block = sch .reindex (main_block , ("read" , 1 ))
545- sch .transform_layout (block , ("write" , 0 ), b_index_map )
546- block = sch .reindex (main_block , ("write" , 0 ))
547- sch .transform_layout (block , ("read" , 0 ), c_index_map )
536+ # Reindex first and than analyze the index map
537+ main_block = reduction_blocks [0 ]
538+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
539+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
540+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
541+
542+ index_maps = get_index_map (sch .get (main_block ))
543+ if index_maps is None :
544+ return None
545+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
546+
547+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
548+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
549+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
548550 sch .transform_block_layout (main_block , matmul_index_map )
549551
550552 # Step 2. Padding for dynamic shape kernels
@@ -729,13 +731,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
729731 if reduction_blocks is None :
730732 return None
731733
732- main_block = reduction_blocks [0 ]
733- block_stmt = sch .get (main_block )
734- index_maps = get_index_map (block_stmt )
735- if index_maps is None :
736- return None
737- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
738-
739734 # Start Schedule
740735 # Step 0. Get schedule config.
741736 # NOTE: we can analyze the config by the hardware spec in the future
@@ -760,12 +755,20 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
760755 k_pad_factor = k_factors [1 ]
761756
762757 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
763- block = sch .reindex (main_block , ("read" , 0 ))
764- sch .transform_layout (block , ("write" , 0 ), a_index_map )
765- block = sch .reindex (main_block , ("read" , 1 ))
766- sch .transform_layout (block , ("write" , 0 ), b_index_map )
767- block = sch .reindex (main_block , ("write" , 0 ))
768- sch .transform_layout (block , ("read" , 0 ), c_index_map )
758+ # Reindex first and than analyze the index map
759+ main_block = reduction_blocks [0 ]
760+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
761+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
762+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
763+
764+ index_maps = get_index_map (sch .get (main_block ))
765+ if index_maps is None :
766+ return None
767+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
768+
769+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
770+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
771+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
769772 sch .transform_block_layout (main_block , matmul_index_map )
770773
771774 # Step 2. Padding for dynamic shape kernels
@@ -979,9 +982,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
979982
980983 main_block = reduction_blocks [0 ]
981984 block_stmt = sch .get (main_block )
982- index_maps = get_index_map (block_stmt )
983- if index_maps is None :
984- return None
985985
986986 main_block_info = get_block_info (sch , main_block )
987987 iter_infos = main_block_info .iters
@@ -1000,13 +1000,19 @@ def is_inner_reduction(block_stmt, iter_infos):
10001000 return ret
10011001
10021002 # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
1003+ # Reindex first and than analyze the index map
1004+ reindex_a = sch .reindex (main_block , ("read" , 0 ))
1005+ reindex_b = sch .reindex (main_block , ("read" , 1 ))
1006+ reindex_c = sch .reindex (main_block , ("write" , 0 ))
1007+
1008+ index_maps = get_index_map (sch .get (main_block ))
1009+ if index_maps is None :
1010+ return None
10031011 matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
1004- block = sch .reindex (main_block , ("read" , 0 ))
1005- sch .transform_layout (block , ("write" , 0 ), a_index_map )
1006- block = sch .reindex (main_block , ("read" , 1 ))
1007- sch .transform_layout (block , ("write" , 0 ), b_index_map )
1008- block = sch .reindex (main_block , ("write" , 0 ))
1009- sch .transform_layout (block , ("read" , 0 ), c_index_map )
1012+
1013+ sch .transform_layout (reindex_a , ("write" , 0 ), a_index_map )
1014+ sch .transform_layout (reindex_b , ("write" , 0 ), b_index_map )
1015+ sch .transform_layout (reindex_c , ("read" , 0 ), c_index_map )
10101016 sch .transform_block_layout (main_block , matmul_index_map )
10111017
10121018 # Step 1. Check Tensor Core support
0 commit comments