|
25 | 25 | #include "paddle/phi/core/kernel_registry.h" |
26 | 26 | #include "paddle/phi/core/tensor_utils.h" |
27 | 27 | #include "paddle/phi/kernels/determinant_kernel.h" |
| 28 | +#include "paddle/phi/kernels/full_kernel.h" |
28 | 29 | #include "paddle/phi/kernels/funcs/blas/blas.h" |
29 | 30 | #include "paddle/phi/kernels/impl/determinant_kernel_impl.h" |
30 | 31 | #include "paddle/phi/kernels/slogdeterminant_kernel.h" |
31 | | - |
32 | 32 | namespace phi { |
33 | 33 |
|
34 | 34 | // T is not complex |
@@ -215,23 +215,10 @@ void SlogDeterminantKernel(const Context& dev_ctx, |
215 | 215 | auto input_dim = common::vectorize(x.dims()); |
216 | 216 | auto input_dim_size = input_dim.size(); |
217 | 217 |
|
218 | | - // shape [*, M, M], check whether it contains 0 in '*'. |
219 | | - if (input_dim.size() > 2) { |
220 | | - bool size_0 = false; |
221 | | - std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2); |
222 | | - for (size_t i = 0; i < tmp_dim_vec.size(); ++i) { |
223 | | - if (tmp_dim_vec[i] == 0) { |
224 | | - size_0 = true; |
225 | | - break; |
226 | | - } |
227 | | - } |
228 | | - if (size_0) { |
229 | | - tmp_dim_vec.insert(tmp_dim_vec.begin(), |
230 | | - 2); // make the output dims as same as numpy |
231 | | - out->Resize(common::make_ddim(tmp_dim_vec)); |
232 | | - dev_ctx.template Alloc<T>(out); |
233 | | - return; |
234 | | - } |
| 218 | + if (x.numel() == 0) { |
| 219 | + phi::Full<T, Context>( |
| 220 | + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); |
| 221 | + return; |
235 | 222 | } |
236 | 223 |
|
237 | 224 | auto batch_count = detail::GetBatchCount(x.dims()); |
|
0 commit comments