@@ -143,6 +143,37 @@ def test_float_static_and_pir(self):
143143 z_expected = np_sgn (np_x )
144144 np .testing .assert_allclose (z , z_expected , rtol = 1e-05 )
145145
146+ def test_zero_size_complex_dynamic (self ):
147+ for dtype in ['complex64' , 'complex128' ]:
148+ np_x = np .empty ((0 , 4 ), dtype = dtype ) # 空张量 shape=[0, 4]
149+ x = paddle .to_tensor (np_x )
150+ z = paddle .sgn (x )
151+ np_z = z .numpy ()
152+ z_expected = np_sgn (np_x )
153+ np .testing .assert_allclose (np_z , z_expected , rtol = 1e-05 )
154+ np .testing .assert_equal (np_z .shape , (0 , 4 ))
155+
156+ def test_zero_size_complex_static_and_pir (self ):
157+ with static_guard ():
158+ for dtype in ['complex64' , 'complex128' ]:
159+ exe = paddle .static .Executor ()
160+ train_program = paddle .static .Program ()
161+ startup_program = paddle .static .Program ()
162+ with paddle .static .program_guard (
163+ train_program , startup_program
164+ ):
165+ x = paddle .static .data (name = 'X' , shape = [0 , 4 ], dtype = dtype )
166+ z = paddle .sgn (x )
167+
168+ exe .run (startup_program )
169+ x_np = np .empty ((0 , 4 ), dtype = dtype )
170+ (z_out ,) = exe .run (
171+ train_program , feed = {"X" : x_np }, fetch_list = [z ]
172+ )
173+ z_expected = np_sgn (x_np )
174+ np .testing .assert_allclose (z_out , z_expected , rtol = 1e-05 )
175+ np .testing .assert_equal (z_out .shape , (0 , 4 ))
176+
146177
147178if __name__ == "__main__" :
148179 unittest .main ()
0 commit comments