Skip to content
3 changes: 3 additions & 0 deletions paddle/phi/kernels/stride/as_strided_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ void AsStridedGradKernel(const Context& dev_ctx,
phi::StridedTensorFill<data_t>(
*input_grad, 0, input_grad);
}));
if (out_grad.numel() == 0) {
return;
}
DenseTensor tmp;
tmp.set_meta(out_grad.meta());
AsStridedKernel<Context>(dev_ctx, *input_grad, dims, stride, offset, &tmp);
Expand Down
54 changes: 54 additions & 0 deletions paddle/phi/kernels/stride/as_strided_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,48 @@
COMMON_DECLARE_bool(use_stride_kernel);

namespace phi {
void CheckInBoundsForMemory(const std::vector<int64_t>& dims,
const std::vector<int64_t>& strides,
const DDim& output_dims,
const DDim& output_strides,
int64_t offset,
const DenseTensor& input) {
PADDLE_ENFORCE_EQ(dims.size(),
strides.size(),
common::errors::InvalidArgument(
"The size of dims and strides should be equal."));
size_t size = 1;
phi::DataType dtype = input.dtype();
for (size_t i = 0; i < dims.size(); i++) {
if (dims[i] == 0) {
return;
}
size += strides[i] * (dims[i] - 1);
}
size_t size_bytes = size * phi::SizeOf(dtype) + offset;

size_t memory_size = 0;
if (input.numel() != 0) {
size = 1;
for (int i = 0; i < input.dims().size(); i++) {
size += input.strides()[i] * (input.dims()[i] - 1);
}
memory_size = size * phi::SizeOf(dtype) + input.offset();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckInBoundsForMemory中 input.numel() != 0 这个判断为true的分支目前永远不会被执行到,覆盖率到不到要求。而且无法豁免coverage的ci,请优化下此处。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. coverage ci 已经通过


PADDLE_ENFORCE_LE(
size_bytes,
memory_size,
common::errors::InvalidArgument(
"Output tensor requires %d bytes memory (dims: [%s], strides: [%s], "
"offset: %d, dtype: %s), but input only has %d bytes available.",
size_bytes,
output_dims,
output_strides,
offset,
dtype,
memory_size));
}

template <typename Context>
void AsStridedKernel(const Context& dev_ctx,
Expand All @@ -36,6 +78,18 @@ void AsStridedKernel(const Context& dev_ctx,
meta.dims = DDim(dims.data(), static_cast<int>(dims.size()));
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
meta.offset = offset;
// Note(ooooo): Now it's to check 0-size tensor.
// Because i see in test_inplace.py to use as_strided as a noinplace op
// implementation for paddle.set_.
if (input.numel() == 0) {
CheckInBoundsForMemory(
dims, stride, meta.dims, meta.strides, offset, input);
}
PADDLE_ENFORCE_GE(
offset,
0,
common::errors::InvalidArgument(
"The offset must be non-negative, but got %d.", offset));
out->set_meta(meta);
out->ResetHolder(input.Holder());
out->ShareInplaceVersionCounterWith(input);
Expand Down
30 changes: 30 additions & 0 deletions test/legacy_test/test_as_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,35 @@ def test_as_strided_backward(self):
self.assertEqual((b.grad.numpy() == 1).all().item(), True)


class TestAsStrided_ZeroSize(unittest.TestCase):
def setUp(self):
self.places = get_places()

def test_as_strided_forward(self):
for place in self.places:
with base.dygraph.guard(place):
a = paddle.to_tensor(
np.random.random([0, 32]).astype('float32')
)
a.stop_gradient = False
b = paddle.as_strided(a, shape=(0, 4), stride=(32, 1))
np.testing.assert_equal(b.shape, [0, 4])
b.backward(paddle.ones_like(b))
np.testing.assert_equal(a.grad.shape, [0, 32])

def test_as_strided_error(self):
for place in self.places:
with base.dygraph.guard(place):
self.assertRaises(
ValueError,
paddle.as_strided,
x=paddle.to_tensor(
np.random.random([0, 32]).astype('float32')
),
shape=[3, 4],
stride=[32, 1],
)


if __name__ == '__main__':
unittest.main()
25 changes: 12 additions & 13 deletions test/legacy_test/test_narrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,18 @@ def setUp(self):
self.length = 1


# TODO(Difers) Address the 0-size issue in the as_strided operator.”
# class TestPaddleNarrowEmptyTensor(TestNarrowBase):
# def setUp(self):
# self.input_np = np.empty((0, 4), dtype='float32')
# self.input_shape = self.input_np.shape
# self.input_dtype = 'float32'
# self.op_static = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
# self.op_dygraph = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
# self.expected = lambda x: x[0:0, :]
# self.places = [None, paddle.CPUPlace()]
# self.dim = 0
# self.start = 0
# self.length = 0
class TestPaddleNarrowEmptyTensor(TestNarrowBase):
def setUp(self):
self.input_np = np.empty((0, 4), dtype='float32')
self.input_shape = self.input_np.shape
self.input_dtype = 'float32'
self.op_static = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
self.op_dygraph = lambda x: paddle.narrow(x, dim=0, start=0, length=0)
self.expected = lambda x: x[0:0, :]
self.places = [None, paddle.CPUPlace()]
self.dim = 0
self.start = 0
self.length = 0


@unittest.skipIf(paddle.device.get_device().startswith("xpu"), "Skip on XPU")
Expand Down
Loading