2727from tvm .tir .analysis import undefined_vars
2828from tvm .tir .schedule .schedule import BlockRV
2929
30- from ..base import analysis
30+ from ..base import analysis , BlockInfo , IterInfo
3131from .base import GPUScheduleRule
3232
3333
@@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
273273 )
274274
275275
276+ def get_block_info (sch : tir .Schedule , block : tir .schedule .BlockRV ) -> BlockInfo :
277+ def _iter_kind (loop : tir .IterVar ) -> str :
278+ return {tir .IterVar .DataPar : "S" , tir .IterVar .CommReduce : "R" }.get (loop .iter_type , "O" )
279+
280+ def _is_reduction_block (block : tir .schedule .BlockRV ):
281+ for iter_var in sch .get (block ).iter_vars :
282+ if _iter_kind (iter_var ) == "R" :
283+ return True
284+ return False
285+
286+ return BlockInfo (
287+ name = sch .get (block ).name_hint ,
288+ iters = [
289+ IterInfo (
290+ kind = _iter_kind (iter_var ),
291+ var = iter_var .var ,
292+ dom = iter_var .dom .extent ,
293+ loop_rv = loop_rv ,
294+ )
295+ for loop_rv , iter_var in zip (sch .get_loops (block ), sch .get (block ).iter_vars )
296+ ],
297+ block_rv = block ,
298+ reduction_block = _is_reduction_block (block ),
299+ )
300+
301+
276302def get_reduction_blocks (sch , blocks ) -> bool :
277303 # Get the main computation block
278304 def is_reduction (block : BlockRV ) -> bool :
@@ -914,17 +940,19 @@ def get_configs(self, target: Target) -> Config:
914940 storage_align = True ,
915941 inner_x = False ,
916942 )
917- elif target .kind .name == "opencl" and "android" in str (target .host ):
943+ elif target .kind .name == "opencl" and (
944+ ("android" in str (target .host )) or ("windows" in str (target .host ))
945+ ):
918946 return Matmul .Config (
919- block_size_x = 8 ,
920- block_size_y = 16 ,
947+ block_size_x = 32 ,
948+ block_size_y = 8 ,
921949 vthread_x = 1 ,
922950 vthread_y = 1 ,
923951 micro_size_x = 8 ,
924952 micro_size_y = 2 ,
925953 micro_size_k = 16 ,
926954 vector_size = 8 ,
927- unroll = 64 ,
955+ unroll = 4 ,
928956 use_shared = False ,
929957 storage_align = False ,
930958 inner_x = True ,
@@ -941,6 +969,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
941969 if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
942970 return None
943971 sch = tir .Schedule (func )
972+ config = self .get_configs (target )
944973 root_block = analysis .get_root_block (sch )
945974 blocks = sch .get_child_blocks (root_block )
946975
@@ -953,9 +982,22 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
953982 index_maps = get_index_map (block_stmt )
954983 if index_maps is None :
955984 return None
956- matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
985+
986+ main_block_info = get_block_info (sch , main_block )
987+ iter_infos = main_block_info .iters
988+
989+ # Checks if it's a inner reduction by getting the last matrix's inner Index
990+ def is_inner_reduction (block_stmt , iter_infos ):
991+ end_it = block_stmt .reads [- 1 ].region [- 1 ].min
992+ return {it .var : it .kind for it in iter_infos }.get (end_it , "O" ) == "R"
993+
994+ if target .kind .name == "opencl" and not is_inner_reduction (block_stmt , iter_infos ):
995+ ret = self .sch_outer_reduction (sch , config , main_block , blocks )
996+ if ret is not None :
997+ return ret
957998
958999 # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
1000+ matmul_index_map , a_index_map , b_index_map , c_index_map = index_maps
9591001 block = sch .reindex (main_block , ("read" , 0 ))
9601002 sch .transform_layout (block , ("write" , 0 ), a_index_map )
9611003 block = sch .reindex (main_block , ("read" , 1 ))
@@ -994,10 +1036,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
9941036 except : # pylint: disable=bare-except
9951037 pass
9961038
997- # Step 2. Get schedule config.
998- config = self .get_configs (target )
999-
1000- # Step 3. Schedule matmul
1039+ # Step 2. Schedule matmul
10011040 y_kernel_size = config .vthread_y * config .block_size_y * config .micro_size_y
10021041 x_kernel_size = config .vthread_x * config .block_size_x * config .micro_size_x
10031042 if config .inner_x :
@@ -1075,3 +1114,88 @@ def _cooperative_fetch(index, vec_len):
10751114
10761115 sch .decompose_reduction (main_block , ko )
10771116 return sch
1117+
1118+ def sch_outer_reduction (
1119+ self ,
1120+ sch : tir .Schedule ,
1121+ config : Config ,
1122+ reduction_block : tir .schedule .BlockRV ,
1123+ blocks : List [tir .schedule .BlockRV ],
1124+ ) -> Optional [tir .Schedule ]:
1125+ reduction_loops = sch .get_loops (reduction_block )
1126+ if not len (reduction_loops ) == 4 :
1127+ return None
1128+
1129+ mb , ms , n , k = reduction_loops
1130+ if not (
1131+ isinstance (sch .get (n ).extent , tir .IntImm )
1132+ and isinstance (sch .get (mb ).extent , tir .IntImm )
1133+ and isinstance (sch .get (ms ).extent , tir .Var )
1134+ ):
1135+ return None
1136+
1137+ Threads_X , Threads_Y , VecSize , Unroll_M = (
1138+ config .block_size_x ,
1139+ config .block_size_y ,
1140+ config .vector_size ,
1141+ config .unroll ,
1142+ )
1143+
1144+ is_dequant_block = len (blocks ) > 1
1145+ if is_dequant_block :
1146+ compute_block , dequant_block , matmul_block = blocks
1147+ sch .compute_inline (compute_block )
1148+ else :
1149+ (matmul_block ,) = blocks
1150+
1151+ m = sch .fuse (mb , ms )
1152+
1153+ sch .pad_einsum (matmul_block , [1 , Threads_Y * Unroll_M , Threads_X * VecSize , 1 ])
1154+
1155+ rmat_block , wmat_block = (
1156+ sch .get_producers (matmul_block )[0 ],
1157+ sch .get_consumers (matmul_block )[0 ],
1158+ )
1159+ mo , mi , mu = sch .split (m , [None , Threads_Y , Unroll_M ])
1160+ no , ni , nv = sch .split (n , [None , Threads_X , VecSize ])
1161+ k0 , k1 , k2 , k3 = sch .split (k , [None , (Threads_X * VecSize ) // 32 , 4 , 8 ])
1162+ sch .reorder (no , mo , ni , mi , k0 , k1 , k2 , k3 , mu , nv )
1163+
1164+ sch .compute_at (rmat_block , k0 )
1165+ if is_dequant_block :
1166+ sch .compute_at (dequant_block , k3 )
1167+ sch .reverse_compute_at (wmat_block , mi )
1168+ sch .set_scope (rmat_block , 0 , "shared" )
1169+ sch .set_scope (matmul_block , 0 , "local" )
1170+ if is_dequant_block :
1171+ sch .set_scope (dequant_block , 0 , "local" )
1172+
1173+ sch .bind (mo , "blockIdx.y" )
1174+ sch .bind (no , "blockIdx.x" )
1175+ sch .bind (mi , "threadIdx.y" )
1176+ sch .bind (ni , "threadIdx.x" )
1177+ sch .vectorize (sch .get_loops (matmul_block )[- 1 ])
1178+ if is_dequant_block :
1179+ sch .vectorize (sch .get_loops (dequant_block )[- 1 ])
1180+
1181+ # Co-operative Memory Fetch
1182+ ro , rv = sch .split (sch .get_loops (rmat_block )[- 1 ], [None , VecSize ])
1183+ sch .bind (ro , "threadIdx.x" )
1184+ sch .vectorize (rv )
1185+
1186+ wv = sch .get_loops (wmat_block )[- 1 ]
1187+ sch .vectorize (wv )
1188+
1189+ # Scale and Quant Cache
1190+ if is_dequant_block :
1191+ qb = sch .cache_read (dequant_block , 0 , "local" )
1192+ sb = sch .cache_read (dequant_block , 1 , "local" )
1193+ sch .compute_at (sb , k1 )
1194+ sch .compute_at (qb , k2 )
1195+ sch .set_scope (sb , 0 , "local" )
1196+ sch .set_scope (qb , 0 , "local" )
1197+ sch .vectorize (sch .get_loops (qb )[- 1 ])
1198+ sch .vectorize (sch .get_loops (sb )[- 1 ])
1199+
1200+ sch .decompose_reduction (matmul_block , k0 )
1201+ return sch
0 commit comments