Skip to content

Commit 384360f

Browse files
authored
[Relax][Bugfix] Support torch.unbind op and fix bugs for expand && split (#17292)
* support unbind * add unit test * format fix * ignore logging in ut
1 parent 47e964a commit 384360f

File tree

9 files changed

+336
-29
lines changed

9 files changed

+336
-29
lines changed

python/tvm/contrib/msc/core/frontend/translate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def from_relax(
119119
)(mod)
120120
patterns = get_patterns_with_prefix("msc.")
121121
passes = [
122+
tvm.relax.transform.ExpandTupleArguments(),
122123
msc_transform.SetExprName(),
123124
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
124125
tvm.relax.transform.FuseOpsByPattern(
@@ -310,6 +311,7 @@ def byoc_partition(
310311
def _partition_mod(mod, as_msc=True):
311312
patterns = get_patterns_with_prefix(target)
312313
passes = [
314+
tvm.relax.transform.ExpandTupleArguments(),
313315
msc_transform.SetExprName(),
314316
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
315317
tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,22 @@ def _einsum(self, node: fx.node.Node) -> relax.Var:
526526
return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0]))
527527
return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))
528528

529+
def _unbind(self, node: fx.node.Node) -> relax.Var:
530+
if len(node.args) == 2:
531+
assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int"
532+
dim = node.args[1]
533+
elif "dim" in node.kwargs:
534+
dim = node.kwargs["dim"]
535+
else:
536+
dim = 0
537+
x = self.env[node.args[0]]
538+
selections = self.shape_of(x)[dim].value
539+
n_section = list(range(1, selections + 1))
540+
ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim))
541+
for i in range(selections):
542+
ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim)))
543+
return self.block_builder.emit(relax.Tuple(ret))
544+
529545
########## Manipulation ##########
530546

531547
def _cat(self, node: fx.node.Node) -> relax.Var:
@@ -535,7 +551,13 @@ def _cat(self, node: fx.node.Node) -> relax.Var:
535551

536552
def _expand(self, node: fx.node.Node) -> relax.Var:
537553
args = self.retrieve_args(node)
538-
return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:]))
554+
broadcast_shape, in_shape = [], self.shape_of(args[0])
555+
for idx, i in enumerate(args[1:]):
556+
if isinstance(i, int) and i == -1:
557+
broadcast_shape.append(in_shape[idx])
558+
else:
559+
broadcast_shape.append(i)
560+
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))
539561

540562
def _flatten(self, node: fx.node.Node) -> relax.Var:
541563
x = self.env[node.args[0]]
@@ -580,7 +602,13 @@ def _split(self, node: fx.node.Node) -> relax.Var:
580602
dim = node.kwargs["dim"]
581603
else:
582604
dim = 0
583-
n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size
605+
if isinstance(split_size, (list, tuple)):
606+
n_section = []
607+
for s in split_size[:-1]:
608+
cum_sum = 0 if not n_section else n_section[-1]
609+
n_section.append(s + cum_sum)
610+
else:
611+
n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size
584612
return self.block_builder.emit(relax.op.split(x, n_section, dim))
585613

586614
def _chunk(self, node: fx.node.Node) -> relax.Var:
@@ -1501,6 +1529,7 @@ def create_convert_map(self):
15011529
"cross_entropy": self._cross_entropy,
15021530
"scaled_dot_product_attention": self._scaled_dot_product_attention,
15031531
"einsum": self._einsum,
1532+
"unbind": self._unbind,
15041533
}
15051534

15061535
def update_convert_map(self, custom_convert_map: dict):

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3):
13451345
def test_split():
13461346
"""test graph builder for split"""
13471347

1348-
class Split(Module):
1348+
class Split1(Module):
13491349
def forward(self, data):
13501350
return torch.split(data, 1, dim=1)
13511351

