diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index b2aaa3e331a1..c017f0cda738 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -852,7 +852,7 @@ def LiftTransformParams() -> tvm.ir.transform.Pass: return _ffi_api.LiftTransformParams() # type: ignore -def BundleModelParams() -> tvm.ir.transform.Pass: +def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transform.Pass: """Bundle several model parameters into a single tuple paramters For each function, if the function has the attribute "num_input", @@ -860,13 +860,20 @@ def BundleModelParams() -> tvm.ir.transform.Pass: Run-time parameters (e.g. activations) are the first `num_input` parameters, and the remainder are compile-time weights. + Parameters + ---------- + param_tuple_name: Optional[str] + + The name of the tuple parameter. If unspecified, defaults to + "model_params". + Returns ------- ret : tvm.transform.Pass The registered pass for lifting transformation of parameters. """ - return _ffi_api.BundleModelParams() # type: ignore + return _ffi_api.BundleModelParams(param_tuple_name) # type: ignore def LegalizeOps( diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index a9cb719d26d9..f5798049efa1 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -35,7 +35,8 @@ namespace relax { class ModelParamBundler : public ExprMutator { public: - ModelParamBundler() {} + explicit ModelParamBundler(Optional param_tuple_name) + : param_tuple_name_(param_tuple_name) {} Expr VisitExpr_(const FunctionNode* op) override { Function func = GetRef(op); @@ -59,7 +60,7 @@ class ModelParamBundler : public ExprMutator { param_tuple.push_back(GetStructInfo(func->params[i])); } - Var var_param_tuple("model_params", TupleStructInfo(param_tuple)); + Var var_param_tuple(param_tuple_name_.value_or("model_params"), TupleStructInfo(param_tuple)); params.push_back(var_param_tuple); for (size_t i = num_input; i < func->params.size(); i++) { @@ -81,21 +82,22 @@ class ModelParamBundler : public ExprMutator { } private: + Optional param_tuple_name_; Map var_to_expr_; }; -Function BundleModelParams(const Function& func) { - ModelParamBundler mutator; +Function BundleModelParams(const Function& func, Optional param_tuple_name) { + ModelParamBundler mutator(param_tuple_name); return Downcast(mutator(func)); } namespace transform { -Pass BundleModelParams() { +Pass BundleModelParams(Optional param_tuple_name) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { IRModule updates; - ModelParamBundler mutator; + ModelParamBundler mutator(param_tuple_name); for (const auto& [gvar, func] : mod->functions) { if (auto opt = func.as()) { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 802099f0ab4b..1ad714972c2d 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -429,9 +429,12 @@ Expr CanonicalizeBindings(const Expr& expr); * * \param func The function to be updated. * + * \param param_tuple_name The name of the tuple parameter. If + * unspecified, defaults to "model_params" + * * \ret The updated function. */ -Function BundleModelParams(const Function& func); +Function BundleModelParams(const Function& func, Optional param_tuple_name = NullOpt); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_transform_bundle_model_params.py b/tests/python/relax/test_transform_bundle_model_params.py index e3528cc357e4..415a883f1638 100644 --- a/tests/python/relax/test_transform_bundle_model_params.py +++ b/tests/python/relax/test_transform_bundle_model_params.py @@ -193,5 +193,45 @@ def main( assert binding.var.name_hint == expected_binding.var.name_hint +def test_bundled_param_name(): + """The tuple parameter can have an explicit name""" + + @tvm.script.ir_module + class Before: + @R.function + def main( + a: R.Tensor([16], "float32"), + b: R.Tensor([16], "float32"), + c: R.Tensor([16], "float32"), + ) -> R.Tensor([16], "float32"): + R.func_attr({"num_input": 1}) + expr = a + expr = R.add(expr, b) + expr = R.add(expr, c) + return expr + + @tvm.script.ir_module + class Expected: + @R.function + def main( + a: R.Tensor([16], "float32"), + custom_tuple_name: R.Tuple(R.Tensor([16], "float32"), R.Tensor([16], "float32")), + ) -> R.Tensor([16], "float32"): + R.func_attr({"num_input": 1}) + expr = a + b = custom_tuple_name[0] + expr = R.add(expr, b) + c = custom_tuple_name[1] + expr = R.add(expr, c) + return expr + + mod = Before + after = relax.transform.BundleModelParams("custom_tuple_name")(mod) + tvm.ir.assert_structural_equal(after, Expected) + + for param, expected_param in zip(after["main"].params, Expected["main"].params): + assert param.name_hint == expected_param.name_hint + + if __name__ == "__main__": tvm.testing.main()