Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions paddle/phi/kernels/cpu/elementwise_multiply_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -38,12 +38,59 @@ void MultiplyKernel(const Context& dev_ctx,
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::MultiplyFunctor<T>(), out, -1);
DenseTensor x_fp32 = phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32);
DenseTensor y_fp32 = phi::Cast<T, Context>(dev_ctx, y, DataType::FLOAT32);
DataType final_out_dtype = out->dtype();
if (final_out_dtype == DataType::UNDEFINED) {
final_out_dtype = x.dtype();
}
if constexpr (std::is_same_v<T, phi::dtype::float16> ||
std::is_same_v<T, phi::dtype::bfloat16>) {
if (final_out_dtype == DataType::FLOAT32) {
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, x_fp32, y_fp32, funcs::MultiplyFunctor<T>(), out, -1);
} else {
funcs::ElementwiseCompute<funcs::InverseMultiplyFunctor<T>, T>(
dev_ctx,
x_fp32,
y_fp32,
funcs::InverseMultiplyFunctor<T>(),
out,
-1);
}
} else {
DenseTensor intermediate_result;
intermediate_result.set_meta(out->meta());
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx,
x_fp32,
y_fp32,
funcs::MultiplyFunctor<T>(),
&intermediate_result,
-1);
} else {
funcs::ElementwiseCompute<funcs::InverseMultiplyFunctor<T>, T>(
dev_ctx,
x_fp32,
y_fp32,
funcs::InverseMultiplyFunctor<T>(),
&intermediate_result,
-1);
}

phi::CastKernel<float, Context>(
dev_ctx, intermediate_result, final_out_dtype, out);
}
} else {
funcs::ElementwiseCompute<funcs::InverseMultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseMultiplyFunctor<T>(), out, -1);
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::MultiplyFunctor<T>(), out, -1);
} else {
funcs::ElementwiseCompute<funcs::InverseMultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseMultiplyFunctor<T>(), out, -1);
}
}
}
}
Expand All @@ -67,4 +114,5 @@ PD_REGISTER_KERNEL(multiply,
bool,
complex64,
complex128,
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 9 additions & 0 deletions test/legacy_test/test_math_op_patch_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ def test_mul(self):
res = a * b
np.testing.assert_array_equal(res.numpy(), a_np * b_np)

def test_mul_fp16(self):
a_np = np.random.random(self.shape).astype(np.float16)
b_np = np.random.random(self.shape).astype(np.float16)
with base.dygraph.guard():
a = paddle.to_tensor(a_np)
b = paddle.to_tensor(b_np)
res = a * b
np.testing.assert_array_equal(res.numpy(), a_np * b_np)

def test_type_promotion_mul_F2_F4(self):
a_np = np.random.random(self.shape).astype(np.float16)
b_np = np.random.random(self.shape).astype(np.float32)
Expand Down
Loading