Skip to content

Commit 2fcecfb

Browse files
committed
expand_grad support 0 size Tensor
1 parent fd73310 commit 2fcecfb

File tree

4 files changed

+104
-17
lines changed

4 files changed

+104
-17
lines changed

paddle/phi/kernels/gpu/expand_grad_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/phi/backends/gpu/gpu_context.h"
1818
#include "paddle/phi/core/dense_tensor.h"
1919
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/reduce_function.h"
2122
#include "paddle/phi/kernels/reduce_sum_kernel.h"
2223

@@ -29,6 +30,11 @@ void ExpandGradKernel(const Context& ctx,
2930
const IntArray& shape,
3031
DenseTensor* x_grad) {
3132
ctx.template Alloc<T>(x_grad);
33+
if ((x_grad && x_grad->numel() == 0) || out_grad.numel() == 0) {
34+
phi::Full<T, Context>(
35+
ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
36+
return;
37+
}
3238
if (x_grad->dims() == out_grad.dims()) {
3339
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
3440
} else {

paddle/phi/kernels/impl/expand_grad_kernel_impl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/phi/core/tensor_utils.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
1819
#include "paddle/phi/kernels/funcs/eigen/common.h"
1920
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
2021
#include "paddle/phi/kernels/impl/expand_kernel_impl.h"
@@ -54,6 +55,12 @@ void ExpandGradKernel(const Context& ctx,
5455
DenseTensor* in_grad) {
5556
auto expand_shape = shape.GetData();
5657
auto x_dims = x.dims();
58+
if (out_grad.numel() == 0 || (in_grad && in_grad->numel() == 0)) {
59+
ctx.template Alloc<T>(in_grad);
60+
phi::Full<T, Context>(
61+
ctx, phi::IntArray(common::vectorize(in_grad->dims())), 0, in_grad);
62+
return;
63+
}
5764

5865
if (in_grad->dims() == out_grad.dims()) {
5966
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, in_grad);

paddle/phi/kernels/onednn/expand_grad_kernel.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/onednn/onednn_reuse.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920

2021
namespace phi {
2122
template <typename T, typename Context>
@@ -26,6 +27,13 @@ void ExpandGradKernel(const Context& dev_ctx,
2627
DenseTensor* in_grad) {
2728
const auto& onednn_engine = dev_ctx.GetEngine();
2829

30+
if ((in_grad && in_grad->numel() == 0) || out_grad.numel() == 0) {
31+
dev_ctx.template Alloc<T>(in_grad);
32+
phi::Full<T, Context>(
33+
dev_ctx, phi::IntArray(common::vectorize(in_grad->dims())), 0, in_grad);
34+
return;
35+
}
36+
2937
auto in_grad_vec_dims = common::vectorize(in_grad->dims());
3038
auto out_grad_vec_dims = common::vectorize(out_grad.dims());
3139

test/legacy_test/test_expand_v2_op.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -684,12 +684,18 @@ def test_value_list_shape2(self):
684684
x = paddle.expand(x, shape=[shape1, 1, -1, -1])
685685
np.testing.assert_equal(tuple(x.shape), (-1, 1, -1, -1))
686686

687+
687688
class TestExpandV2OneDNNOp(OpTest):
688689
def setUp(self):
689690
self.op_type = "expand_v2"
690691
self.init_data()
691-
self.x = np.random.random(self.ori_shape).astype("float32")
692-
self.attrs = {'shape': self.shape, 'use_mkldnn': True}
692+
self.python_api = paddle.expand
693+
self.x = np.zeros(self.ori_shape).astype("float32")
694+
self.attrs = {
695+
'shape': self.shape,
696+
'use_mkldnn': True,
697+
'dtype': int(paddle.float32),
698+
}
693699
self.set_inputs()
694700
self.set_additional_inputs()
695701
output = np.zeros(self.expect_shape).astype("float32")
@@ -702,30 +708,90 @@ def set_additional_inputs(self):
702708
pass
703709

704710
def init_data(self):
705-
self.ori_shape = [1, 1, 1, 140]
706-
self.shape = [2, 3, 0, 140]
707-
self.expect_shape = [2, 3, 0, 140]
711+
self.ori_shape = [1, 0, 1, 140]
712+
self.shape = [1, 0, 1, 140]
713+
self.expect_shape = [1, 0, 1, 140]
708714

709715
def test_check_output(self):
710-
self.check_output_with_place(core.CPUPlace(), check_pir_onednn=True,check_dygraph=False)
711-
712-
# def test_check_grad(self):
713-
# self.check_grad_with_place(
714-
# core.CPUPlace(), ["X"], "Out", check_pir_onednn=True, check_dygraph=False
715-
# )
716+
self.check_output_with_place(
717+
core.CPUPlace(), check_pir_onednn=True, check_dygraph=False
718+
)
719+
720+
def test_check_grad(self):
721+
self.check_grad_with_place(
722+
core.CPUPlace(),
723+
["X"],
724+
"Out",
725+
check_pir_onednn=True,
726+
check_dygraph=False,
727+
)
728+
729+
716730
class TestExpandV2ZeroSizeOneDNNOp(TestExpandV2OneDNNOp):
717731

718732
def init_data(self):
719-
self.ori_shape = (1, 3)
720-
self.shape = (0, 3)
721-
self.expect_shape = (0, 3)
733+
self.ori_shape = (0, 130)
734+
self.shape = (4, 0, 130)
735+
self.expect_shape = (4, 0, 130)
736+
722737

723738
class TestExpandV2ZeroSizeOneDNNOp2(TestExpandV2OneDNNOp):
724739

725740
def init_data(self):
726-
self.ori_shape = (1, 3)
727-
self.shape = (1, 0, 3)
728-
self.expect_shape = (1, 0, 3)
741+
self.ori_shape = (0, 1, 8)
742+
self.shape = (0, 8, 8)
743+
self.expect_shape = (0, 8, 8)
744+
745+
746+
class TestExpandV2GPUOp(TestExpandV2OneDNNOp):
747+
def test_check_output(self):
748+
self.check_output_with_place(core.CUDAPlace(0), check_dygraph=True)
749+
750+
def test_check_grad(self):
751+
if core.is_compiled_with_cuda():
752+
self.check_grad_with_place(
753+
core.CUDAPlace(0), ["X"], "Out", check_dygraph=True
754+
)
755+
756+
757+
class TestExpandV2ZeroSizeGPUOp(TestExpandV2GPUOp):
758+
def init_data(self):
759+
self.ori_shape = (0, 130)
760+
self.shape = (4, 0, 130)
761+
self.expect_shape = (4, 0, 130)
762+
763+
764+
class TestExpandV2ZeroSizeGPUOp2(TestExpandV2GPUOp):
765+
def init_data(self):
766+
self.ori_shape = (0, 1)
767+
self.shape = (0, 8)
768+
self.expect_shape = (0, 8)
769+
770+
771+
class TestExpandV2CPUOp(TestExpandV2OneDNNOp):
772+
def test_check_output(self):
773+
self.check_output_with_place(core.CPUPlace(), check_dygraph=True)
774+
775+
def test_check_grad(self):
776+
if core.is_compiled_with_cuda():
777+
self.check_grad_with_place(
778+
core.CPUPlace(), ["X"], "Out", check_dygraph=True
779+
)
780+
781+
782+
class TestExpandV2CPUOp1(TestExpandV2CPUOp):
783+
def init_data(self):
784+
self.ori_shape = (0, 1)
785+
self.shape = (0, 8)
786+
self.expect_shape = (0, 8)
787+
788+
789+
class TestExpandV2CPUOp2(TestExpandV2CPUOp):
790+
def init_data(self):
791+
self.ori_shape = (0, 130)
792+
self.shape = (4, 0, 130)
793+
self.expect_shape = (4, 0, 130)
794+
729795

730796
if __name__ == "__main__":
731797
paddle.enable_static()

0 commit comments

Comments
 (0)