Skip to content

Commit f74615f

Browse files
authored
Fix (#73170)
1 parent 86f1498 commit f74615f

File tree

6 files changed

+127
-26
lines changed

6 files changed

+127
-26
lines changed

paddle/phi/infermeta/ternary.cc

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,29 +98,6 @@ void AddmmInferMeta(const MetaTensor& input,
9898
<< " alpha=" << alpha << " ndim_input=" << ndim_input
9999
<< " ndim_x=" << ndim_x << " ndim_y=" << ndim_y;
100100

101-
PADDLE_ENFORCE_NE(
102-
product(input_dims),
103-
0,
104-
errors::PreconditionNotMet("The Input variable 'input' has not "
105-
"been initialized. You may need to confirm "
106-
"if you put exe.run(startup_program) "
107-
"after optimizer.minimize function."));
108-
109-
PADDLE_ENFORCE_NE(
110-
product(x_dims),
111-
0,
112-
errors::PreconditionNotMet("The Input variable 'x' has not "
113-
"been initialized. You may need to confirm "
114-
"if you put exe.run(startup_program) "
115-
"after optimizer.minimize function."));
116-
117-
PADDLE_ENFORCE_NE(
118-
product(y_dims),
119-
0,
120-
errors::PreconditionNotMet("The Input variable 'y' has not "
121-
"been initialized. You may need to confirm "
122-
"if you put exe.run(startup_program) "
123-
"after optimizer.minimize function."));
124101
// dim check
125102
PADDLE_ENFORCE_EQ(ndim_input == 2 || ndim_input == 1,
126103
true,

paddle/phi/kernels/impl/addmm_grad_kernel_impl.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020

2121
#include "paddle/phi/common/amp_type_traits.h"
2222
#include "paddle/phi/kernels/addmm_grad_kernel.h"
23+
#include "paddle/phi/kernels/full_kernel.h"
2324
#include "paddle/phi/kernels/funcs/blas/blas.h"
2425
#include "paddle/phi/kernels/funcs/eigen/common.h"
2526
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
@@ -66,6 +67,24 @@ void AddmmGradKernel(const Context& dev_ctx,
6667
DenseTensor* input_grad,
6768
DenseTensor* x_grad,
6869
DenseTensor* y_grad) {
70+
if (out_grad.numel() == 0) {
71+
if (input_grad) {
72+
phi::Full<T, Context>(
73+
dev_ctx,
74+
phi::IntArray(common::vectorize(input_grad->dims())),
75+
0,
76+
input_grad);
77+
}
78+
if (x_grad) {
79+
phi::Full<T, Context>(
80+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
81+
}
82+
if (y_grad) {
83+
phi::Full<T, Context>(
84+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
85+
}
86+
return;
87+
}
6988
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
7089
bool is_float16_or_bfloat16 = false;
7190
if (std::is_same<T, phi::dtype::float16>::value ||
@@ -166,6 +185,18 @@ void AddmmGradKernel(const Context& dev_ctx,
166185
input_grad->Resize(input.dims());
167186
}
168187
}
188+
if (x_grad && x_grad->numel() == 0) {
189+
dev_ctx.template Alloc<T>(x_grad);
190+
phi::Full<T, Context>(
191+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
192+
return;
193+
}
194+
if (y_grad && y_grad->numel() == 0) {
195+
dev_ctx.template Alloc<T>(y_grad);
196+
phi::Full<T, Context>(
197+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
198+
return;
199+
}
169200
if (x_grad) {
170201
dev_ctx.template Alloc<T>(x_grad);
171202
total_elems = x.dims()[0] * x.dims()[1];

paddle/phi/kernels/impl/addmm_kernel_impl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ void AddmmKernel(const Context& dev_ctx,
9797
y_dims[0]));
9898

9999
dev_ctx.template Alloc<T>(out);
100+
if (out->numel() == 0) return;
101+
100102
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
101103

102104
// calc broadcast dim
@@ -112,6 +114,13 @@ void AddmmKernel(const Context& dev_ctx,
112114
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
113115
place, eigen_out, eigen_input, bcast_dims);
114116

117+
// Just return input X beta
118+
if (x.numel() == 0 || y.numel() == 0) {
119+
auto eigen_out2 = phi::EigenVector<T>::Flatten(*out);
120+
eigen_out2.device(place) = eigen_out2 * static_cast<T>(beta);
121+
return;
122+
}
123+
115124
T t_alpha = static_cast<T>(alpha);
116125
T t_beta = static_cast<T>(beta);
117126
blas.GEMM(false,

paddle/phi/kernels/xpu/addmm_grad_kernel.cc

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1717
#include "paddle/phi/backends/xpu/xpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
20-
2121
namespace phi {
2222

2323
template <typename T, typename Context>
@@ -32,6 +32,24 @@ void AddmmGradKernel(const Context& dev_ctx,
3232
DenseTensor* x_grad,
3333
DenseTensor* y_grad) {
3434
using XPUType = typename XPUTypeTrait<T>::Type;
35+
if (out_grad.numel() == 0) {
36+
if (input_grad) {
37+
phi::Full<T, Context>(
38+
dev_ctx,
39+
phi::IntArray(common::vectorize(input_grad->dims())),
40+
0,
41+
input_grad);
42+
}
43+
if (x_grad) {
44+
phi::Full<T, Context>(
45+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
46+
}
47+
if (y_grad) {
48+
phi::Full<T, Context>(
49+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
50+
}
51+
return;
52+
}
3553

3654
xpu::Context* xpu_ctx = dev_ctx.x_context();
3755
xpu::ctx_guard RAII_GUARD(xpu_ctx);
@@ -59,7 +77,16 @@ void AddmmGradKernel(const Context& dev_ctx,
5977
if (y_grad) {
6078
dev_ctx.template Alloc<T>(y_grad);
6179
}
62-
80+
if (x_grad && x_grad->numel() == 0) {
81+
phi::Full<T, Context>(
82+
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
83+
return;
84+
}
85+
if (y_grad && y_grad->numel() == 0) {
86+
phi::Full<T, Context>(
87+
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
88+
return;
89+
}
6390
const XPUType* out_grad_ptr =
6491
reinterpret_cast<const XPUType*>(out_grad.data<T>());
6592
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());

paddle/phi/kernels/xpu/addmm_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,16 @@ void AddmmKernel(const Context& dev_ctx,
6060
input_dims));
6161

6262
dev_ctx.template Alloc<T>(out);
63+
if (out->numel() == 0) return;
64+
6365
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
6466
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
6567
const XPUType* input_ptr = reinterpret_cast<const XPUType*>(input.data<T>());
6668
XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
6769

6870
int r;
69-
if (alpha == 0.f) {
71+
// If x.numel or y.numel is 0, we just need to do a broadcast mul.
72+
if (alpha == 0.f || x.numel() == 0 || y.numel() == 0) {
7073
if (beta == 0.f) {
7174
r = xpu::constant(dev_ctx.x_context(),
7275
out_ptr,

test/legacy_test/test_addmm_op.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,60 @@ def test_api_normal_3(self):
513513
paddle.enable_static()
514514

515515

516+
class TestAddmmOp_ZeroSize(OpTest):
517+
def setUp(self):
518+
self.op_type = "addmm"
519+
self.python_api = paddle.addmm
520+
self.public_python_api = paddle.addmm
521+
self.init_dtype_type()
522+
self.init_input()
523+
self.attrs = {
524+
'Alpha': 0.5,
525+
'Beta': 2.0,
526+
}
527+
self.outputs = {
528+
'Out': self.attrs['Beta'] * self.inputs['Input']
529+
+ self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])
530+
}
531+
532+
def init_input(self):
533+
# result shape: [20, 100]
534+
self.inputs = {
535+
'Input': np.random.random(100).astype(self.dtype),
536+
'X': np.random.random((20, 0)).astype(self.dtype),
537+
'Y': np.random.random((0, 100)).astype(self.dtype),
538+
}
539+
540+
def init_dtype_type(self):
541+
self.dtype = np.float64
542+
543+
def test_check_output(self):
544+
self.check_output(check_pir=True)
545+
546+
def test_check_grad_normal(self):
547+
self.check_grad(['Input', 'X', 'Y'], 'Out', check_pir=True)
548+
549+
550+
class TestAddmmOp_ZeroSize2(TestAddmmOp_ZeroSize):
551+
def init_input(self):
552+
# result shape: [20, 0]
553+
self.inputs = {
554+
'Input': np.random.random(0).astype(self.dtype),
555+
'X': np.random.random((20, 100)).astype(self.dtype),
556+
'Y': np.random.random((100, 0)).astype(self.dtype),
557+
}
558+
559+
560+
class TestAddmmOp_ZeroSize3(TestAddmmOp_ZeroSize):
561+
def init_input(self):
562+
# result shape: [0, 0]
563+
self.inputs = {
564+
'Input': np.random.random(0).astype(self.dtype),
565+
'X': np.random.random((0, 100)).astype(self.dtype),
566+
'Y': np.random.random((100, 0)).astype(self.dtype),
567+
}
568+
569+
516570
if __name__ == "__main__":
517571
paddle.enable_static()
518572
unittest.main()

0 commit comments

Comments
 (0)