Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/phi/kernels/cpu/prod_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

namespace phi {
Expand All @@ -28,6 +29,13 @@ void ProdKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
if (x.numel() == 0) {
// fill with 1.
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
return;
}

reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/impl/prod_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ void ProdGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/kps/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ void ProdKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
if (x.numel() == 0) {
// fill with 1.
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
return;
}

reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/xpu/prod_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/xpu/reduce.h"

namespace phi {
Expand All @@ -28,6 +29,12 @@ void ProdKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
return;
}

reduce_all = recompute_reduce_all(x, dims, reduce_all);
using XPUType = typename XPUTypeTrait<T>::Type;

Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5012,6 +5012,10 @@ def prod(
if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)

# axis is 0-size tensor.
if paddle.is_tensor(axis) and axis.shape == [0]:
return x

reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
if in_dynamic_or_pir_mode():
return _C_ops.prod(x, axis, keepdim, reduce_all)
Expand Down
34 changes: 34 additions & 0 deletions test/legacy_test/test_prod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,39 @@ def init_data(self):
]


class TestProdOp_ZeroSize(unittest.TestCase):
def setUp(self):
self.input = np.random.random(size=(10, 0, 5)).astype(np.float32)

def run_imperative(self, place):
input = paddle.to_tensor(self.input, place=place)
input.stop_gradient = False
out = paddle.prod(input)
expected_result = np.prod(self.input)
np.testing.assert_allclose(out.numpy(), expected_result, rtol=1e-05)
out.sum().backward()
np.testing.assert_allclose(input.grad.shape, input.shape)

def test_cpu(self):
with dygraph_guard():
self.run_imperative(place=paddle.CPUPlace())

def test_gpu(self):
if not paddle.base.core.is_compiled_with_cuda():
return
with dygraph_guard():
self.run_imperative(place=paddle.CUDAPlace(0))


class TestProdOp_ZeroSize2(TestProdOp_ZeroSize):
def setUp(self):
self.input = np.random.random(size=(10, 1, 5)).astype(np.float32)

def run_imperative(self, place):
input = paddle.to_tensor(self.input, place=place)
out = paddle.prod(input, paddle.randn([0]).astype(paddle.int32))
np.testing.assert_allclose(out.numpy(), input.numpy())


if __name__ == "__main__":
unittest.main()
Loading