@@ -120,6 +120,29 @@ def test_out(self):
120120 np .testing .assert_allclose (ex_x2 , r2 , rtol = 1e-05 )
121121
122122
123+ class API_TestChunkZeroSize1 (unittest .TestCase ):
124+ def test_out (self ):
125+ with base .program_guard (base .Program (), base .Program ()):
126+ data1 = paddle .static .data (
127+ 'data1' , shape = [0 , 1 , 1 , 4 ], dtype = 'float32'
128+ )
129+ x0 , x1 , x2 , x3 = paddle .chunk (data1 , chunks = 4 , axis = - 1 )
130+ place = paddle .CPUPlace ()
131+ exe = paddle .static .Executor (place )
132+ input1 = np .random .random ([0 , 1 , 1 , 4 ]).astype ('float32' )
133+ (
134+ r0 ,
135+ r1 ,
136+ r2 ,
137+ r3 ,
138+ ) = exe .run (feed = {"data1" : input1 }, fetch_list = [x0 , x1 , x2 , x3 ])
139+ ex_x0 , ex_x1 , ex_x2 , ex_x3 = np .array_split (input1 , 4 , axis = - 1 )
140+ np .testing .assert_allclose (ex_x0 , r0 , rtol = 1e-05 )
141+ np .testing .assert_allclose (ex_x1 , r1 , rtol = 1e-05 )
142+ np .testing .assert_allclose (ex_x2 , r2 , rtol = 1e-05 )
143+ np .testing .assert_allclose (ex_x3 , r3 , rtol = 1e-05 )
144+
145+
123146class API_TestDygraphChunk (unittest .TestCase ):
124147 def test_out1 (self ):
125148 with base .dygraph .guard ():
0 commit comments