Skip to content

Commit 3eed4e3

Browse files
committed
add unit test
1 parent 857e214 commit 3eed4e3

File tree

7 files changed

+212
-33
lines changed

7 files changed

+212
-33
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def _split(self, node: fx.node.Node) -> relax.Var:
604604
dim = 0
605605
if isinstance(split_size, (list, tuple)):
606606
n_section = []
607-
for s in split_size:
607+
for s in split_size[:-1]:
608608
cum_sum = 0 if not n_section else n_section[-1]
609609
n_section.append(s + cum_sum)
610610
else:

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3):
13451345
def test_split():
13461346
"""test graph builder for split"""
13471347

1348-
class Split(Module):
1348+
class Split1(Module):
13491349
def forward(self, data):
13501350
return torch.split(data, 1, dim=1)
13511351

1352-
expected = {
1352+
class Split2(Module):
1353+
def forward(self, data):
1354+
return torch.split(data, [1, 2], dim=1)
1355+
1356+
expected1 = {
13531357
"inputs": [
13541358
{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
13551359
],
@@ -1361,8 +1365,20 @@ def forward(self, data):
13611365
"nodes": {"total": 2, "input": 1, "split": 1},
13621366
}
13631367

1368+
expected2 = {
1369+
"inputs": [
1370+
{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
1371+
],
1372+
"outputs": [
1373+
{"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
1374+
{"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"},
1375+
],
1376+
"nodes": {"total": 2, "input": 1, "split": 1},
1377+
}
1378+
13641379
input_info = [([1, 3, 10, 10], "float32")]
1365-
verify_model(Split(), input_info, expected)
1380+
verify_model(Split1(), input_info, expected1)
1381+
verify_model(Split2(), input_info, expected2)
13661382

13671383

13681384
def test_unbind():
@@ -1570,10 +1586,14 @@ def forward(self, x):
15701586
def test_expand():
15711587
"""test graph builder for expand"""
15721588

1573-
class Expand(Module):
1589+
class Expand1(Module):
15741590
def forward(self, x):
15751591
return x.expand(4, 2, 3, 4)
15761592

1593+
class Expand2(Module):
1594+
def forward(self, x):
1595+
return x.expand(4, -1, -1, 4)
1596+
15771597
expected = {
15781598
"inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}],
15791599
"outputs": [
@@ -1583,7 +1603,8 @@ def forward(self, x):
15831603
}
15841604

15851605
input_info = [([1, 2, 3, 4], "float32")]
1586-
verify_model(Expand(), input_info, expected)
1606+
verify_model(Expand1(), input_info, expected)
1607+
verify_model(Expand2(), input_info, expected)
15871608

15881609

15891610
def test_reduce():

tests/python/contrib/test_msc/test_translate_relax.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ def _run_relax(relax_mod):
6767

6868
orig_output = _run_relax(orig_mod)
6969
rt_output = _run_relax(rt_mod)
70-
tvm.testing.assert_allclose(orig_output, rt_output)
70+
if not isinstance(orig_output, (list, tuple)):
71+
orig_output = [orig_output]
72+
if not isinstance(rt_output, (list, tuple)):
73+
rt_output = [rt_output]
74+
for o_out, r_out in zip(orig_output, rt_output):
75+
tvm.testing.assert_allclose(o_out, r_out)
7176

7277

7378
def test_conv1d():
@@ -750,12 +755,17 @@ def forward(self, x_1, x_2, x_3):
750755
def test_split():
751756
"""test relax translator for split"""
752757

753-
class Split(Module):
758+
class Split1(Module):
754759
def forward(self, data):
755760
return torch.split(data, 1, dim=1)
756761

762+
class Split2(Module):
763+
def forward(self, data):
764+
return torch.split(data, [1, 2], dim=1)
765+
757766
input_info = [([1, 3, 10, 10], "float32")]
758-
_verify_model(Split(), input_info)
767+
_verify_model(Split1(), input_info)
768+
_verify_model(Split2(), input_info)
759769

760770

761771
def test_unbind():
@@ -890,12 +900,17 @@ def forward(self, x):
890900
def test_expand():
891901
"""test relax translator for expand"""
892902

893-
class Expand(Module):
903+
class Expand1(Module):
894904
def forward(self, x):
895905
return x.expand(4, 2, 3, 4)
896906

907+
class Expand2(Module):
908+
def forward(self, x):
909+
return x.expand(4, -1, -1, 4)
910+
897911
input_info = [([1, 2, 3, 4], "float32")]
898-
_verify_model(Expand(), input_info)
912+
_verify_model(Expand1(), input_info)
913+
_verify_model(Expand2(), input_info)
899914

900915

901916
def test_reduce():

tests/python/contrib/test_msc/test_translate_relay.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,17 @@ def forward(self, x_1, x_2, x_3):
731731
def test_split():
732732
"""test relay to relax for split"""
733733

734-
class Split(Module):
734+
class Split1(Module):
735735
def forward(self, data):
736736
return torch.split(data, 1, dim=1)
737737

738+
class Split2(Module):
739+
def forward(self, data):
740+
return torch.split(data, [1, 2], dim=1)
741+
738742
input_info = [([1, 3, 10, 10], "float32")]
739-
verify_model(Split(), input_info, build_target="llvm")
743+
verify_model(Split1(), input_info, build_target="llvm")
744+
verify_model(Split2(), input_info, build_target="llvm")
740745

741746

742747
def test_unbind():
@@ -875,12 +880,17 @@ def forward(self, x):
875880
def test_expand():
876881
"""test relay to relax for expand"""
877882

878-
class Expand(Module):
883+
class Expand1(Module):
879884
def forward(self, x):
880885
return x.expand(4, 2, 3, 4)
881886

887+
class Expand2(Module):
888+
def forward(self, x):
889+
return x.expand(4, -1, -1, 4)
890+
882891
input_info = [([1, 2, 3, 4], "float32")]
883-
verify_model(Expand(), input_info, build_target="llvm")
892+
verify_model(Expand1(), input_info, build_target="llvm")
893+
verify_model(Expand2(), input_info, build_target="llvm")
884894

885895

886896
def test_reduce():

tests/python/contrib/test_msc/test_translate_tensorrt.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,17 @@ def forward(self, x_1, x_2, x_3):
673673
def test_split():
674674
"""test tensorrt translator for split"""
675675

676-
class Split(Module):
676+
class Split1(Module):
677677
def forward(self, data):
678678
return torch.split(data, 1, dim=1)
679679

680+
class Split2(Module):
681+
def forward(self, data):
682+
return torch.split(data, [1, 2], dim=1)
683+
680684
input_info = [([1, 3, 10, 10], "float32")]
681-
verify_model(Split(), input_info)
685+
verify_model(Split1(), input_info)
686+
verify_model(Split2(), input_info)
682687

683688

684689
@requires_tensorrt
@@ -714,13 +719,19 @@ def forward(self, data):
714719
def test_expand():
715720
"""test tensorrt translator for expand"""
716721

717-
class Expand(Module):
722+
class Expand1(Module):
718723
def forward(self, x):
719724
x = x + 1.0
720725
return x.expand(4, 2, 3, 4)
721726

727+
class Expand2(Module):
728+
def forward(self, x):
729+
x = x + 1.0
730+
return x.expand(4, -1, -1, 4)
731+
722732
input_info = [([1, 2, 3, 4], "float32")]
723-
verify_model(Expand(), input_info)
733+
verify_model(Expand1(), input_info)
734+
verify_model(Expand2(), input_info)
724735

725736

726737
@requires_tensorrt

tests/python/contrib/test_msc/test_translate_torch.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -728,13 +728,18 @@ def forward(self, x_1, x_2, x_3):
728728
def test_split():
729729
"""test torch translator for split"""
730730

731-
class Split(Module):
731+
class Split1(Module):
732732
def forward(self, data):
733733
return torch.split(data, 1, dim=1)
734734

735+
class Split2(Module):
736+
def forward(self, data):
737+
return torch.split(data, [1, 2], dim=1)
738+
735739
input_info = [([1, 3, 10, 10], "float32")]
736740
for via_relax in [True, False]:
737-
verify_model(Split(), input_info, via_relax)
741+
verify_model(Split1(), input_info, via_relax)
742+
verify_model(Split2(), input_info, via_relax)
738743

739744

740745
def test_unbind():
@@ -852,13 +857,18 @@ def forward(self, x):
852857
def test_expand():
853858
"""test torch translator for expand"""
854859

855-
class Expand(Module):
860+
class Expand1(Module):
856861
def forward(self, x):
857862
return x.expand(4, 2, 3, 4)
858863

864+
class Expand2(Module):
865+
def forward(self, x):
866+
return x.expand(4, -1, -1, 4)
867+
859868
input_info = [([1, 2, 3, 4], "float32")]
860869
for via_relax in [True, False]:
861-
verify_model(Expand(), input_info, via_relax)
870+
verify_model(Expand1(), input_info, via_relax)
871+
verify_model(Expand2(), input_info, via_relax)
862872

863873

864874
def test_reduce():

0 commit comments

Comments
 (0)