Skip to content

Commit 1bec105

Browse files
ooooo-createco63oc
authored andcommitted
【Hackathon 9th No.8】Fix 0-size for as_strided grad and add bound check for as_strided (PaddlePaddle#74860)
* Fix 0-size for as_strided grad and add bound check for as_strided * fix bugs * fix bugs * fix bugs * refine code * refine code * refine code * refine code to pass coverage ci
1 parent 37b686b commit 1bec105

File tree

4 files changed

+71
-13
lines changed

4 files changed

+71
-13
lines changed

paddle/phi/kernels/stride/as_strided_grad_kernel.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ void AsStridedGradKernel(const Context& dev_ctx,
4242
phi::StridedTensorFill<data_t>(
4343
*input_grad, 0, input_grad);
4444
}));
45+
if (out_grad.numel() == 0) {
46+
return;
47+
}
4548
DenseTensor tmp;
4649
tmp.set_meta(out_grad.meta());
4750
AsStridedKernel<Context>(dev_ctx, *input_grad, dims, stride, offset, &tmp);

paddle/phi/kernels/stride/as_strided_kernel.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,26 @@
1919
COMMON_DECLARE_bool(use_stride_kernel);
2020

2121
namespace phi {
22+
void ValidateZeroSizeTensorShape(const std::vector<int64_t>& dims,
23+
const std::vector<int64_t>& strides,
24+
const DenseTensor& input) {
25+
if (input.numel() != 0) {
26+
return;
27+
}
28+
PADDLE_ENFORCE_EQ(dims.size(),
29+
strides.size(),
30+
common::errors::InvalidArgument(
31+
"The size of dims and strides should be equal."));
32+
for (size_t i = 0; i < dims.size(); i++) {
33+
if (dims[i] == 0) {
34+
return;
35+
}
36+
}
37+
38+
PADDLE_THROW(common::errors::InvalidArgument(
39+
"When input is zero-size tensor, the shape attribute must also be "
40+
"zero-size."));
41+
}
2242

2343
template <typename Context>
2444
void AsStridedKernel(const Context& dev_ctx,
@@ -36,6 +56,12 @@ void AsStridedKernel(const Context& dev_ctx,
3656
meta.dims = DDim(dims.data(), static_cast<int>(dims.size()));
3757
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
3858
meta.offset = offset;
59+
ValidateZeroSizeTensorShape(dims, stride, input);
60+
PADDLE_ENFORCE_GE(
61+
offset,
62+
0,
63+
common::errors::InvalidArgument(
64+
"The offset must be non-negative, but got %d.", offset));
3965
out->set_meta(meta);
4066
out->ResetHolder(input.Holder());
4167
out->ShareInplaceVersionCounterWith(input);

test/legacy_test/test_as_strided.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,35 @@ def test_as_strided_backward(self):
5959
self.assertEqual((b.grad.numpy() == 1).all().item(), True)
6060

6161

62+
class TestAsStrided_ZeroSize(unittest.TestCase):
63+
def setUp(self):
64+
self.places = get_places()
65+
66+
def test_as_strided_forward(self):
67+
for place in self.places:
68+
with base.dygraph.guard(place):
69+
a = paddle.to_tensor(
70+
np.random.random([0, 32]).astype('float32')
71+
)
72+
a.stop_gradient = False
73+
b = paddle.as_strided(a, shape=(0, 4), stride=(32, 1))
74+
np.testing.assert_equal(b.shape, [0, 4])
75+
b.backward(paddle.ones_like(b))
76+
np.testing.assert_equal(a.grad.shape, [0, 32])
77+
78+
def test_as_strided_error(self):
79+
for place in self.places:
80+
with base.dygraph.guard(place):
81+
self.assertRaises(
82+
ValueError,
83+
paddle.as_strided,
84+
x=paddle.to_tensor(
85+
np.random.random([0, 32]).astype('float32')
86+
),
87+
shape=[3, 4],
88+
stride=[32, 1],
89+
)
90+
91+
6292
if __name__ == '__main__':
6393
unittest.main()

test/legacy_test/test_narrow.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,18 @@ def setUp(self):
267267
self.length = 1
268268

269269

270-
# TODO(Difers) Address the 0-size issue in the as_strided operator.”
271-
# class TestPaddleNarrowEmptyTensor(TestNarrowBase):
272-
# def setUp(self):
273-
# self.input_np = np.empty((0, 4), dtype='float32')
274-
# self.input_shape = self.input_np.shape
275-
# self.input_dtype = 'float32'
276-
# self.op_static = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
277-
# self.op_dygraph = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
278-
# self.expected = lambda x: x[0:0, :]
279-
# self.places = [None, paddle.CPUPlace()]
280-
# self.dim = 0
281-
# self.start = 0
282-
# self.length = 0
270+
class TestPaddleNarrowEmptyTensor(TestNarrowBase):
271+
def setUp(self):
272+
self.input_np = np.empty((0, 4), dtype='float32')
273+
self.input_shape = self.input_np.shape
274+
self.input_dtype = 'float32'
275+
self.op_static = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
276+
self.op_dygraph = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
277+
self.expected = lambda x: x[0:0, :]
278+
self.places = [None, paddle.CPUPlace()]
279+
self.dim = 0
280+
self.start = 0
281+
self.length = 0
283282

284283

285284
@unittest.skipIf(paddle.device.get_device().startswith("xpu"), "Skip on XPU")

0 commit comments

Comments
 (0)