Skip to content

Commit 67e2e22

Browse files
committed
Fix
1 parent a494615 commit 67e2e22

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

paddle/phi/kernels/cpu/allclose_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void AllCloseKernel(const Context& dev_ctx,
3131
const Scalar& atol,
3232
bool equal_nan,
3333
DenseTensor* out) {
34+
if (out && out->numel() == 0) {
35+
dev_ctx.template Alloc<bool>(out);
36+
return;
37+
}
3438
double rtol_v = NAN, atol_v = NAN;
3539
if (rtol.dtype() == DataType::FLOAT64) {
3640
rtol_v = rtol.to<double>();

paddle/phi/kernels/gpu/allclose_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ void AllCloseKernel(const Context& dev_ctx,
6666
const Scalar& atol,
6767
bool equal_nan,
6868
DenseTensor* out) {
69+
if (out && out->numel() == 0) {
70+
dev_ctx.template Alloc<bool>(out);
71+
return;
72+
}
6973
double rtol_v, atol_v;
7074
if (rtol.dtype() == DataType::FLOAT64) {
7175
rtol_v = rtol.to<double>();

test/legacy_test/test_allclose_op.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,5 +463,38 @@ def set_args(self):
463463
self.equal_nan = False
464464

465465

466+
def create_test_class(op_type, dtype, shape):
467+
class Cls(unittest.TestCase):
468+
def test_zero_size(self):
469+
paddle.disable_static()
470+
numpy_tensor_1 = np.random.rand(*shape).astype(dtype)
471+
numpy_tensor_2 = numpy_tensor_1.copy()
472+
paddle_x = paddle.to_tensor(numpy_tensor_1)
473+
paddle_x.stop_gradient = False
474+
paddle_y = paddle.to_tensor(numpy_tensor_2)
475+
paddle_y.stop_gradient = False
476+
477+
paddle_api = eval(f"paddle.{op_type}")
478+
paddle_out = paddle_api(paddle_x, paddle_y)
479+
numpy_api = eval(f"np.{op_type}")
480+
numpy_out = numpy_api(numpy_tensor_1, numpy_tensor_2)
481+
482+
np.testing.assert_allclose(
483+
paddle_out.numpy(),
484+
numpy_out,
485+
1e-2,
486+
1e-2,
487+
)
488+
489+
cls_name = f"{op_type}{dtype}_0SizeTest"
490+
Cls.__name__ = cls_name
491+
globals()[cls_name] = Cls
492+
493+
494+
create_test_class("allclose", "float32", [3, 4, 0])
495+
create_test_class("allclose", "float64", [3, 4, 0, 3, 4])
496+
create_test_class("allclose", "int32", [3, 4, 0])
497+
create_test_class("allclose", "int64", [3, 4, 0, 3, 4])
498+
466499
if __name__ == "__main__":
467500
unittest.main()

0 commit comments

Comments
 (0)