Skip to content

Commit 07478af

Browse files
authored
[Relax] Fix issue in fuse concat ops by pattern (#18163)
* [Relax] Fix issue in fuse concat ops by pattern * fix lint
1 parent 8a914e5 commit 07478af

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,16 @@ class FunctionCreator : public ExprMutator {
427427
}
428428

429429
for (const Expr& arg : call->args) {
430-
CheckDefAndUpdateParam(arg);
431430
if (GetStructInfoAs<TupleStructInfoNode>(arg) != nullptr) {
432431
// The argument is fully referenced. Thus we remove it from the mapping.
433432
partially_used_tuple_params_.erase(arg.get());
433+
const Tuple& tup_args = Downcast<Tuple>(arg);
434+
for (const Expr& tup_arg : tup_args->fields) {
435+
CheckDefAndUpdateParam(tup_arg);
436+
ICHECK(GetStructInfoAs<TupleStructInfoNode>(tup_arg) == nullptr);
437+
}
438+
} else {
439+
CheckDefAndUpdateParam(arg);
434440
}
435441
}
436442
}

tests/python/relax/test_transform_fuse_ops_by_pattern.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_tuple_get_item,
2727
make_fused_bias_activation_pattern,
2828
wildcard,
29+
is_tuple,
2930
)
3031
from tvm.relax.transform import PatternCheckContext
3132
from tvm.script import ir as I
@@ -1348,5 +1349,77 @@ def local_func(
13481349
tvm.ir.assert_structural_equal(Expected, After)
13491350

13501351

1352+
def test_concat():
1353+
@R.function
1354+
def func(x: R.Tensor((10,), "float32"), y: R.Tensor((10,), "float32")):
1355+
R.func_attr({"global_symbol": "main"})
1356+
with R.dataflow():
1357+
lv = R.abs(x)
1358+
lv1 = R.abs(y)
1359+
lv2 = R.concat([lv, lv1])
1360+
gv = R.nn.relu(lv2)
1361+
R.output(gv)
1362+
return gv
1363+
1364+
@I.ir_module
1365+
class Expected1:
1366+
@R.function(private=True)
1367+
def fused_relax_abs_relax_abs_relax_concat(
1368+
x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32")
1369+
) -> R.Tensor((20,), dtype="float32"):
1370+
R.func_attr({"Composite": "x.concat_abs_abs", "Primitive": True})
1371+
with R.dataflow():
1372+
lv: R.Tensor((10,), dtype="float32") = R.abs(x)
1373+
lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
1374+
gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0)
1375+
R.output(gv)
1376+
return gv
1377+
1378+
@R.function
1379+
def main(
1380+
x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32")
1381+
) -> R.Tensor((20,), dtype="float32"):
1382+
with R.dataflow():
1383+
lv: R.Tensor(
1384+
(20,), dtype="float32"
1385+
) = Expected1.fused_relax_abs_relax_abs_relax_concat(x, y)
1386+
gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv)
1387+
R.output(gv)
1388+
return gv
1389+
1390+
mod = tvm.IRModule({"main": func})
1391+
inp = is_tuple([is_op("relax.abs")(wildcard()), is_op("relax.abs")(wildcard())])
1392+
pat_clip = is_op("relax.concat")(inp)
1393+
1394+
check(mod, [("x.concat_abs_abs", pat_clip)], Expected1)
1395+
1396+
@I.ir_module
1397+
class Expected2:
1398+
@R.function(private=True)
1399+
def fused_relax_concat(
1400+
lv: R.Tensor((10,), dtype="float32"), lv1: R.Tensor((10,), dtype="float32")
1401+
) -> R.Tensor((20,), dtype="float32"):
1402+
R.func_attr({"Composite": "x.concat", "Primitive": True})
1403+
with R.dataflow():
1404+
gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0)
1405+
R.output(gv)
1406+
return gv
1407+
1408+
@R.function
1409+
def main(
1410+
x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32")
1411+
) -> R.Tensor((20,), dtype="float32"):
1412+
with R.dataflow():
1413+
lv: R.Tensor((10,), dtype="float32") = R.abs(x)
1414+
lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
1415+
lv_1: R.Tensor((20,), dtype="float32") = Expected2.fused_relax_concat(lv, lv1)
1416+
gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv_1)
1417+
R.output(gv)
1418+
return gv
1419+
1420+
pat_clip = is_op("relax.concat")(wildcard())
1421+
check(mod, [("x.concat", pat_clip)], Expected2)
1422+
1423+
13511424
if __name__ == "__main__":
13521425
pytest.main([__file__])

0 commit comments

Comments
 (0)