Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,21 +852,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
return _ffi_api.LiftTransformParams() # type: ignore


def BundleModelParams() -> tvm.ir.transform.Pass:
def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transform.Pass:
"""Bundle several model parameters into a single tuple paramters

For each function, if the function has the attribute "num_input",
separate between run-time parameters and compile-time weights.
Run-time parameters (e.g. activations) are the first `num_input`
parameters, and the remainder are compile-time weights.

Parameters
----------
param_tuple_name: Optional[str]

The name of the tuple parameter. If unspecified, defaults to
"model_params".

Returns
-------
ret : tvm.transform.Pass
The registered pass for lifting transformation of parameters.

"""
return _ffi_api.BundleModelParams() # type: ignore
return _ffi_api.BundleModelParams(param_tuple_name) # type: ignore


def LegalizeOps(
Expand Down
14 changes: 8 additions & 6 deletions src/relax/transform/bundle_model_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace relax {

class ModelParamBundler : public ExprMutator {
public:
ModelParamBundler() {}
explicit ModelParamBundler(Optional<String> param_tuple_name)
: param_tuple_name_(param_tuple_name) {}

Expr VisitExpr_(const FunctionNode* op) override {
Function func = GetRef<Function>(op);
Expand All @@ -59,7 +60,7 @@ class ModelParamBundler : public ExprMutator {
param_tuple.push_back(GetStructInfo(func->params[i]));
}

Var var_param_tuple("model_params", TupleStructInfo(param_tuple));
Var var_param_tuple(param_tuple_name_.value_or("model_params"), TupleStructInfo(param_tuple));
params.push_back(var_param_tuple);

for (size_t i = num_input; i < func->params.size(); i++) {
Expand All @@ -81,21 +82,22 @@ class ModelParamBundler : public ExprMutator {
}

private:
Optional<String> param_tuple_name_;
Map<Var, Expr> var_to_expr_;
};

Function BundleModelParams(const Function& func) {
ModelParamBundler mutator;
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name) {
ModelParamBundler mutator(param_tuple_name);
return Downcast<Function>(mutator(func));
}

namespace transform {
Pass BundleModelParams() {
Pass BundleModelParams(Optional<String> param_tuple_name) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
IRModule updates;

ModelParamBundler mutator;
ModelParamBundler mutator(param_tuple_name);

for (const auto& [gvar, func] : mod->functions) {
if (auto opt = func.as<relax::Function>()) {
Expand Down
5 changes: 4 additions & 1 deletion src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,12 @@ Expr CanonicalizeBindings(const Expr& expr);
*
* \param func The function to be updated.
*
* \param param_tuple_name The name of the tuple parameter. If
* unspecified, defaults to "model_params"
*
* \ret The updated function.
*/
Function BundleModelParams(const Function& func);
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name = NullOpt);

} // namespace relax
} // namespace tvm
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_transform_bundle_model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,45 @@ def main(
assert binding.var.name_hint == expected_binding.var.name_hint


def test_bundled_param_name():
"""The tuple parameter can have an explicit name"""

@tvm.script.ir_module
class Before:
@R.function
def main(
a: R.Tensor([16], "float32"),
b: R.Tensor([16], "float32"),
c: R.Tensor([16], "float32"),
) -> R.Tensor([16], "float32"):
R.func_attr({"num_input": 1})
expr = a
expr = R.add(expr, b)
expr = R.add(expr, c)
return expr

@tvm.script.ir_module
class Expected:
@R.function
def main(
a: R.Tensor([16], "float32"),
custom_tuple_name: R.Tuple(R.Tensor([16], "float32"), R.Tensor([16], "float32")),
) -> R.Tensor([16], "float32"):
R.func_attr({"num_input": 1})
expr = a
b = custom_tuple_name[0]
expr = R.add(expr, b)
c = custom_tuple_name[1]
expr = R.add(expr, c)
return expr

mod = Before
after = relax.transform.BundleModelParams("custom_tuple_name")(mod)
tvm.ir.assert_structural_equal(after, Expected)

for param, expected_param in zip(after["main"].params, Expected["main"].params):
assert param.name_hint == expected_param.name_hint


if __name__ == "__main__":
tvm.testing.main()