@@ -1345,11 +1345,15 @@ def forward(self, x_1, x_2, x_3):
13451345def test_split ():
13461346 """test graph builder for split"""
13471347
1348- class Split (Module ):
1348+ class Split1 (Module ):
13491349 def forward (self , data ):
13501350 return torch .split (data , 1 , dim = 1 )
13511351
1352- expected = {
1352+ class Split2 (Module ):
1353+ def forward (self , data ):
1354+ return torch .split (data , [1 , 2 ], dim = 1 )
1355+
1356+ expected1 = {
13531357 "inputs" : [
13541358 {"name" : "inp_0" , "shape" : [1 , 3 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" }
13551359 ],
@@ -1361,8 +1365,43 @@ def forward(self, data):
13611365 "nodes" : {"total" : 2 , "input" : 1 , "split" : 1 },
13621366 }
13631367
1368+ expected2 = {
1369+ "inputs" : [
1370+ {"name" : "inp_0" , "shape" : [1 , 3 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" }
1371+ ],
1372+ "outputs" : [
1373+ {"name" : "split_0" , "shape" : [1 , 1 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" },
1374+ {"name" : "split_1" , "shape" : [1 , 2 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" },
1375+ ],
1376+ "nodes" : {"total" : 2 , "input" : 1 , "split" : 1 },
1377+ }
1378+
1379+ input_info = [([1 , 3 , 10 , 10 ], "float32" )]
1380+ verify_model (Split1 (), input_info , expected1 )
1381+ verify_model (Split2 (), input_info , expected2 )
1382+
1383+
1384+ def test_unbind ():
1385+ """test graph builder for unbind"""
1386+
1387+ class Unbind (Module ):
1388+ def forward (self , data ):
1389+ return torch .unbind (data , dim = 1 )
1390+
1391+ expected = {
1392+ "inputs" : [
1393+ {"name" : "inp_0" , "shape" : [1 , 3 , 10 , 10 ], "dtype" : "float32" , "layout" : "ABCD" }
1394+ ],
1395+ "outputs" : [
1396+ {"name" : "tuple_0" , "shape" : [1 , 10 , 10 ], "dtype" : "float32" , "layout" : "ACD" },
1397+ {"name" : "tuple_1" , "shape" : [1 , 10 , 10 ], "dtype" : "float32" , "layout" : "ACD" },
1398+ {"name" : "tuple_2" , "shape" : [1 , 10 , 10 ], "dtype" : "float32" , "layout" : "ACD" },
1399+ ],
1400+ "nodes" : {"total" : 9 , "input" : 1 , "split" : 1 , "get_item" : 3 , "squeeze" : 3 , "tuple" : 1 },
1401+ }
1402+
13641403 input_info = [([1 , 3 , 10 , 10 ], "float32" )]
1365- verify_model (Split (), input_info , expected )
1404+ verify_model (Unbind (), input_info , expected )
13661405
13671406
13681407def test_cumsum ():
@@ -1547,10 +1586,14 @@ def forward(self, x):
15471586def test_expand ():
15481587 """test graph builder for expand"""
15491588
1550- class Expand (Module ):
1589+ class Expand1 (Module ):
15511590 def forward (self , x ):
15521591 return x .expand (4 , 2 , 3 , 4 )
15531592
1593+ class Expand2 (Module ):
1594+ def forward (self , x ):
1595+ return x .expand (4 , - 1 , - 1 , 4 )
1596+
15541597 expected = {
15551598 "inputs" : [{"name" : "inp_0" , "shape" : [1 , 2 , 3 , 4 ], "dtype" : "float32" , "layout" : "" }],
15561599 "outputs" : [
@@ -1560,7 +1603,8 @@ def forward(self, x):
15601603 }
15611604
15621605 input_info = [([1 , 2 , 3 , 4 ], "float32" )]
1563- verify_model (Expand (), input_info , expected )
1606+ verify_model (Expand1 (), input_info , expected )
1607+ verify_model (Expand2 (), input_info , expected )
15641608
15651609
15661610def test_reduce ():
0 commit comments