Skip to content

Commit e40ad6d

Browse files
authored
[0-size Tensor No.256] Add 0-size Tensor support for paddle.sgn API (#73606)
* fix sgn 0size * fix as_complex 0-size
1 parent 7f4afa7 commit e40ad6d

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

paddle/phi/kernels/stride/as_complex_kernel.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ void AsComplexStridedKernel(const Context& dev_ctx,
3131
"FLAGS_use_stride_kernel is closed. Strided kernel "
3232
"be called, something wrong has happened!"));
3333
}
34+
if (out && out->numel() == 0) {
35+
if (x.dtype() == DataType::FLOAT32) {
36+
out->set_type(DataType::COMPLEX64);
37+
} else if (x.dtype() == DataType::FLOAT64) {
38+
out->set_type(DataType::COMPLEX128);
39+
} else {
40+
PADDLE_THROW(common::errors::Unimplemented(
41+
"as_complex is not supported data type (%s).",
42+
DataTypeToString(x.dtype())));
43+
}
44+
out->set_offset(x.offset());
45+
out->ResetHolder(x.Holder());
46+
out->ShareInplaceVersionCounterWith(x);
47+
return;
48+
}
3449

3550
PADDLE_ENFORCE_EQ(
3651
x.strides()[x.strides().size() - 1],

test/legacy_test/test_sgn.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ def test_float_static_and_pir(self):
143143
z_expected = np_sgn(np_x)
144144
np.testing.assert_allclose(z, z_expected, rtol=1e-05)
145145

146+
def test_zero_size_complex_dynamic(self):
147+
for dtype in ['complex64', 'complex128']:
148+
np_x = np.empty((0, 4), dtype=dtype) # 空张量 shape=[0, 4]
149+
x = paddle.to_tensor(np_x)
150+
z = paddle.sgn(x)
151+
np_z = z.numpy()
152+
z_expected = np_sgn(np_x)
153+
np.testing.assert_allclose(np_z, z_expected, rtol=1e-05)
154+
np.testing.assert_equal(np_z.shape, (0, 4))
155+
156+
def test_zero_size_complex_static_and_pir(self):
157+
with static_guard():
158+
for dtype in ['complex64', 'complex128']:
159+
exe = paddle.static.Executor()
160+
train_program = paddle.static.Program()
161+
startup_program = paddle.static.Program()
162+
with paddle.static.program_guard(
163+
train_program, startup_program
164+
):
165+
x = paddle.static.data(name='X', shape=[0, 4], dtype=dtype)
166+
z = paddle.sgn(x)
167+
168+
exe.run(startup_program)
169+
x_np = np.empty((0, 4), dtype=dtype)
170+
(z_out,) = exe.run(
171+
train_program, feed={"X": x_np}, fetch_list=[z]
172+
)
173+
z_expected = np_sgn(x_np)
174+
np.testing.assert_allclose(z_out, z_expected, rtol=1e-05)
175+
np.testing.assert_equal(z_out.shape, (0, 4))
176+
146177

147178
if __name__ == "__main__":
148179
unittest.main()

0 commit comments

Comments
 (0)