Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,29 +98,6 @@ void AddmmInferMeta(const MetaTensor& input,
<< " alpha=" << alpha << " ndim_input=" << ndim_input
<< " ndim_x=" << ndim_x << " ndim_y=" << ndim_y;

PADDLE_ENFORCE_NE(
product(input_dims),
0,
errors::PreconditionNotMet("The Input variable 'input' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));

PADDLE_ENFORCE_NE(
product(x_dims),
0,
errors::PreconditionNotMet("The Input variable 'x' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));

PADDLE_ENFORCE_NE(
product(y_dims),
0,
errors::PreconditionNotMet("The Input variable 'y' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
// dim check
PADDLE_ENFORCE_EQ(ndim_input == 2 || ndim_input == 1,
true,
Expand Down
31 changes: 31 additions & 0 deletions paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
Expand Down Expand Up @@ -66,6 +67,24 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
if (out_grad.numel() == 0) {
if (input_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(input_grad->dims())),
0,
input_grad);
}
if (x_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
}
if (y_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
}
return;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
Expand Down Expand Up @@ -166,6 +185,18 @@ void AddmmGradKernel(const Context& dev_ctx,
input_grad->Resize(input.dims());
}
}
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}
if (y_grad && y_grad->numel() == 0) {
dev_ctx.template Alloc<T>(y_grad);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
return;
}
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
total_elems = x.dims()[0] * x.dims()[1];
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/impl/addmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ void AddmmKernel(const Context& dev_ctx,
y_dims[0]));

dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) return;

auto blas = funcs::GetBlas<Context, T>(dev_ctx);

// calc broadcast dim
Expand All @@ -112,6 +114,13 @@ void AddmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);

// Just return input X beta
if (x.numel() == 0 || y.numel() == 0) {
auto eigen_out2 = phi::EigenVector<T>::Flatten(*out);
eigen_out2.device(place) = eigen_out2 * static_cast<T>(beta);
return;
}

T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false,
Expand Down
31 changes: 29 additions & 2 deletions paddle/phi/kernels/xpu/addmm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -32,6 +32,24 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (out_grad.numel() == 0) {
if (input_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(input_grad->dims())),
0,
input_grad);
}
if (x_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
}
if (y_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
}
return;
}

xpu::Context* xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
Expand Down Expand Up @@ -59,7 +77,16 @@ void AddmmGradKernel(const Context& dev_ctx,
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
}

if (x_grad && x_grad->numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y_grad->dims())), 0, y_grad);
return;
}
if (y_grad && y_grad->numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
return;
}
const XPUType* out_grad_ptr =
reinterpret_cast<const XPUType*>(out_grad.data<T>());
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/xpu/addmm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,16 @@ void AddmmKernel(const Context& dev_ctx,
input_dims));

dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) return;

const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
const XPUType* input_ptr = reinterpret_cast<const XPUType*>(input.data<T>());
XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());

int r;
if (alpha == 0.f) {
// If x.numel or y.numel is 0, we just need to do a broadcast mul.
if (alpha == 0.f || x.numel() == 0 || y.numel() == 0) {
if (beta == 0.f) {
r = xpu::constant(dev_ctx.x_context(),
out_ptr,
Expand Down
54 changes: 54 additions & 0 deletions test/legacy_test/test_addmm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,60 @@ def test_api_normal_3(self):
paddle.enable_static()


class TestAddmmOp_ZeroSize(OpTest):
def setUp(self):
self.op_type = "addmm"
self.python_api = paddle.addmm
self.public_python_api = paddle.addmm
self.init_dtype_type()
self.init_input()
self.attrs = {
'Alpha': 0.5,
'Beta': 2.0,
}
self.outputs = {
'Out': self.attrs['Beta'] * self.inputs['Input']
+ self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])
}

def init_input(self):
# result shape: [20, 100]
self.inputs = {
'Input': np.random.random(100).astype(self.dtype),
'X': np.random.random((20, 0)).astype(self.dtype),
'Y': np.random.random((0, 100)).astype(self.dtype),
}

def init_dtype_type(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['Input', 'X', 'Y'], 'Out', check_pir=True)


class TestAddmmOp_ZeroSize2(TestAddmmOp_ZeroSize):
def init_input(self):
# result shape: [20, 0]
self.inputs = {
'Input': np.random.random(0).astype(self.dtype),
'X': np.random.random((20, 100)).astype(self.dtype),
'Y': np.random.random((100, 0)).astype(self.dtype),
}


class TestAddmmOp_ZeroSize3(TestAddmmOp_ZeroSize):
def init_input(self):
# result shape: [0, 0]
self.inputs = {
'Input': np.random.random(0).astype(self.dtype),
'X': np.random.random((0, 100)).astype(self.dtype),
'Y': np.random.random((100, 0)).astype(self.dtype),
}


if __name__ == "__main__":
paddle.enable_static()
unittest.main()