Skip to content

Commit 18c529a

Browse files
committed
Merge branch 'b68' into b69
2 parents 9dd52b6 + cbf2425 commit 18c529a

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

paddle/phi/kernels/gpu/slogdeterminant_kernel.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,25 @@ void SlogDeterminantKernel(const Context& dev_ctx,
215215
auto input_dim = common::vectorize(x.dims());
216216
auto input_dim_size = input_dim.size();
217217

218+
// shape [*, M, M], check whether it contains 0 in '*'.
219+
if (input_dim.size() > 2) {
220+
bool size_0 = false;
221+
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
222+
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
223+
if (tmp_dim_vec[i] == 0) {
224+
size_0 = true;
225+
break;
226+
}
227+
}
228+
if (size_0) {
229+
tmp_dim_vec.insert(tmp_dim_vec.begin(),
230+
2); // make the output dims as same as numpy
231+
out->Resize(common::make_ddim(tmp_dim_vec));
232+
dev_ctx.template Alloc<T>(out);
233+
return;
234+
}
235+
}
236+
218237
auto batch_count = detail::GetBatchCount(x.dims());
219238
VLOG(2) << "input dim:" << x.dims();
220239
PADDLE_ENFORCE_GE(

paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,
3535
const DenseTensor& out,
3636
const DenseTensor& out_grad,
3737
DenseTensor* x_grad) {
38+
if (x_grad && x_grad->numel() == 0) {
39+
dev_ctx.template Alloc<T>(x_grad);
40+
return;
41+
}
3842
PADDLE_ENFORCE_EQ(
3943
out_grad.dims()[0],
4044
2,

paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,25 @@ void SlogDeterminantKernel(const Context& dev_ctx,
127127
auto input_dim = common::vectorize(x.dims());
128128
auto input_dim_size = input_dim.size();
129129

130+
// shape [*, M, M], check whether it contains 0 in '*'.
131+
if (input_dim.size() > 2) {
132+
bool size_0 = false;
133+
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
134+
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
135+
if (tmp_dim_vec[i] == 0) {
136+
size_0 = true;
137+
break;
138+
}
139+
}
140+
if (size_0) {
141+
tmp_dim_vec.insert(tmp_dim_vec.begin(),
142+
2); // make the output dims as same as numpy
143+
out->Resize(common::make_ddim(tmp_dim_vec));
144+
dev_ctx.template Alloc<T>(out);
145+
return;
146+
}
147+
}
148+
130149
auto batch_count = detail::GetBatchCount(x.dims());
131150
VLOG(2) << "input dim:" << x.dims();
132151
PADDLE_ENFORCE_GE(

test/legacy_test/test_determinant_op.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,14 @@ def init_data(self):
332332
self.target = np.array(np.linalg.slogdet(self.case))
333333

334334

335+
class TestSlogDeterminantOp_ZeroSize(TestSlogDeterminantOp):
336+
def init_data(self):
337+
np.random.seed(0)
338+
self.case = np.random.rand(0, 5, 5).astype('float64')
339+
self.inputs = {'Input': self.case}
340+
self.target = np.array(np.linalg.slogdet(self.case))
341+
342+
335343
class TestSlogDeterminantAPI(unittest.TestCase):
336344
def setUp(self):
337345
np.random.seed(0)

0 commit comments

Comments
 (0)