1352-
expected = {
1352+
class Split2(Module):
1353+
def forward(self, data):
1354+
return torch.split(data, [1, 2], dim=1)
1355+
1356+
expected1 = {
13531357
"inputs": [
13541358
{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
13551359
],
@@ -1361,8 +1365,43 @@ def forward(self, data):
13611365
"nodes": {"total": 2, "input": 1, "split": 1},
13621366
}
13631367

1368+
expected2 = {
1369+
"inputs": [
1370+
{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
1371+
],
1372+
"outputs": [
1373+
{"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
1374+
{"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"},
1375+
],
1376+
"nodes": {"total": 2, "input": 1, "split": 1},
1377+
}
1378+
1379+
input_info = [([1, 3, 10, 10], "float32")]
1380+
verify_model(Split1(), input_info, expected1)
1381+
verify_model(Split2(), input_info, expected2)
1382+
1383+
1384+
def test_unbind():
1385+
"""test graph builder for unbind"""
1386+
1387+
class Unbind(Module):
1388+
def forward(self, data):
1389+
return torch.unbind(data, dim=1)
1390+
1391+
expected = {
1392+
"inputs": [
1393+
{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
1394+
],
1395+
"outputs": [
1396+
{"name": "tuple_0", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"},
1397+
{"name": "tuple_1", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"},
1398+
{"name": "tuple_2", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"},
1399+
],
1400+
"nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1},
1401+
}
1402+
13641403
input_info = [([1, 3, 10, 10], "float32")]
1365-
verify_model(Split(), input_info, expected)
1404+
verify_model(Unbind(), input_info, expected)
13661405

13671406

13681407
def test_cumsum():
@@ -1547,10 +1586,14 @@ def forward(self, x):
15471586
def test_expand():
15481587
"""test graph builder for expand"""
15491588

1550-
class Expand(Module):
1589+
class Expand1(Module):
15511590
def forward(self, x):
15521591
return x.expand(4, 2, 3, 4)
15531592

1593+
class Expand2(Module):
1594+
def forward(self, x):
1595+
return x.expand(4, -1, -1, 4)
1596+
15541597
expected = {
15551598
"inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}],
15561599
"outputs": [
@@ -1560,7 +1603,8 @@ def forward(self, x):
15601603
}
15611604

15621605
input_info = [([1, 2, 3, 4], "float32")]
1563-
verify_model(Expand(), input_info, expected)
1606+
verify_model(Expand1(), input_info, expected)
1607+
verify_model(Expand2(), input_info, expected)
15641608

15651609

15661610
def test_reduce():

tests/python/contrib/test_msc/test_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1
3838
path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static")
3939
return {
4040
"workspace": msc_utils.msc_dir(path),
41-
"verbose": "info",
41+
"verbose": "critical",
4242
"model_type": model_type,
4343
"inputs": inputs,
4444
"outputs": outputs,

tests/python/contrib/test_msc/test_translate_relax.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ def _run_relax(relax_mod):
6767

6868
orig_output = _run_relax(orig_mod)
6969
rt_output = _run_relax(rt_mod)
70-
tvm.testing.assert_allclose(orig_output, rt_output)
70+
if not isinstance(orig_output, (list, tuple)):
71+
orig_output = [orig_output]
72+
if not isinstance(rt_output, (list, tuple)):
73+
rt_output = [rt_output]
74+
for o_out, r_out in zip(orig_output, rt_output):
75+
tvm.testing.assert_allclose(o_out, r_out)
7176

7277

7378
def test_conv1d():
@@ -750,12 +755,33 @@ def forward(self, x_1, x_2, x_3):
750755
def test_split():
751756
"""test relax translator for split"""
752757

753-
class Split(Module):
758+
class Split1(Module):
754759
def forward(self, data):
755760
return torch.split(data, 1, dim=1)
756761

762+
class Split2(Module):
763+
def forward(self, data):
764+
return torch.split(data, [1, 2], dim=1)
765+
757766
input_info = [([1, 3, 10, 10], "float32")]
758-
_verify_model(Split(), input_info)
767+
_verify_model(Split1(), input_info)
768+
_verify_model(Split2(), input_info)
769+
770+
771+
def test_unbind():
772+
"""test relax translator for unbind"""
773+
774+
class Unbind1(Module):
775+
def forward(self, data):
776+
return torch.unbind(data)
777+
778+
class Unbind2(Module):
779+
def forward(self, data):
780+
return torch.unbind(data, dim=1)
781+
782+
input_info = [([3, 3, 10, 10], "float32")]
783+
_verify_model(Unbind1(), input_info)
784+
_verify_model(Unbind2(), input_info)
759785

760786

761787
def test_cumsum():
@@ -874,12 +900,17 @@ def forward(self, x):
874900
def test_expand():
875901
"""test relax translator for expand"""
876902

877-
class Expand(Module):
903+
class Expand1(Module):
878904
def forward(self, x):
879905
return x.expand(4, 2, 3, 4)
880906

907+
class Expand2(Module):
908+
def forward(self, x):
909+
return x.expand(4, -1, -1, 4)
910+
881911
input_info = [([1, 2, 3, 4], "float32")]
882-
_verify_model(Expand(), input_info)
912+
_verify_model(Expand1(), input_info)
913+
_verify_model(Expand2(), input_info)
883914

884915

885916
def test_reduce():

tests/python/contrib/test_msc/test_translate_relay.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,33 @@ def forward(self, x_1, x_2, x_3):
731731
def test_split():
732732
"""test relay to relax for split"""
733733

734-
class Split(Module):
734+
class Split1(Module):
735735
def forward(self, data):
736736
return torch.split(data, 1, dim=1)
737737

738+
class Split2(Module):
739+
def forward(self, data):
740+
return torch.split(data, [1, 2], dim=1)
741+
738742
input_info = [([1, 3, 10, 10], "float32")]
739-
verify_model(Split(), input_info, build_target="llvm")
743+
verify_model(Split1(), input_info, build_target="llvm")
744+
verify_model(Split2(), input_info, build_target="llvm")
745+
746+
747+
def test_unbind():
748+
"""test relay to relax for unbind"""
749+
750+
class Unbind1(Module):
751+
def forward(self, data):
752+
return torch.unbind(data)
753+
754+
class Unbind2(Module):
755+
def forward(self, data):
756+
return torch.unbind(data, dim=1)
757+
758+
input_info = [([3, 3, 10, 10], "float32")]
759+
verify_model(Unbind1(), input_info, build_target="llvm")
760+
verify_model(Unbind2(), input_info, build_target="llvm")
740761

741762

742763
def test_cumsum():
@@ -859,12 +880,17 @@ def forward(self, x):
859880
def test_expand():
860881
"""test relay to relax for expand"""
861882

862-
class Expand(Module):
883+
class Expand1(Module):
863884
def forward(self, x):
864885
return x.expand(4, 2, 3, 4)
865886

887+
class Expand2(Module):
888+
def forward(self, x):
889+
return x.expand(4, -1, -1, 4)
890+
866891
input_info = [([1, 2, 3, 4], "float32")]
867-
verify_model(Expand(), input_info, build_target="llvm")
892+
verify_model(Expand1(), input_info, build_target="llvm")
893+
verify_model(Expand2(), input_info, build_target="llvm")
868894

869895

870896
def test_reduce():

tests/python/contrib/test_msc/test_translate_tensorrt.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,34 @@ def forward(self, x_1, x_2, x_3):
673673
def test_split():
674674
"""test tensorrt translator for split"""
675675

676-
class Split(Module):
676+
class Split1(Module):
677677
def forward(self, data):
678678
return torch.split(data, 1, dim=1)
679679

680+
class Split2(Module):
681+
def forward(self, data):
682+
return torch.split(data, [1, 2], dim=1)
683+
680684
input_info = [([1, 3, 10, 10], "float32")]
681-
verify_model(Split(), input_info)
685+
verify_model(Split1(), input_info)
686+
verify_model(Split2(), input_info)
687+
688+
689+
@requires_tensorrt
690+
def test_unbind():
691+
"""test tensorrt to relax for unbind"""
692+
693+
class Unbind1(Module):
694+
def forward(self, data):
695+
return torch.unbind(data)
696+
697+
class Unbind2(Module):
698+
def forward(self, data):
699+
return torch.unbind(data, dim=1)
700+
701+
input_info = [([3, 3, 10, 10], "float32")]
702+
verify_model(Unbind1(), input_info)
703+
verify_model(Unbind2(), input_info)
682704

683705

684706
@requires_tensorrt
@@ -697,13 +719,19 @@ def forward(self, data):
697719
def test_expand():
698720
"""test tensorrt translator for expand"""
699721

700-
class Expand(Module):
722+
class Expand1(Module):
701723
def forward(self, x):
702724
x = x + 1.0
703725
return x.expand(4, 2, 3, 4)
704726

727+
class Expand2(Module):
728+
def forward(self, x):
729+
x = x + 1.0
730+
return x.expand(4, -1, -1, 4)
731+
705732
input_info = [([1, 2, 3, 4], "float32")]
706-
verify_model(Expand(), input_info)
733+
verify_model(Expand1(), input_info)
734+
verify_model(Expand2(), input_info)
707735

708736

709737
@requires_tensorrt

0 commit comments

Comments
 (0)