Skip to content

Commit f62445c

Browse files
authored
[Relax] Disable fusion for fetching from the packed params in FuseOps (#17198)
* [Relax] Disable fusion for fetching from the packed params in FuseOps The order of bindings in the fusion result is determined by the first binding in each partition group. When the packed param tuple is used, the function usually begins with a numbers of `TupleGetItem` to unpack the param tuple. Previously `TupleGetItem` is treated as `kInjective`, this causes any operation that relies purely on these params to be moved to the beginning of the function and increases the memory usage of the intermediate results. * lint
1 parent 4330c11 commit f62445c

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ class GraphCreator : public ExprVisitor {
147147
SetNodePattern(param_node, OpPatternKind::kOpaque);
148148
AddToPostDFSOrder(param_node, param.get());
149149
}
150+
if (auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput)) {
151+
for (int i = static_cast<int>(opt_num_input.value()->value);
152+
i < static_cast<int>(func->params.size()); ++i) {
153+
input_params_.insert(func->params[i].get());
154+
}
155+
}
150156
ExprVisitor::VisitExpr_(func);
151157
}
152158

@@ -224,8 +230,15 @@ class GraphCreator : public ExprVisitor {
224230
IndexedForwardGraph::Node* binding_var_node) {
225231
ICHECK_NOTNULL(binding_var_node);
226232

227-
SetNodePattern(binding_var_node, OpPatternKind::kInjective);
228-
VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective);
233+
auto pattern = OpPatternKind::kInjective;
234+
if (input_params_.count(tuple_item->tuple.as<VarNode>())) {
235+
// TupleGetItem for fetching the parameter from the packed param tuple is treated as opaque
236+
// and won't be fused. This prevents the usage of packed param tuple changes the order of the
237+
// fusion result as the function usually begins with fetching the parameters.
238+
pattern = OpPatternKind::kOpaque;
239+
}
240+
SetNodePattern(binding_var_node, pattern);
241+
VisitLeaf(tuple_item->tuple, binding_var_node, pattern);
229242
}
230243

231244
void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) {
@@ -354,6 +367,8 @@ class GraphCreator : public ExprVisitor {
354367
IndexedForwardGraph graph_;
355368
/*! \brief The graph nodes whose patterns are set */
356369
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
370+
/*! \brief The model params in the function input */
371+
std::unordered_set<const VarNode*> input_params_;
357372
};
358373

359374
/*!

tests/python/relax/test_transform_fuse_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,5 +1642,53 @@ def main(
16421642
_check(Module, Expected)
16431643

16441644

1645+
def test_packed_params():
1646+
# fmt: off
1647+
@I.ir_module
1648+
class Before:
1649+
@T.prim_func(private=True)
1650+
def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")):
1651+
T.func_attr({"tir.noalias": T.bool(True)})
1652+
# with T.block("root"):
1653+
for i0, i1 in T.grid(T.int64(16), T.int64(16)):
1654+
with T.block("compute"):
1655+
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
1656+
T.reads(lv[v_i0, v_i1])
1657+
T.writes(compute[v_i0, v_i1])
1658+
compute[v_i0, v_i1] = T.Cast("float32", lv[v_i0, v_i1])
1659+
1660+
@T.prim_func(private=True)
1661+
def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")):
1662+
T.func_attr({"tir.noalias": T.bool(True)})
1663+
# with T.block("root"):
1664+
for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)):
1665+
with T.block("T_matmul"):
1666+
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
1667+
T.reads(x[v_ax0, v_k], lv2[v_k, v_ax1])
1668+
T.writes(T_matmul[v_ax0, v_ax1])
1669+
with T.init():
1670+
T_matmul[v_ax0, v_ax1] = T.float32(0)
1671+
T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + x[v_ax0, v_k] * lv2[v_k, v_ax1]
1672+
1673+
@R.function
1674+
def main(x: R.Tensor((16, 16), dtype="float32"), packed_params: R.Tuple(R.Tensor((16, 16), dtype="float16"), R.Tensor((16, 16), dtype="float16"))) -> R.Tensor((16, 16), dtype="float32"):
1675+
R.func_attr({"num_input": 1})
1676+
cls = Before
1677+
with R.dataflow():
1678+
lv: R.Tensor((16, 16), dtype="float16") = packed_params[0]
1679+
lv1: R.Tensor((16, 16), dtype="float16") = packed_params[1]
1680+
lv2 = R.call_tir(cls.cast, (lv,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
1681+
lv3 = R.call_tir(cls.matmul, (x, lv2), out_sinfo=R.Tensor((16, 16), dtype="float32"))
1682+
lv4 = R.call_tir(cls.cast, (lv1,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
1683+
lv5 = R.call_tir(cls.matmul, (lv3, lv4), out_sinfo=R.Tensor((16, 16), dtype="float32"))
1684+
gv: R.Tensor((16, 16), dtype="float32") = lv5
1685+
R.output(gv)
1686+
return gv
1687+
# fmt: on
1688+
1689+
Expected = Before
1690+
_check(Before, Expected)
1691+
1692+
16451693
if __name__ == "__main__":
16461694
tvm.testing.main()

0 commit comments

Comments
 (0)