@@ -83,7 +83,7 @@ class LivenessAnalysis(PyExprVisitor):
8383 """
8484
8585 def __init__ (self , out_tuple_var : relax .Var ) -> None :
86- self .last_appear_in_var_binding = None
86+ self .last_appear_in_var_binding = []
8787 self .out_tuple_var = out_tuple_var
8888 self .var_liveness_end = {}
8989 self .ended_vars = set ()
@@ -132,20 +132,22 @@ def __init__(
132132 self .extra_get_item_params = extra_get_item_params
133133 self .fset_item = fset_item
134134 self .extra_set_item_params = extra_set_item_params
135- # the only input param, which should be a Tuple
136- self .input_tuple_param = None
137135 self .input_params_set = None
138136 self .out_tuple_map = None
139137 self .out_tuple_var = None
140138 self .memory_free_insertion = None
141139
142140 def transform (self , func : relax .Function ) -> relax .Function :
143- self .input_tuple_param = func .params [0 ]
141+ if func .attrs is not None and "num_input" in func .attrs :
142+ num_input = func .attrs ["num_input" ].value
143+ else :
144+ num_input = 0
145+
144146 seq_expr = func .body
145147 self .out_tuple_var = seq_expr .body
146148
147149 # Step 1. collect out_tuple_map and input_params_set
148- forward_collector = ForwardCollector (self .out_tuple_var , self . input_tuple_param )
150+ forward_collector = ForwardCollector (self .out_tuple_var , func . params [ num_input ] )
149151 forward_collector .visit_expr (func )
150152 self .out_tuple_map = forward_collector .out_tuple_map
151153 # input_params_set is the set of binding var for var = params[i]
@@ -157,24 +159,65 @@ def transform(self, func: relax.Function) -> relax.Function:
157159 self .memory_free_insertion = liveness .var_liveness_end
158160
159161 # Step 3. rewrite get item and set item
160- new_body = func .body
161162 if self .fget_item is not None :
162- new_body = LazyInputMutator (self , self .mod ).visit_expr (new_body )
163+ new_func = LazyInputMutator (self , self .mod ).visit_expr (func )
163164
165+ new_body = new_func .body
164166 if self .fset_item is not None :
167+ # The LazyOutputMutator only inspects variable bindings
168+ # for replacement. If the output tuple includes elements
169+ # that do not have a variable binding, such as
170+ # `relax.Const`, these must still produce a call to the
171+ # `"set_item"` function.
172+ leaf_outputs = {
173+ expr : indices
174+ for expr , indices in self .out_tuple_map .items ()
175+ if not isinstance (expr , relax .Var )
176+ }
177+ if leaf_outputs :
178+ new_bindings = [
179+ relax .VarBinding (
180+ relax .Var ("_" , relax .ObjectStructInfo ()),
181+ relax .Call (
182+ relax .ExternFunc (self .fset_item ),
183+ [* self .extra_set_item_params , index , expr ],
184+ None ,
185+ [relax .ObjectStructInfo ()],
186+ ),
187+ )
188+ for expr , indices in leaf_outputs .items ()
189+ for index in indices
190+ ]
191+ new_body = relax .SeqExpr (
192+ [* new_body .blocks , relax .BindingBlock (new_bindings )], new_body .body
193+ )
194+
165195 new_body = LazyOutputMutator (self , self .mod ).visit_expr (new_body )
166196
167197 # Step 4. Add parameters of get_item and set_item (except index) to the function.
168- params = [* self .extra_get_item_params , * self .extra_set_item_params ]
198+ params = [
199+ * func .params [:num_input ],
200+ * self .extra_get_item_params ,
201+ * self .extra_set_item_params ,
202+ ]
169203
170204 # Step 5. Find all shape parameters that should be retained as
171205 # parameters.
172206 symbolic_vars = relax .analysis .defined_symbolic_vars (func )
173207 if symbolic_vars :
208+
209+ def unpack_sinfo (sinfo ):
210+ if isinstance (sinfo , relax .TupleStructInfo ):
211+ for field in sinfo .fields :
212+ yield from unpack_sinfo (field )
213+ else :
214+ yield sinfo
215+
174216 # direct iterate over the struct info annotation
175- for sinfo in self .input_tuple_param .struct_info .fields :
176- if not isinstance (sinfo , relax .TensorStructInfo ):
177- params .append (relax .Var ("symbolic_var_holder" , sinfo ))
217+ for param in func .params [num_input :]:
218+ for sinfo in unpack_sinfo (param .struct_info ):
219+ if not isinstance (sinfo , relax .TensorStructInfo ):
220+ params .append (relax .Var ("symbolic_var_holder" , sinfo ))
178221
179222 return relax .Function (
180223 params ,
@@ -191,22 +234,67 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
191234 self .func_creator = func_creator
192235 super ().__init__ (mod )
193236
194- def visit_tuple_getitem_ (self , op : relax .TupleGetItem ) -> relax .Expr :
195- # rewrite get item
196- tuple_get_item = super ().visit_tuple_getitem_ (op )
197- if tuple_get_item .tuple_value == self .func_creator .input_tuple_param :
237+ def visit_function_ (self , func : relax .Function ) -> relax .Expr :
238+ if func .attrs is not None and "num_input" in func .attrs :
239+ num_input = func .attrs ["num_input" ].value
240+ else :
241+ num_input = 0
242+
243+ params = list (func .params )[num_input :]
244+ if len (params ) == 1 and isinstance (params [0 ].struct_info_ , relax .TupleStructInfo ):
245+ self .tuple_param = params [0 ]
246+ self .params = {}
247+ else :
248+ self .tuple_param = None
249+ self .params = {var : i for i , var in enumerate (params )}
250+ func = relax .Function (
251+ func .params [:num_input ],
252+ func .body ,
253+ func .ret_struct_info ,
254+ is_pure = False ,
255+ attrs = func .attrs ,
256+ span = func .span ,
257+ ).without_attr ("relax.force_pure" )
258+ output = super ().visit_function_ (func )
259+ self .tuple_param = None
260+ self .params = {}
261+ return output
262+
263+ def visit_var_ (self , var : relax .Var ) -> relax .Expr :
264+ if var in self .params :
265+ index = self .params [var ]
266+ get_item_result = self .builder_ .emit (
267+ relax .Call (
268+ relax .ExternFunc (self .func_creator .fget_item ),
269+ self .func_creator .extra_get_item_params + [relax .PrimValue (index )],
270+ None ,
271+ [relax .ObjectStructInfo ()],
272+ )
273+ )
274+ match_cast = relax .MatchCast (var , get_item_result , var .struct_info )
275+ self .builder_ .emit_normalized (match_cast )
276+
277+ del self .params [var ]
278+
279+ return super ().visit_var_ (var )
280+
281+ def visit_tuple_getitem_ (self , node : relax .TupleGetItem ) -> relax .Expr :
282+ sinfo = node .struct_info
283+
284+ node = super ().visit_tuple_getitem_ (node )
285+
286+ if self .tuple_param is not None and node .tuple_value .same_as (self .tuple_param ):
198287 get_item_result = self .builder_ .emit (
199288 relax .Call (
200289 relax .ExternFunc (self .func_creator .fget_item ),
201- self .func_creator .extra_get_item_params
202- + [relax .PrimValue (tuple_get_item .index )],
290+ self .func_creator .extra_get_item_params + [relax .PrimValue (node .index )],
203291 None ,
204292 [relax .ObjectStructInfo ()],
205293 )
206294 )
207- return self .builder_ .match_cast (get_item_result , op . struct_info )
295+ return self .builder_ .match_cast (get_item_result , sinfo )
208296 else :
209- return tuple_get_item
297+ return node
210298
211299
212300@mutator
0 commit comments