Skip to content

Commit fcfc05b

Browse files
authored
[Transform] Allow explicit name of bundled model parameters (#16597)
In `BundleModelParams`, allow the user to specify a name for the tuple parameters. If unspecified, defaults to the previous name `"model_params"`.
1 parent 5308ef1 commit fcfc05b

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

python/tvm/relax/transform/transform.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,21 +852,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
852852
return _ffi_api.LiftTransformParams() # type: ignore
853853

854854

855-
def BundleModelParams() -> tvm.ir.transform.Pass:
855+
def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transform.Pass:
856856
"""Bundle several model parameters into a single tuple paramters
857857
858858
For each function, if the function has the attribute "num_input",
859859
separate between run-time parameters and compile-time weights.
860860
Run-time parameters (e.g. activations) are the first `num_input`
861861
parameters, and the remainder are compile-time weights.
862862
863+
Parameters
864+
----------
865+
param_tuple_name: Optional[str]
866+
867+
The name of the tuple parameter. If unspecified, defaults to
868+
"model_params".
869+
863870
Returns
864871
-------
865872
ret : tvm.transform.Pass
866873
The registered pass for lifting transformation of parameters.
867874
868875
"""
869-
return _ffi_api.BundleModelParams() # type: ignore
876+
return _ffi_api.BundleModelParams(param_tuple_name) # type: ignore
870877

871878

872879
def LegalizeOps(

src/relax/transform/bundle_model_params.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ namespace relax {
3535

3636
class ModelParamBundler : public ExprMutator {
3737
public:
38-
ModelParamBundler() {}
38+
explicit ModelParamBundler(Optional<String> param_tuple_name)
39+
: param_tuple_name_(param_tuple_name) {}
3940

4041
Expr VisitExpr_(const FunctionNode* op) override {
4142
Function func = GetRef<Function>(op);
@@ -59,7 +60,7 @@ class ModelParamBundler : public ExprMutator {
5960
param_tuple.push_back(GetStructInfo(func->params[i]));
6061
}
6162

62-
Var var_param_tuple("model_params", TupleStructInfo(param_tuple));
63+
Var var_param_tuple(param_tuple_name_.value_or("model_params"), TupleStructInfo(param_tuple));
6364
params.push_back(var_param_tuple);
6465

6566
for (size_t i = num_input; i < func->params.size(); i++) {
@@ -81,21 +82,22 @@ class ModelParamBundler : public ExprMutator {
8182
}
8283

8384
private:
85+
Optional<String> param_tuple_name_;
8486
Map<Var, Expr> var_to_expr_;
8587
};
8688

87-
Function BundleModelParams(const Function& func) {
88-
ModelParamBundler mutator;
89+
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name) {
90+
ModelParamBundler mutator(param_tuple_name);
8991
return Downcast<Function>(mutator(func));
9092
}
9193

9294
namespace transform {
93-
Pass BundleModelParams() {
95+
Pass BundleModelParams(Optional<String> param_tuple_name) {
9496
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
9597
PassContext pc) {
9698
IRModule updates;
9799

98-
ModelParamBundler mutator;
100+
ModelParamBundler mutator(param_tuple_name);
99101

100102
for (const auto& [gvar, func] : mod->functions) {
101103
if (auto opt = func.as<relax::Function>()) {

src/relax/transform/utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,12 @@ Expr CanonicalizeBindings(const Expr& expr);
429429
*
430430
* \param func The function to be updated.
431431
*
432+
* \param param_tuple_name The name of the tuple parameter. If
433+
* unspecified, defaults to "model_params"
434+
*
432435
* \ret The updated function.
433436
*/
434-
Function BundleModelParams(const Function& func);
437+
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name = NullOpt);
435438

436439
} // namespace relax
437440
} // namespace tvm

tests/python/relax/test_transform_bundle_model_params.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,45 @@ def main(
193193
assert binding.var.name_hint == expected_binding.var.name_hint
194194

195195

196+
def test_bundled_param_name():
197+
"""The tuple parameter can have an explicit name"""
198+
199+
@tvm.script.ir_module
200+
class Before:
201+
@R.function
202+
def main(
203+
a: R.Tensor([16], "float32"),
204+
b: R.Tensor([16], "float32"),
205+
c: R.Tensor([16], "float32"),
206+
) -> R.Tensor([16], "float32"):
207+
R.func_attr({"num_input": 1})
208+
expr = a
209+
expr = R.add(expr, b)
210+
expr = R.add(expr, c)
211+
return expr
212+
213+
@tvm.script.ir_module
214+
class Expected:
215+
@R.function
216+
def main(
217+
a: R.Tensor([16], "float32"),
218+
custom_tuple_name: R.Tuple(R.Tensor([16], "float32"), R.Tensor([16], "float32")),
219+
) -> R.Tensor([16], "float32"):
220+
R.func_attr({"num_input": 1})
221+
expr = a
222+
b = custom_tuple_name[0]
223+
expr = R.add(expr, b)
224+
c = custom_tuple_name[1]
225+
expr = R.add(expr, c)
226+
return expr
227+
228+
mod = Before
229+
after = relax.transform.BundleModelParams("custom_tuple_name")(mod)
230+
tvm.ir.assert_structural_equal(after, Expected)
231+
232+
for param, expected_param in zip(after["main"].params, Expected["main"].params):
233+
assert param.name_hint == expected_param.name_hint
234+
235+
196236
if __name__ == "__main__":
197237
tvm.testing.main()

0 commit comments

Comments
 (0)