Skip to content

Commit 8844eaa

Browse files
committed
Fix
1 parent 0437ba9 commit 8844eaa

File tree

6 files changed

+52
-36
lines changed

6 files changed

+52
-36
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6263,6 +6263,29 @@ void StraightThroughEstimatorInferMeta(const MetaTensor& out_grad,
62636263
x_grad->set_dtype(out_grad.dtype());
62646264
}
62656265

6266+
void SlogdetInferMeta(const MetaTensor& x, MetaTensor* out) {
6267+
out->set_dims(x.dims());
6268+
out->set_dtype(x.dtype());
6269+
6270+
auto input_dim = common::vectorize(x.dims());
6271+
// shape [*, M, M], check whether it contains 0 in '*'.
6272+
if (input_dim.size() > 2) {
6273+
bool size_0 = false;
6274+
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
6275+
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
6276+
if (tmp_dim_vec[i] == 0) {
6277+
size_0 = true;
6278+
break;
6279+
}
6280+
}
6281+
if (size_0) {
6282+
tmp_dim_vec.insert(tmp_dim_vec.begin(),
6283+
2); // make the output dims as same as numpy
6284+
out->set_dims(common::make_ddim(tmp_dim_vec));
6285+
}
6286+
}
6287+
}
6288+
62666289
void NumberCountInferMeta(const MetaTensor& x,
62676290
int upper_range,
62686291
MetaTensor* out) {

paddle/phi/infermeta/unary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out);
10071007
void StraightThroughEstimatorInferMeta(const MetaTensor& out_grad,
10081008
MetaTensor* x_grad);
10091009

1010+
void SlogdetInferMeta(const MetaTensor& x, MetaTensor* out);
1011+
10101012
void LrnInferMeta(const MetaTensor& x,
10111013
int n,
10121014
MetaTensor* out,

paddle/phi/kernels/gpu/slogdeterminant_kernel.cu

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
#include "paddle/phi/core/kernel_registry.h"
2626
#include "paddle/phi/core/tensor_utils.h"
2727
#include "paddle/phi/kernels/determinant_kernel.h"
28+
#include "paddle/phi/kernels/full_kernel.h"
2829
#include "paddle/phi/kernels/funcs/blas/blas.h"
2930
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
3031
#include "paddle/phi/kernels/slogdeterminant_kernel.h"
31-
3232
namespace phi {
3333

3434
// T is not complex
@@ -215,23 +215,10 @@ void SlogDeterminantKernel(const Context& dev_ctx,
215215
auto input_dim = common::vectorize(x.dims());
216216
auto input_dim_size = input_dim.size();
217217

218-
// shape [*, M, M], check whether it contains 0 in '*'.
219-
if (input_dim.size() > 2) {
220-
bool size_0 = false;
221-
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
222-
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
223-
if (tmp_dim_vec[i] == 0) {
224-
size_0 = true;
225-
break;
226-
}
227-
}
228-
if (size_0) {
229-
tmp_dim_vec.insert(tmp_dim_vec.begin(),
230-
2); // make the output dims as same as numpy
231-
out->Resize(common::make_ddim(tmp_dim_vec));
232-
dev_ctx.template Alloc<T>(out);
233-
return;
234-
}
218+
if (x.numel() == 0) {
219+
phi::Full<T, Context>(
220+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
221+
return;
235222
}
236223

237224
auto batch_count = detail::GetBatchCount(x.dims());

paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "paddle/phi/core/enforce.h"
2424
#include "paddle/phi/core/tensor_utils.h"
25+
#include "paddle/phi/kernels/full_kernel.h"
2526
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
2627
#include "paddle/phi/kernels/slogdeterminant_kernel.h"
2728

@@ -127,23 +128,10 @@ void SlogDeterminantKernel(const Context& dev_ctx,
127128
auto input_dim = common::vectorize(x.dims());
128129
auto input_dim_size = input_dim.size();
129130

130-
// shape [*, M, M], check whether it contains 0 in '*'.
131-
if (input_dim.size() > 2) {
132-
bool size_0 = false;
133-
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
134-
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
135-
if (tmp_dim_vec[i] == 0) {
136-
size_0 = true;
137-
break;
138-
}
139-
}
140-
if (size_0) {
141-
tmp_dim_vec.insert(tmp_dim_vec.begin(),
142-
2); // make the output dims as same as numpy
143-
out->Resize(common::make_ddim(tmp_dim_vec));
144-
dev_ctx.template Alloc<T>(out);
145-
return;
146-
}
131+
if (x.numel() == 0) {
132+
phi::Full<T, Context>(
133+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
134+
return;
147135
}
148136

149137
auto batch_count = detail::GetBatchCount(x.dims());

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4893,7 +4893,7 @@
48934893
args : (Tensor x)
48944894
output : Tensor
48954895
infer_meta :
4896-
func : UnchangedInferMeta
4896+
func : SlogdetInferMeta
48974897
kernel :
48984898
func : slogdet
48994899
backward : slogdet_grad

test/legacy_test/test_determinant_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ def init_data(self):
144144
self.target = np.linalg.det(self.case)
145145

146146

147+
class TestDeterminantOp_ZeroSize2(TestDeterminantOp):
148+
def init_data(self):
149+
np.random.seed(0)
150+
self.case = np.random.rand(0, 0, 0)
151+
self.inputs = {'Input': self.case}
152+
self.target = np.linalg.det(self.case)
153+
154+
147155
class TestDeterminantAPI(unittest.TestCase):
148156
def setUp(self):
149157
np.random.seed(0)
@@ -340,6 +348,14 @@ def init_data(self):
340348
self.target = np.array(np.linalg.slogdet(self.case))
341349

342350

351+
class TestSlogDeterminantOp_ZeroSize2(TestSlogDeterminantOp):
352+
def init_data(self):
353+
np.random.seed(0)
354+
self.case = np.random.rand(0, 0, 0).astype('float64')
355+
self.inputs = {'Input': self.case}
356+
self.target = np.array(np.linalg.slogdet(self.case))
357+
358+
343359
class TestSlogDeterminantAPI(unittest.TestCase):
344360
def setUp(self):
345361
np.random.seed(0)

0 commit comments

Comments
 (0)