1717# pylint: disable=missing-docstring
1818"""A fallback schedule rule for GPU operators."""
1919# pylint: disable=invalid-name
20-
21- from typing import List , Optional , Union
20+ from typing import List , Optional , Set , Tuple , Union
2221
2322from tvm import tir
2423from tvm ._ffi import get_global_func
2726from tvm .target import Target
2827
2928from ..base import ScheduleRule , normalize_prim_func , try_inline_contiguous_spatial
29+ from . import utils
3030
3131
3232def _get_reduction_expr (block : tir .Block ) -> Optional [tir .PrimExpr ]:
@@ -47,7 +47,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4747
4848def _detect_dominant_read (block : tir .Block ) -> tir .PrimExpr :
4949 dominant_read , read_iters = None , None
50- tir_vars = set ()
50+ tir_vars : Set [ tir . Var ] = set ()
5151 for buffer_region in block .reads :
5252 tir_vars .clear ()
5353
@@ -67,25 +67,37 @@ def _collect_tir_var(e):
6767 return result
6868
6969
70- class DecodeGEMV (ScheduleRule ):
71- def __init__ (self ) -> None :
72- super ().__init__ ()
73- self .get_loop_iter_type = get_global_func ("tir.schedule.GetLoopIterType" )
70+ _get_loop_iter_type = get_global_func ("tir.schedule.GetLoopIterType" )
71+
72+
73+ def _fuse_spatial_reduction_loops (
74+ sch : tir .Schedule ,
75+ loops : List [tir .schedule .LoopRV ],
76+ ) -> Tuple [tir .schedule .LoopRV , tir .schedule .LoopRV ]:
77+ s_loops , r_loops = [], []
78+ for loop_rv in loops :
79+ iter_type = _get_loop_iter_type (sch , loop_rv )
80+ if iter_type == "S" :
81+ s_loops .append (loop_rv )
82+ elif iter_type == "R" :
83+ r_loops .append (loop_rv )
84+ else :
85+ raise RuntimeError ("Unknown loop type " + str (iter_type ))
86+ sch .reorder (* s_loops , * r_loops )
87+ s_ctr = sch .fuse (* s_loops )
88+ r_ctr = sch .fuse (* r_loops )
89+ return s_ctr , r_ctr
7490
75- def apply ( # pylint: disable=too-many-locals
91+
92+ class DecodeGEMV (ScheduleRule ):
93+ def apply ( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
7694 self ,
7795 func : tir .PrimFunc ,
7896 target : Target ,
7997 _ : bool ,
8098 ) -> Union [None , tir .Schedule , List [tir .Schedule ]]:
8199 if not isinstance (func , tir .PrimFunc ):
82100 return None
83-
84- if target .kind .name == "cuda" :
85- len_tx , len_ty = 16 , 16
86- else :
87- len_tx , len_ty = 8 , 8
88-
89101 sch = tir .Schedule (func )
90102 block_infos = try_inline_contiguous_spatial (sch , normalize_prim_func (sch ))
91103
@@ -113,6 +125,7 @@ def apply( # pylint: disable=too-many-locals
113125 return None
114126 iter_to_info = {i .var : i for i in block_info .iters }
115127 s_loops , r_loops , c_loops = [], [], []
128+ c_loop_factor = None
116129 for split in sorted_iter_access .args :
117130 block_var = split .source .source
118131 block_var_info = iter_to_info [block_var ]
@@ -122,71 +135,82 @@ def apply( # pylint: disable=too-many-locals
122135 c_loop_factor = split .lower_factor
123136 loop_rv , c_loop = sch .split (loop_rv , factors = [None , c_loop_factor ])
124137 c_loops .append (c_loop )
125- is_loop_c_reduction = is_inner_reduction
138+ if is_inner_reduction :
139+ c_loop_factor = None
126140 if is_inner_reduction :
127141 r_loops .append (loop_rv )
128142 else :
129143 s_loops .append (loop_rv )
130144
131- if len (c_loops ) > 1 :
145+ if len (c_loops ) > 1 or len ( s_loops ) == 0 or len ( r_loops ) == 0 :
132146 return None
133147 if len (s_loops ) != len ([_ for i in block_info .iters if i .kind == "S" ]):
134148 return None
135- if len (s_loops ) == 0 or len (r_loops ) == 0 :
136- return None
137149
138150 sch .reorder (* s_loops , * r_loops , * c_loops )
139151 s = sch .fuse (* s_loops )
140152 r = sch .fuse (* r_loops )
141-
142- if is_inner_reduction :
143- _ , tx = sch .split (r , factors = [None , len_tx * len_ty ])
144- rf = sch .rfactor (tx , 0 )
145- s , r , tx = sch .get_loops (rf )[:3 ]
146- sch .reorder (s , tx , r )
147- sch .reverse_compute_at (block , s , preserve_unit_loops = True )
148- sch .bind (tx , "threadIdx.x" )
149- sch .bind (s , "blockIdx.x" )
150- else :
151- sch .split (s , factors = [None , len_tx ])
152- _ , ty = sch .split (r , factors = [None , len_ty ])
153- rf = sch .rfactor (ty , 0 )
154- bx , tx , r , ty = sch .get_loops (rf )[:4 ]
155- sch .reorder (bx , tx , ty , r )
156- sch .reverse_compute_at (block , bx , preserve_unit_loops = True )
157- sch .bind (tx , "threadIdx.x" )
158- sch .bind (ty , "threadIdx.y" )
159- sch .bind (bx , "blockIdx.x" )
160-
161- s_loops , r_loops = [], []
162- for loop_rv in sch .get_loops (block )[1 :]:
163- iter_type = self .get_loop_iter_type (sch , loop_rv )
164- if iter_type == "S" :
165- s_loops .append (loop_rv )
166- elif iter_type == "R" :
167- r_loops .append (loop_rv )
168- else :
169- raise RuntimeError ("Unknown loop type " + str (iter_type ))
170- sch .reorder (* s_loops , * r_loops )
171- s_ctr = sch .fuse (* s_loops )
172- r_ctr = sch .fuse (* r_loops )
173-
174- if c_loops and not is_loop_c_reduction :
175- s_ctr , inner = sch .split (s_ctr , factors = [None , c_loop_factor ])
176- sch .reorder (s_ctr , r_ctr , inner )
177-
178153 if is_inner_reduction :
179- sch .bind (r_ctr , "threadIdx.x" )
180- sch .set_scope (rf , 0 , "local" )
181- sch .decompose_reduction (rf , sch .get_loops (rf )[2 ])
154+ self ._sch_inner_reduction (sch , block , target , s , r , c_loop_factor )
182155 else :
183- sch .bind (s_ctr , "threadIdx.x" )
184- sch .bind (r_ctr , "threadIdx.y" )
185- sch .set_scope (rf , 0 , "local" )
186- sch .decompose_reduction (rf , sch .get_loops (rf )[3 ])
187-
156+ self ._sch_inner_spatial (sch , block , target , s , r , c_loop_factor )
188157 if len (block_infos ) == 2 :
189158 sch .set_scope (block , 0 , "local" )
190159 sch .reverse_compute_at (block_infos [1 ].block_rv , sch .get_loops (block )[0 ])
191-
192160 return sch
161+
162+ def _sch_inner_reduction ( # pylint: disable=too-many-arguments
163+ self ,
164+ sch : tir .Schedule ,
165+ block : tir .schedule .BlockRV ,
166+ target : Target ,
167+ _ : tir .schedule .LoopRV ,
168+ r : tir .schedule .LoopRV ,
169+ unroll_spatial_factor : Optional [int ],
170+ ):
171+ (len_tx ,) = utils .suggest_threads_per_block ( # pylint: disable=unbalanced-tuple-unpacking
172+ target , [sch .get (r )]
173+ )
174+
175+ _ , tx = sch .split (r , factors = [None , len_tx ])
176+ rf = sch .rfactor (tx , 0 )
177+ s , r , tx = sch .get_loops (rf )[:3 ]
178+ sch .reorder (s , tx , r )
179+ sch .reverse_compute_at (block , s , preserve_unit_loops = True )
180+ sch .bind (tx , "threadIdx.x" )
181+ sch .bind (s , "blockIdx.x" )
182+ s_ctr , r_ctr = _fuse_spatial_reduction_loops (sch , sch .get_loops (block )[1 :])
183+ if unroll_spatial_factor :
184+ s_ctr , inner = sch .split (s_ctr , factors = [None , unroll_spatial_factor ])
185+ sch .reorder (s_ctr , r_ctr , inner )
186+ sch .bind (r_ctr , "threadIdx.x" )
187+ sch .set_scope (rf , 0 , "local" )
188+ sch .decompose_reduction (rf , sch .get_loops (rf )[2 ])
189+
190+ def _sch_inner_spatial ( # pylint: disable=too-many-locals,too-many-arguments
191+ self ,
192+ sch : tir .Schedule ,
193+ block : tir .schedule .BlockRV ,
194+ target : Target ,
195+ s : tir .schedule .LoopRV ,
196+ r : tir .schedule .LoopRV ,
197+ unroll_spatial_factor : Optional [int ],
198+ ):
199+ len_tx , len_ty = 16 , 16
200+ sch .split (s , factors = [None , len_tx ])
201+ _ , ty = sch .split (r , factors = [None , len_ty ])
202+ rf = sch .rfactor (ty , 0 )
203+ bx , tx , r , ty = sch .get_loops (rf )[:4 ]
204+ sch .reorder (bx , tx , ty , r )
205+ sch .reverse_compute_at (block , bx , preserve_unit_loops = True )
206+ sch .bind (tx , "threadIdx.x" )
207+ sch .bind (ty , "threadIdx.y" )
208+ sch .bind (bx , "blockIdx.x" )
209+ s_ctr , r_ctr = _fuse_spatial_reduction_loops (sch , sch .get_loops (block )[1 :])
210+ if unroll_spatial_factor :
211+ s_ctr , inner = sch .split (s_ctr , factors = [None , unroll_spatial_factor ])
212+ sch .reorder (s_ctr , r_ctr , inner )
213+ sch .bind (s_ctr , "threadIdx.x" )
214+ sch .bind (r_ctr , "threadIdx.y" )
215+ sch .set_scope (rf , 0 , "local" )
216+ sch .decompose_reduction (rf , sch .get_loops (rf )[3 ])
0 commit comments