Skip to content

Commit 2d933a7

Browse files
committed
fix 0 size Tensor for expand kernel in onednn
1 parent 5883eef commit 2d933a7

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

paddle/phi/kernels/onednn/expand_kernel.cc

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,51 @@ void ExpandKernel(const Context& dev_ctx,
3939
auto x_vec_dims = common::vectorize(x.dims());
4040

4141
auto out_new_dims = shape.GetData();
42+
bool has_zero_size = false;
4243

4344
for (size_t i = 0; i < out_new_dims.size(); ++i) {
44-
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
45+
out_new_dims[i] = out_new_dims[i] >= 0 ? out_new_dims[i] : x_vec_dims[i];
4546
}
4647

4748
if (x_vec_dims.size() != out_new_dims.size()) {
4849
x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size()); // NOLINT
4950
}
5051

52+
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
53+
PADDLE_ENFORCE_GE(
54+
out_new_dims[i],
55+
0,
56+
common::errors::InvalidArgument(
57+
"The expanded size (%d) for non-existing dimensions must be "
58+
"positive for expand_v2 op.",
59+
out_new_dims[i]));
60+
61+
PADDLE_ENFORCE_GE(
62+
x_vec_dims[i],
63+
0,
64+
common::errors::InvalidArgument(
65+
"The expanded size (%d) for non-existing dimensions must be "
66+
"positive for expand_v2 op.",
67+
x_vec_dims[i]));
68+
69+
PADDLE_ENFORCE_EQ(
70+
x_vec_dims[i] == 1 || x_vec_dims[i] == out_new_dims[i],
71+
true,
72+
common::errors::InvalidArgument(
73+
"The value (%d) of the non-singleton dimension does not match"
74+
" the corresponding value (%d) in shape for expand_v2 op.",
75+
x_vec_dims[i],
76+
out_new_dims[i]));
77+
if (out_new_dims[i] == 0) {
78+
has_zero_size = true;
79+
}
80+
}
81+
5182
out->Resize(common::make_ddim(out_new_dims));
83+
if (has_zero_size) {
84+
dev_ctx.template Alloc<T>(out);
85+
return;
86+
}
5287
funcs::BroadcastDataOneDNNHandler<T> handler(dnnl::algorithm::binary_add,
5388
onednn_engine,
5489
dev_ctx.GetPlace(),

test/legacy_test/test_expand_v2_op.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,48 @@ def test_value_list_shape2(self):
684684
x = paddle.expand(x, shape=[shape1, 1, -1, -1])
685685
np.testing.assert_equal(tuple(x.shape), (-1, 1, -1, -1))
686686

687+
class TestExpandV2OneDNNOp(OpTest):
688+
def setUp(self):
689+
self.op_type = "expand_v2"
690+
self.init_data()
691+
self.x = np.random.random(self.ori_shape).astype("float32")
692+
self.attrs = {'shape': self.shape, 'use_mkldnn': True}
693+
self.set_inputs()
694+
self.set_additional_inputs()
695+
output = np.zeros(self.expect_shape).astype("float32")
696+
self.outputs = {'Out': output}
697+
698+
def set_inputs(self):
699+
self.inputs = {'X': self.x}
700+
701+
def set_additional_inputs(self):
702+
pass
703+
704+
def init_data(self):
705+
self.ori_shape = [1, 1, 1, 140]
706+
self.shape = [2, 3, 0, 140]
707+
self.expect_shape = [2, 3, 0, 140]
708+
709+
def test_check_output(self):
710+
self.check_output_with_place(core.CPUPlace(), check_pir_onednn=True,check_dygraph=False)
711+
712+
# def test_check_grad(self):
713+
# self.check_grad_with_place(
714+
# core.CPUPlace(), ["X"], "Out", check_pir_onednn=True, check_dygraph=False
715+
# )
716+
class TestExpandV2ZeroSizeOneDNNOp(TestExpandV2OneDNNOp):
717+
718+
def init_data(self):
719+
self.ori_shape = (1, 3)
720+
self.shape = (0, 3)
721+
self.expect_shape = (0, 3)
722+
723+
class TestExpandV2ZeroSizeOneDNNOp2(TestExpandV2OneDNNOp):
724+
725+
def init_data(self):
726+
self.ori_shape = (1, 3)
727+
self.shape = (1, 0, 3)
728+
self.expect_shape = (1, 0, 3)
687729

688730
if __name__ == "__main__":
689731
paddle.enable_static()

0 commit comments

Comments
 (0)