Skip to content

Commit e991019

Browse files
authored
[0-size Tensor No.43、131] Add 0-size Tensor support for paddle.diag/masked_fill (#73570)
1 parent 8e5a730 commit e991019

File tree

13 files changed

+120
-8
lines changed

13 files changed

+120
-8
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,9 @@ bool DiagOpInferSymbolicShape(pir::Operation *op,
899899
size_ = x_shape[1].dyn_cast<int64_t>();
900900
}
901901
}
902+
if (size_ < 0) {
903+
size_ = 0;
904+
}
902905
infer_context->SetShapeOrDataForValue(
903906
op->result(0), symbol::TensorShapeOrDataDimExprs({size_}));
904907
} else {

paddle/phi/infermeta/unary.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,9 @@ void DiagInferMeta(const MetaTensor& x,
955955
size_ = x_dims[1];
956956
}
957957
}
958+
if (size_ < 0) {
959+
size_ = 0;
960+
}
958961
out->set_dims({size_});
959962
out->set_dtype(x.dtype());
960963
} else {

paddle/phi/kernels/cpu/diag_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void DiagGradKernel(const Context& dev_ctx,
2828
int offset,
2929
DenseTensor* x_grad) {
3030
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
31+
if (x_grad && x_grad->numel() == 0) return;
3132
const T* dout_data = out_grad.data<T>();
3233
auto dx_dims = x_grad->dims();
3334
auto dout_dims = out_grad.dims();

paddle/phi/kernels/cpu/diag_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ void DiagKernel(const Context& dev_ctx,
3030
auto* x_data = x.data<T>();
3131
auto x_dims = x.dims();
3232
T* out_data = dev_ctx.template Alloc<T>(out);
33+
if (out && out->numel() == 0) return;
3334
auto out_dims = out->dims();
3435

3536
int64_t i = 0;

paddle/phi/kernels/cpu/masked_fill_grad_kernel.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
#include "paddle/phi/kernels/empty_kernel.h"
2121
#include "paddle/phi/kernels/expand_grad_kernel.h"
2222
#include "paddle/phi/kernels/expand_kernel.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
2324
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
2425
#include "paddle/phi/kernels/funcs/common_shape.h"
25-
2626
namespace phi {
2727

2828
template <typename T, typename Context>
@@ -33,6 +33,19 @@ void MaskedFillGradKernel(const Context& dev_ctx,
3333
const DenseTensor& out_grad,
3434
DenseTensor* x_grad,
3535
DenseTensor* v_grad) {
36+
if (out_grad.numel() == 0 || mask.numel() == 0) {
37+
// x shape [2, 1, 3], mask shape [2, 0, 3], x_grad shape [2, 1, 3]
38+
if (x_grad) {
39+
phi::Full<T, Context>(
40+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
41+
}
42+
if (v_grad) {
43+
phi::Full<T, Context>(
44+
dev_ctx, phi::IntArray(common::vectorize(v_grad->dims())), 0, v_grad);
45+
}
46+
return;
47+
}
48+
3649
auto x_grad_dims = x_grad->dims();
3750
auto mask_dims = mask.dims();
3851
bool expand_x = false;

paddle/phi/kernels/cpu/masked_fill_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ void MaskedFillKernel(const Context& dev_ctx,
2929
const DenseTensor& mask,
3030
const DenseTensor& value,
3131
DenseTensor* out) {
32+
if (x.numel() == 0 || mask.numel() == 0) {
33+
dev_ctx.template Alloc<T>(out);
34+
return;
35+
}
36+
3237
auto x_dims = x.dims();
3338
auto mask_dims = mask.dims();
3439

paddle/phi/kernels/gpu/diag_grad_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ void DiagGradKernel(const Context& dev_ctx,
5757
int offset,
5858
DenseTensor* x_grad) {
5959
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
60+
if (x_grad && x_grad->numel() == 0) return;
6061
auto* dout_data = out_grad.data<T>();
6162
auto dx_dims = x_grad->dims();
6263
auto dout_dims = out_grad.dims();

paddle/phi/kernels/gpu/diag_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ void DiagKernel(const Context& dev_ctx,
6363
auto* x_data = x.data<T>();
6464
auto x_dims = x.dims();
6565
T* out_data = dev_ctx.template Alloc<T>(out);
66+
if (out && out->numel() == 0) return;
6667
auto out_dims = out->dims();
6768

6869
auto GetBlockGridSize = [&dev_ctx](int64_t size) {

paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,14 @@ void MaskedFillGradKernel(const Context& dev_ctx,
307307
DenseTensor* x_grad,
308308
DenseTensor* v_grad) {
309309
if (out_grad.numel() == 0 || mask.numel() == 0) {
310-
if (x_grad != nullptr) {
311-
x_grad->Resize({0});
312-
dev_ctx.template Alloc<T>(x_grad);
310+
// x shape [2, 1, 3], mask shape [2, 0, 3], x_grad shape [2, 1, 3]
311+
if (x_grad) {
312+
phi::Full<T, Context>(
313+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
313314
}
314-
if (v_grad != nullptr) {
315-
v_grad->Resize({0});
316-
dev_ctx.template Alloc<T>(v_grad);
315+
if (v_grad) {
316+
phi::Full<T, Context>(
317+
dev_ctx, phi::IntArray(common::vectorize(v_grad->dims())), 0, v_grad);
317318
}
318319
return;
319320
}

paddle/phi/kernels/gpu/masked_fill_kernel.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ void MaskedFillKernel(const Context& dev_ctx,
211211
const DenseTensor& value,
212212
DenseTensor* out) {
213213
if (x.numel() == 0 || mask.numel() == 0) {
214-
out->Resize({0});
215214
dev_ctx.template Alloc<T>(out);
216215
return;
217216
}

0 commit comments

Comments
 (0)