Skip to content

Commit 2255259

Browse files
authored
[0-size Tensor No.31、293] Add 0-size Tensor test case for chunk (#73616)
* fix * fix * fix * fix * fix * fix * fix
1 parent 2ce32c9 commit 2255259

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

test/legacy_test/test_chunk_op.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,29 @@ def test_out(self):
120120
np.testing.assert_allclose(ex_x2, r2, rtol=1e-05)
121121

122122

123+
class API_TestChunkZeroSize1(unittest.TestCase):
124+
def test_out(self):
125+
with base.program_guard(base.Program(), base.Program()):
126+
data1 = paddle.static.data(
127+
'data1', shape=[0, 1, 1, 4], dtype='float32'
128+
)
129+
x0, x1, x2, x3 = paddle.chunk(data1, chunks=4, axis=-1)
130+
place = paddle.CPUPlace()
131+
exe = paddle.static.Executor(place)
132+
input1 = np.random.random([0, 1, 1, 4]).astype('float32')
133+
(
134+
r0,
135+
r1,
136+
r2,
137+
r3,
138+
) = exe.run(feed={"data1": input1}, fetch_list=[x0, x1, x2, x3])
139+
ex_x0, ex_x1, ex_x2, ex_x3 = np.array_split(input1, 4, axis=-1)
140+
np.testing.assert_allclose(ex_x0, r0, rtol=1e-05)
141+
np.testing.assert_allclose(ex_x1, r1, rtol=1e-05)
142+
np.testing.assert_allclose(ex_x2, r2, rtol=1e-05)
143+
np.testing.assert_allclose(ex_x3, r3, rtol=1e-05)
144+
145+
123146
class API_TestDygraphChunk(unittest.TestCase):
124147
def test_out1(self):
125148
with base.dygraph.guard():

0 commit comments

Comments
 (0)