Skip to content

Commit

Permalink
Added mul BF16/FP32 FWD/BWD oneDNN kernel (#38552)
Browse files Browse the repository at this point in the history
* base changes for mul reimplementation

* empty commit

* tmp save

* full implementation of mul bf16/fp32 fwd bwd

* CI fix

* CI rerun

* changed unity build cmake to avoid gpu issues

* removed mul mkldnn from unity build

* added skipping tests if not cpu_bf16

* CI fix

* CI fix

* CI fix
  • Loading branch information
jakpiase authored Jan 13, 2022
1 parent 281644c commit fc6eed5
Show file tree
Hide file tree
Showing 10 changed files with 490 additions and 117 deletions.
109 changes: 1 addition & 108 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using dnnl::memory;
using dnnl::primitive;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
Expand Down Expand Up @@ -107,114 +108,6 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
return strides;
}

template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;

if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);

const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];

std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);

x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());

if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}

if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}

out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});

for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}

if (is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);

this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}

std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());

int total_stride = 1;

for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}

return fake_strides;
}

std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data));
}
};

bool IsOutputFused(const ExecutionContext& ctx) {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
Expand Down
176 changes: 169 additions & 7 deletions paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace framework {
Expand All @@ -32,13 +32,17 @@ namespace operators {
using framework::DataLayout;
using framework::DDim;
using framework::ExecutionContext;
using framework::LoDTensor;
using framework::Tensor;

using platform::MatMulV2MKLDNNHandler;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;

using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;

template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory {
Expand Down Expand Up @@ -345,7 +349,7 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,

/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
Expand All @@ -371,17 +375,175 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
}
};

template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }

protected:
void ExecuteMatMul(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine &onednn_engine,
const platform::Place &cpu_place, const Tensor *x,
const std::vector<int64_t> &x_dims, bool trans_x,
const Tensor *y, const std::vector<int64_t> &y_dims,
bool trans_y, Tensor *out) const {
static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y, false,
vec_placeholder, vec_placeholder);

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);

auto matmul_p = handler.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();

out->set_layout(framework::DataLayout::kMKLDNN);
// plain output formats are enforced inside handler
out->set_format(platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw));
}

private:
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<Tensor>("X");
const auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");

int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y;

// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> x_dims(3, 1);

y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];

x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];

ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), &x_matrix,
x_dims, false, &y_matrix, y_dims, false, out);
}
};

template <typename XT, typename YT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }

private:
template <typename OT = XT>
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<LoDTensor>("X");
const auto *y = ctx.Input<LoDTensor>("Y");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));

auto *dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<LoDTensor>(framework::GradVarName("Y"));

int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor &>(*x);
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor &>(*y);

Tensor dout_matrix = *dout;
dout_matrix.Resize(
{framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});

// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);

x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];

y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];

dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];

if (dx != nullptr) {
dx->set_lod(x->lod());
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(),
&dout_matrix, dout_dims, false, &y_matrix, y_dims,
true, static_cast<Tensor *>(dx));
}
if (dy != nullptr) {
dy->set_lod(y->lod());
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(),
&x_matrix, x_dims, true, &dout_matrix, dout_dims,
false, static_cast<Tensor *>(dy));
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
U8, ops::kMULMKLDNNINT8,
ops::MulMKLDNNKernel<uint8_t, float>);
ops::MulMKLDNNINT8Kernel<uint8_t, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
S8, ops::kMULMKLDNNINT8,
ops::MulMKLDNNKernel<int8_t, float>);
ops::MulMKLDNNINT8Kernel<int8_t, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
FP32, ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);

REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
ops::MulMKLDNNKernel<uint8_t, float>);
ops::MulMKLDNNINT8Kernel<uint8_t, float>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float, float>);
36 changes: 36 additions & 0 deletions paddle/fluid/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ class MulOp : public framework::OperatorWithKernel {
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif
Expand Down Expand Up @@ -233,6 +239,36 @@ class MulGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;

if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif

return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace operators {
using Tensor = framework::Tensor;

constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;

template <typename DeviceContext, typename T>
class MulKernel : public framework::OpKernel<T> {
Expand Down
Loading

0 comments on commit fc6eed5

Please sign in to comment.