@@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3):
13451345def test_split ():
13461346 """test graph builder for split"""
13471347
1348- class Split (Module ):
1348+ class Split1 (Module ):
13491349 def forward (self , data ):
13501350 return torch .split (data , 1 , dim = 1 )
13511351
1352- expected = {
1352+ class Split2 (Module ):
1353+ def forward (self , data ):
1354+ return torch .split (data , [1 , 2 ], dim = 1 )
1355+
1356+ expected1 = {
13531357 "inputs" : [
13541358 {"name" : "inp_0" , "shape" : [1 , 3 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" }
13551359 ],
@@ -1361,8 +1365,20 @@ def forward(self, data):
13611365 "nodes" : {"total" : 2 , "input" : 1 , "split" : 1 },
13621366 }
13631367
1368+ expected2 = {
1369+ "inputs" : [
1370+ {"name" : "inp_0" , "shape" : [1 , 3 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" }
1371+ ],
1372+ "outputs" : [
1373+ {"name" : "split_0" , "shape" : [1 , 1 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" },
1374+ {"name" : "split_1" , "shape" : [1 , 2 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" },
1375+ ],
1376+ "nodes" : {"total" : 2 , "input" : 1 , "split" : 1 },
1377+ }
1378+
13641379 input_info = [([1 , 3 , 10 , 10 ], "float32" )]
1365- verify_model (Split (), input_info , expected )
1380+ verify_model (Split1 (), input_info , expected1 )
1381+ verify_model (Split2 (), input_info , expected2 )
13661382
13671383
13681384def test_unbind ():
@@ -1570,10 +1586,14 @@ def forward(self, x):
15701586def test_expand ():
15711587 """test graph builder for expand"""
15721588
1573- class Expand (Module ):
1589+ class Expand1 (Module ):
15741590 def forward (self , x ):
15751591 return x .expand (4 , 2 , 3 , 4 )
15761592
1593+ class Expand2 (Module ):
1594+ def forward (self , x ):
1595+ return x .expand (4 , - 1 , - 1 , 4 )
1596+
15771597 expected = {
15781598 "inputs" : [{"name" : "inp_0" , "shape" : [1 , 2 , 3 , 4 ], "dtype" : "float32" , "layout" : "" }],
15791599 "outputs" : [
@@ -1583,7 +1603,8 @@ def forward(self, x):
15831603 }
15841604
15851605 input_info = [([1 , 2 , 3 , 4 ], "float32" )]
1586- verify_model (Expand (), input_info , expected )
1606+ verify_model (Expand1 (), input_info , expected )
1607+ verify_model (Expand2 (), input_info , expected )
15871608
15881609
15891610def test_reduce ():
0 commit comments