Skip to content

Commit 32f2808

Browse files
committed
[Unity] Implement ExpandTupleArguments transform
Currently, the `FuseOps` and `FuseTIR` passes have a large amount of added complexity to identify and handle partial use of tuple arguments. The handling partial use of tuples could be significantly simpler if performed in multiple steps. 1. Perform `FuseOps`. Any tuple variables that are used by the fused function are passed as-is. 2. Expand any parameters that are passed as a tuple. Any unused tensors that were included in a partially-used tuple will be converted to unused parameters. 3. Remove any unused parameters. Any unused tensors that were included in a partially-used tuple will be removed in this step. 4. Perform `FuseTIR`. No checking for tuple arguments, either partial or full, is required at this step. This PR implements `relax.transform.ExpandTupleArguments`, which is step (2) in this process.
1 parent a6adaae commit 32f2808

File tree

7 files changed

+288
-3
lines changed

7 files changed

+288
-3
lines changed

include/tvm/relax/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ TVM_DLL Pass LiftTransformParams();
275275
*/
276276
TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
277277

278+
/*! \brief Expand tuple arguments to internal functions
279+
*
280+
* \return The Pass
281+
*/
282+
TVM_DLL Pass ExpandTupleArguments();
283+
278284
/*!
279285
* \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
280286
* \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be

python/tvm/relax/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DecomposeOpsForInference,
3434
DecomposeOpsForTraining,
3535
EliminateCommonSubexpr,
36+
ExpandTupleArguments,
3637
FewShotTuning,
3738
FoldConstant,
3839
FunctionPass,

python/tvm/relax/transform/transform.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,16 @@ def FoldConstant() -> tvm.ir.transform.Pass:
558558
return _ffi_api.FoldConstant() # type: ignore
559559

560560

561+
def ExpandTupleArguments() -> tvm.ir.transform.Pass:
562+
"""Expand tuple arguments to internal functions
563+
564+
Returns
565+
-------
566+
ret: tvm.ir.transform.Pass
567+
"""
568+
return _ffi_api.ExpandTupleArguments() # type: ignore
569+
570+
561571
def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
562572
"""Annotate Op Pattern Kind for TIR functions
563573

python/tvm/script/parser/relax/parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def visit_if(self: Parser, node: doc.If) -> None:
367367
@dispatch.register(token="relax", type_name="enter_token")
368368
def enter_token(self: Parser) -> Dict[str, Any]:
369369
def relax_call(self, *args) -> Expr:
370+
371+
args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args]
372+
370373
if all(isinstance(x, Expr) for x in args):
371374
return relax.Call(self, args)
372375
arg_types = [type(x) for x in args]

python/tvm/testing/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,7 @@ def __init_subclass__(cls):
19381938

19391939
@classmethod
19401940
def _normalize_ir_module(cls, func):
1941-
if isinstance(func, tvm.tir.PrimFunc):
1941+
if isinstance(func, (tvm.tir.PrimFunc, tvm.IRModule)):
19421942

19431943
def inner(self):
19441944
# pylint: disable=unused-argument
@@ -2042,8 +2042,7 @@ def inner(self):
20422042

20432043
@staticmethod
20442044
def _is_method(func):
2045-
sig = inspect.signature(func)
2046-
return "self" in sig.parameters
2045+
return callable(func) and "self" in inspect.signature(func).parameters
20472046

20482047
def test_compare(self, before, expected, transform):
20492048
"""Unit test to compare the expected TIR PrimFunc to actual"""
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/relax/expr_functor.h>
21+
#include <tvm/relax/transform.h>
22+
23+
#include <algorithm>
24+
#include <tuple>
25+
26+
namespace tvm {
27+
namespace relax {
28+
29+
namespace {
30+
31+
template <typename T, typename U>
32+
using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
33+
34+
Optional<Function> ExpandParams(Function func) {
35+
bool is_exposed = func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
36+
if (is_exposed) return NullOpt;
37+
38+
bool has_tuple_param = std::any_of(
39+
func->params.begin(), func->params.end(),
40+
[](const Var& param) -> bool { return param->struct_info_.as<TupleStructInfoNode>(); });
41+
42+
if (!has_tuple_param) return NullOpt;
43+
44+
Array<Var> params;
45+
Array<Binding> bindings;
46+
47+
std::function<void(const Var&)> expand_param = [&](const Var& param) {
48+
if (auto sinfo = param->struct_info_.as<TupleStructInfoNode>()) {
49+
Array<Expr> internal_tuple;
50+
for (size_t i = 0; i < sinfo->fields.size(); i++) {
51+
auto name = static_cast<const std::stringstream&>(std::stringstream()
52+
<< param->name_hint() << "_" << i)
53+
.str();
54+
Var new_param(name, sinfo->fields[i]);
55+
internal_tuple.push_back(new_param);
56+
expand_param(new_param);
57+
}
58+
bindings.push_back(VarBinding(param, Tuple(internal_tuple)));
59+
} else {
60+
params.push_back(param);
61+
}
62+
};
63+
64+
for (const auto& param : func->params) {
65+
expand_param(param);
66+
}
67+
68+
FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }),
69+
func->ret_struct_info,
70+
Downcast<FuncStructInfo>(func->struct_info_)->purity);
71+
72+
auto write_ptr = func.CopyOnWrite();
73+
write_ptr->params = params;
74+
write_ptr->body = SeqExpr({BindingBlock(bindings)}, func->body);
75+
write_ptr->struct_info_ = new_sinfo;
76+
77+
return func;
78+
}
79+
80+
class TupleExpander : public ExprMutator {
81+
public:
82+
explicit TupleExpander(PMap<GlobalVar, GlobalVar> callees) : replacements_(callees) {}
83+
84+
using ExprMutator::VisitExpr_;
85+
86+
Expr VisitExpr_(const CallNode* op) override {
87+
auto node = Downcast<Call>(ExprMutator::VisitExpr_(op));
88+
89+
if (auto gvar = node->op.as<GlobalVar>()) {
90+
if (auto it = replacements_.find(gvar.value()); it != replacements_.end()) {
91+
Array<Expr> new_args;
92+
93+
std::function<void(const Expr&)> expand_arg = [&](const Expr& arg) {
94+
if (auto sinfo = arg->struct_info_.as<TupleStructInfoNode>()) {
95+
for (size_t i = 0; i < sinfo->fields.size(); i++) {
96+
expand_arg(TupleGetItem(arg, i));
97+
}
98+
} else {
99+
new_args.push_back(arg);
100+
}
101+
};
102+
103+
for (const auto& arg : node->args) {
104+
expand_arg(arg);
105+
}
106+
107+
auto write_ptr = node.CopyOnWrite();
108+
write_ptr->op = it->second;
109+
write_ptr->args = new_args;
110+
}
111+
}
112+
113+
return node;
114+
}
115+
116+
PMap<GlobalVar, GlobalVar> replacements_;
117+
};
118+
119+
} // namespace
120+
121+
namespace transform {
122+
123+
Pass ExpandTupleArguments() {
124+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
125+
[=](IRModule mod, PassContext pc) -> IRModule {
126+
PMap<GlobalVar, GlobalVar> gvar_replacements;
127+
128+
{
129+
PMap<GlobalVar, Function> new_callees;
130+
131+
for (const auto& [gvar, base_func] : mod->functions) {
132+
if (auto func = base_func.as<Function>()) {
133+
if (auto opt = ExpandParams(func.value())) {
134+
auto new_func = opt.value();
135+
GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_);
136+
new_gvar->struct_info_ = new_func->struct_info_;
137+
gvar_replacements[gvar] = new_gvar;
138+
new_callees[new_gvar] = new_func;
139+
}
140+
}
141+
}
142+
143+
if (gvar_replacements.empty()) {
144+
return mod;
145+
}
146+
auto write_ptr = mod.CopyOnWrite();
147+
for (auto [old_gvar, new_gvar] : gvar_replacements) {
148+
write_ptr->Remove(old_gvar);
149+
write_ptr->Add(new_gvar, new_callees.at(new_gvar));
150+
}
151+
}
152+
153+
TupleExpander mutator(std::move(gvar_replacements));
154+
155+
IRModule caller_updates;
156+
157+
for (const auto& [gvar, base_func] : mod->functions) {
158+
if (auto func = base_func.as<Function>()) {
159+
auto mutated = Downcast<Function>(mutator.VisitExpr(func.value()));
160+
if (!mutated.same_as(base_func)) {
161+
caller_updates->Add(gvar, mutated);
162+
}
163+
}
164+
}
165+
166+
if (caller_updates->functions.size()) {
167+
mod.CopyOnWrite()->Update(caller_updates);
168+
}
169+
return mod;
170+
};
171+
auto inner_pass = CreateModulePass(pass_func, 0, "ExpandTupleArgumentsInner", {});
172+
173+
return tvm::transform::Sequential(
174+
{
175+
inner_pass,
176+
CanonicalizeBindings(),
177+
DeadCodeElimination({}),
178+
},
179+
"ExpandTupleArguments");
180+
}
181+
182+
TVM_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments").set_body_typed(ExpandTupleArguments);
183+
184+
} // namespace transform
185+
186+
} // namespace relax
187+
} // namespace tvm
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import tvm.testing
20+
from tvm.script import ir as I, relax as R, tir as T
21+
22+
23+
class BaseCompare(tvm.testing.CompareBeforeAfter):
24+
transform = tvm.relax.transform.ExpandTupleArguments()
25+
26+
27+
class TestSimple(BaseCompare):
28+
@I.ir_module
29+
class Before:
30+
@R.function
31+
def main(A: R.Tensor, B: R.Tensor):
32+
return Before.func((A, B))
33+
34+
@R.function(private=True)
35+
def func(args: R.Tuple([R.Tensor, R.Tensor])) -> R.Tensor:
36+
return args[0]
37+
38+
@I.ir_module
39+
class Expected:
40+
@R.function
41+
def main(A: R.Tensor, B: R.Tensor):
42+
return Expected.func(A, B)
43+
44+
@R.function(private=True)
45+
def func(A: R.Tensor, B: R.Tensor) -> R.Tensor:
46+
return A
47+
48+
49+
class TestNested(BaseCompare):
50+
@I.ir_module
51+
class Before:
52+
@R.function
53+
def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor:
54+
return Before.func(((A, B), (C, D)))
55+
56+
@R.function(private=True)
57+
def func(
58+
args: R.Tuple(
59+
[
60+
R.Tuple([R.Tensor, R.Tensor]),
61+
R.Tuple([R.Tensor, R.Tensor]),
62+
]
63+
)
64+
) -> R.Tensor:
65+
return args[0][1]
66+
67+
@I.ir_module
68+
class Expected:
69+
@R.function
70+
def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor:
71+
return Expected.func(A, B, C, D)
72+
73+
@R.function(private=True)
74+
def func(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor:
75+
return B
76+
77+
78+
if __name__ == "__main__":
79+
tvm.testing.main()

0 commit comments

Comments
 (0)