Skip to content

Commit 5308ef1

Browse files
authored
[Transform] Improvements to LazyTransformParams (#16602)
* [Transform] Improvements to LazyTransformParams * Handle non-bundled parameters in LazyTransformParams. * Check for `"num_input"` attribute * Handle relax.Const in LazyTransformParams Prior to this commit, `LazyTransformParams` would only output a call to the `fset_item` function if that element of the output had a corresponding `relax.Binding`. If `relax.Const` appeared in the output, then the call to `fset_item` would be omitted. This commit updates `LazyTransformParams` to check for any non-`Var` elements of the output tuple. * Update based on review comments
1 parent 8fe0164 commit 5308ef1

File tree

2 files changed

+310
-20
lines changed

2 files changed

+310
-20
lines changed

python/tvm/relax/transform/lazy_transform_params.py

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class LivenessAnalysis(PyExprVisitor):
8383
"""
8484

8585
def __init__(self, out_tuple_var: relax.Var) -> None:
86-
self.last_appear_in_var_binding = None
86+
self.last_appear_in_var_binding = []
8787
self.out_tuple_var = out_tuple_var
8888
self.var_liveness_end = {}
8989
self.ended_vars = set()
@@ -132,20 +132,22 @@ def __init__(
132132
self.extra_get_item_params = extra_get_item_params
133133
self.fset_item = fset_item
134134
self.extra_set_item_params = extra_set_item_params
135-
# the only input param, which should be a Tuple
136-
self.input_tuple_param = None
137135
self.input_params_set = None
138136
self.out_tuple_map = None
139137
self.out_tuple_var = None
140138
self.memory_free_insertion = None
141139

142140
def transform(self, func: relax.Function) -> relax.Function:
143-
self.input_tuple_param = func.params[0]
141+
if func.attrs is not None and "num_input" in func.attrs:
142+
num_input = func.attrs["num_input"].value
143+
else:
144+
num_input = 0
145+
144146
seq_expr = func.body
145147
self.out_tuple_var = seq_expr.body
146148

147149
# Step 1. collect out_tuple_map and input_params_set
148-
forward_collector = ForwardCollector(self.out_tuple_var, self.input_tuple_param)
150+
forward_collector = ForwardCollector(self.out_tuple_var, func.params[num_input])
149151
forward_collector.visit_expr(func)
150152
self.out_tuple_map = forward_collector.out_tuple_map
151153
# input_params_set is the set of binding var for var = params[i]
@@ -157,24 +159,65 @@ def transform(self, func: relax.Function) -> relax.Function:
157159
self.memory_free_insertion = liveness.var_liveness_end
158160

159161
# Step 3. rewrite get item and set item
160-
new_body = func.body
161162
if self.fget_item is not None:
162-
new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
163+
new_func = LazyInputMutator(self, self.mod).visit_expr(func)
163164

165+
new_body = new_func.body
164166
if self.fset_item is not None:
167+
# The LazyOutputMutator only inspects variable bindings
168+
# for replacement. If the output tuple includes elements
169+
# that do not have a variable binding, such as
170+
# `relax.Const`, these must still produce a call to the
171+
# `"set_item"` function.
172+
leaf_outputs = {
173+
expr: indices
174+
for expr, indices in self.out_tuple_map.items()
175+
if not isinstance(expr, relax.Var)
176+
}
177+
if leaf_outputs:
178+
new_bindings = [
179+
relax.VarBinding(
180+
relax.Var("_", relax.ObjectStructInfo()),
181+
relax.Call(
182+
relax.ExternFunc(self.fset_item),
183+
[*self.extra_set_item_params, index, expr],
184+
None,
185+
[relax.ObjectStructInfo()],
186+
),
187+
)
188+
for expr, indices in leaf_outputs.items()
189+
for index in indices
190+
]
191+
new_body = relax.SeqExpr(
192+
[*new_body.blocks, relax.BindingBlock(new_bindings)], new_body.body
193+
)
194+
165195
new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)
166196

167197
# Step 4. Add parameters of get_item and set_item (except index) to the function.
168-
params = [*self.extra_get_item_params, *self.extra_set_item_params]
198+
params = [
199+
*func.params[:num_input],
200+
*self.extra_get_item_params,
201+
*self.extra_set_item_params,
202+
]
169203

170204
# Step 5. Find all shape parameters that should be retained as
171205
# parameters.
172206
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
173207
if symbolic_vars:
208+
209+
def unpack_sinfo(sinfo):
210+
if isinstance(sinfo, relax.TupleStructInfo):
211+
for field in sinfo.fields:
212+
yield from unpack_sinfo(field)
213+
else:
214+
yield sinfo
215+
174216
# direct iterate over the struct info annotation
175-
for sinfo in self.input_tuple_param.struct_info.fields:
176-
if not isinstance(sinfo, relax.TensorStructInfo):
177-
params.append(relax.Var("symbolic_var_holder", sinfo))
217+
for param in func.params[num_input:]:
218+
for sinfo in unpack_sinfo(param.struct_info):
219+
if not isinstance(sinfo, relax.TensorStructInfo):
220+
params.append(relax.Var("symbolic_var_holder", sinfo))
178221

179222
return relax.Function(
180223
params,
@@ -191,22 +234,67 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
191234
self.func_creator = func_creator
192235
super().__init__(mod)
193236

194-
def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
195-
# rewrite get item
196-
tuple_get_item = super().visit_tuple_getitem_(op)
197-
if tuple_get_item.tuple_value == self.func_creator.input_tuple_param:
237+
def visit_function_(self, func: relax.Function) -> relax.Expr:
238+
if func.attrs is not None and "num_input" in func.attrs:
239+
num_input = func.attrs["num_input"].value
240+
else:
241+
num_input = 0
242+
243+
params = list(func.params)[num_input:]
244+
if len(params) == 1 and isinstance(params[0].struct_info_, relax.TupleStructInfo):
245+
self.tuple_param = params[0]
246+
self.params = {}
247+
else:
248+
self.tuple_param = None
249+
self.params = {var: i for i, var in enumerate(params)}
250+
func = relax.Function(
251+
func.params[:num_input],
252+
func.body,
253+
func.ret_struct_info,
254+
is_pure=False,
255+
attrs=func.attrs,
256+
span=func.span,
257+
).without_attr("relax.force_pure")
258+
output = super().visit_function_(func)
259+
self.tuple_param = None
260+
self.params = {}
261+
return output
262+
263+
def visit_var_(self, var: relax.Var) -> relax.Expr:
264+
if var in self.params:
265+
index = self.params[var]
266+
get_item_result = self.builder_.emit(
267+
relax.Call(
268+
relax.ExternFunc(self.func_creator.fget_item),
269+
self.func_creator.extra_get_item_params + [relax.PrimValue(index)],
270+
None,
271+
[relax.ObjectStructInfo()],
272+
)
273+
)
274+
match_cast = relax.MatchCast(var, get_item_result, var.struct_info)
275+
self.builder_.emit_normalized(match_cast)
276+
277+
del self.params[var]
278+
279+
return super().visit_var_(var)
280+
281+
def visit_tuple_getitem_(self, node: relax.TupleGetItem) -> relax.Expr:
282+
sinfo = node.struct_info
283+
284+
node = super().visit_tuple_getitem_(node)
285+
286+
if self.tuple_param is not None and node.tuple_value.same_as(self.tuple_param):
198287
get_item_result = self.builder_.emit(
199288
relax.Call(
200289
relax.ExternFunc(self.func_creator.fget_item),
201-
self.func_creator.extra_get_item_params
202-
+ [relax.PrimValue(tuple_get_item.index)],
290+
self.func_creator.extra_get_item_params + [relax.PrimValue(node.index)],
203291
None,
204292
[relax.ObjectStructInfo()],
205293
)
206294
)
207-
return self.builder_.match_cast(get_item_result, op.struct_info)
295+
return self.builder_.match_cast(get_item_result, sinfo)
208296
else:
209-
return tuple_get_item
297+
return node
210298

211299

212300
@mutator

0 commit comments

Comments
 (0)