Skip to content

Commit 9dd52b6

Browse files
committed
Fix
1 parent db92e25 commit 9dd52b6

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

paddle/phi/kernels/gpu/determinant_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ template <typename T, typename Context>
218218
void DeterminantKernel(const Context& dev_ctx,
219219
const DenseTensor& x,
220220
DenseTensor* out) {
221+
if (out && out->numel() == 0) {
222+
dev_ctx.template Alloc<T>(out);
223+
return;
224+
}
221225
auto input_dim = common::vectorize(x.dims());
222226
auto input_dim_size = input_dim.size();
223227

paddle/phi/kernels/impl/determinant_grad_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ void DeterminantGradKernel(const Context& dev_ctx,
8282
const DenseTensor& out,
8383
const DenseTensor& out_grad,
8484
DenseTensor* x_grad) {
85+
if (x_grad && x_grad->numel() == 0) {
86+
dev_ctx.template Alloc<T>(x_grad);
87+
return;
88+
}
8589
auto input_dims_size = x.dims().size();
8690
if (input_dims_size > 2) {
8791
PADDLE_ENFORCE_EQ(

paddle/phi/kernels/impl/determinant_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ template <typename T, typename Context>
136136
void DeterminantKernel(const Context& dev_ctx,
137137
const DenseTensor& x,
138138
DenseTensor* out) {
139+
if (out && out->numel() == 0) {
140+
dev_ctx.template Alloc<T>(out);
141+
return;
142+
}
139143
auto input_dim = common::vectorize(x.dims());
140144
auto input_dim_size = input_dim.size();
141145

test/legacy_test/test_determinant_op.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def init_data(self):
136136
self.target = np.linalg.det(self.case)
137137

138138

139+
class TestDeterminantOp_ZeroSize(TestDeterminantOp):
140+
def init_data(self):
141+
np.random.seed(0)
142+
self.case = np.random.rand(0, 10, 10)
143+
self.inputs = {'Input': self.case}
144+
self.target = np.linalg.det(self.case)
145+
146+
139147
class TestDeterminantAPI(unittest.TestCase):
140148
def setUp(self):
141149
np.random.seed(0)

0 commit comments

Comments
 (0)