|
26 | 26 | is_tuple_get_item, |
27 | 27 | make_fused_bias_activation_pattern, |
28 | 28 | wildcard, |
| 29 | + is_tuple, |
29 | 30 | ) |
30 | 31 | from tvm.relax.transform import PatternCheckContext |
31 | 32 | from tvm.script import ir as I |
@@ -1348,5 +1349,77 @@ def local_func( |
1348 | 1349 | tvm.ir.assert_structural_equal(Expected, After) |
1349 | 1350 |
|
1350 | 1351 |
|
| 1352 | +def test_concat(): |
| 1353 | + @R.function |
| 1354 | + def func(x: R.Tensor((10,), "float32"), y: R.Tensor((10,), "float32")): |
| 1355 | + R.func_attr({"global_symbol": "main"}) |
| 1356 | + with R.dataflow(): |
| 1357 | + lv = R.abs(x) |
| 1358 | + lv1 = R.abs(y) |
| 1359 | + lv2 = R.concat([lv, lv1]) |
| 1360 | + gv = R.nn.relu(lv2) |
| 1361 | + R.output(gv) |
| 1362 | + return gv |
| 1363 | + |
| 1364 | + @I.ir_module |
| 1365 | + class Expected1: |
| 1366 | + @R.function(private=True) |
| 1367 | + def fused_relax_abs_relax_abs_relax_concat( |
| 1368 | + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") |
| 1369 | + ) -> R.Tensor((20,), dtype="float32"): |
| 1370 | + R.func_attr({"Composite": "x.concat_abs_abs", "Primitive": True}) |
| 1371 | + with R.dataflow(): |
| 1372 | + lv: R.Tensor((10,), dtype="float32") = R.abs(x) |
| 1373 | + lv1: R.Tensor((10,), dtype="float32") = R.abs(y) |
| 1374 | + gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0) |
| 1375 | + R.output(gv) |
| 1376 | + return gv |
| 1377 | + |
| 1378 | + @R.function |
| 1379 | + def main( |
| 1380 | + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") |
| 1381 | + ) -> R.Tensor((20,), dtype="float32"): |
| 1382 | + with R.dataflow(): |
| 1383 | + lv: R.Tensor( |
| 1384 | + (20,), dtype="float32" |
| 1385 | + ) = Expected1.fused_relax_abs_relax_abs_relax_concat(x, y) |
| 1386 | + gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv) |
| 1387 | + R.output(gv) |
| 1388 | + return gv |
| 1389 | + |
| 1390 | + mod = tvm.IRModule({"main": func}) |
| 1391 | + inp = is_tuple([is_op("relax.abs")(wildcard()), is_op("relax.abs")(wildcard())]) |
| 1392 | + pat_clip = is_op("relax.concat")(inp) |
| 1393 | + |
| 1394 | + check(mod, [("x.concat_abs_abs", pat_clip)], Expected1) |
| 1395 | + |
| 1396 | + @I.ir_module |
| 1397 | + class Expected2: |
| 1398 | + @R.function(private=True) |
| 1399 | + def fused_relax_concat( |
| 1400 | + lv: R.Tensor((10,), dtype="float32"), lv1: R.Tensor((10,), dtype="float32") |
| 1401 | + ) -> R.Tensor((20,), dtype="float32"): |
| 1402 | + R.func_attr({"Composite": "x.concat", "Primitive": True}) |
| 1403 | + with R.dataflow(): |
| 1404 | + gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), axis=0) |
| 1405 | + R.output(gv) |
| 1406 | + return gv |
| 1407 | + |
| 1408 | + @R.function |
| 1409 | + def main( |
| 1410 | + x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), dtype="float32") |
| 1411 | + ) -> R.Tensor((20,), dtype="float32"): |
| 1412 | + with R.dataflow(): |
| 1413 | + lv: R.Tensor((10,), dtype="float32") = R.abs(x) |
| 1414 | + lv1: R.Tensor((10,), dtype="float32") = R.abs(y) |
| 1415 | + lv_1: R.Tensor((20,), dtype="float32") = Expected2.fused_relax_concat(lv, lv1) |
| 1416 | + gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv_1) |
| 1417 | + R.output(gv) |
| 1418 | + return gv |
| 1419 | + |
| 1420 | + pat_clip = is_op("relax.concat")(wildcard()) |
| 1421 | + check(mod, [("x.concat", pat_clip)], Expected2) |
| 1422 | + |
| 1423 | + |
1351 | 1424 | if __name__ == "__main__": |
1352 | 1425 | pytest.main([__file__]) |
0 commit comments