Skip to content

Commit 5d5fb53

Browse files
authored
[0-size Tensor Job2 No.14] Add 0-size Tensor support for gammaincc (#73207)
* Fix * Fix
1 parent 2fe5382 commit 5d5fb53

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ void GammainccGradKernel(const Context& dev_ctx,
4949
const DenseTensor& y,
5050
const DenseTensor& d_out,
5151
DenseTensor* d_y) {
52+
if (d_y && d_y->numel() == 0) {
53+
dev_ctx.template Alloc<T>(d_y);
54+
return;
55+
}
5256
auto numel = d_out.numel();
5357
auto* dout_data = d_out.data<T>();
5458
auto* x_data = x.data<T>();

paddle/phi/kernels/impl/gammaincc_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ void GammainccKernel(const Context& dev_ctx,
132132
const DenseTensor& x,
133133
const DenseTensor& y,
134134
DenseTensor* out) {
135+
if (out && out->numel() == 0) {
136+
dev_ctx.template Alloc<T>(out);
137+
return;
138+
}
135139
auto numel = x.numel();
136140
auto* x_data = x.data<T>();
137141
auto* y_data = y.data<T>();

test/legacy_test/test_gammaincc_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ def setUp(self):
3232
self.op_type = 'gammaincc'
3333
self.python_api = paddle.gammaincc
3434
self.init_dtype_type()
35-
self.shape = (3, 40)
35+
self.init_shape()
3636
self.x = np.random.random(self.shape).astype(self.dtype) + 1
3737
self.y = np.random.random(self.shape).astype(self.dtype) + 1
3838
self.inputs = {'x': self.x, 'y': self.y}
3939
out = ref_gammaincc(self.x, self.y)
4040
self.outputs = {'out': out}
4141

42+
def init_shape(self):
43+
self.shape = (3, 40)
44+
4245
def init_dtype_type(self):
4346
self.dtype = np.float64
4447

@@ -127,5 +130,17 @@ def init_dtype_type(self):
127130
self.dtype = "float32"
128131

129132

133+
class TestGammainccOp_ZeroSize(TestGammainccOp):
134+
135+
def init_shape(self):
136+
self.shape = (0, 40)
137+
138+
139+
class TestGammainccOp_ZeroSize2(TestGammainccOp):
140+
141+
def init_shape(self):
142+
self.shape = (0, 0)
143+
144+
130145
if __name__ == "__main__":
131146
unittest.main()

0 commit comments

Comments
 (0)