Skip to content

Commit

Permalink
Add conv2d_transpose to bf16 list and kernel refator
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna committed Feb 16, 2021
1 parent 4fa62a2 commit 2a8a832
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2202,9 +2202,9 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc",
"fusion_gru", "gelu", "layer_norm", "matmul", "pool2d", "reshape2",
"softmax", "sum", "transpose2"});
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
Expand Down
48 changes: 25 additions & 23 deletions paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ class ConvTransposeMKLDNNHandlerT
platform::errors::Unimplemented(
"Now we only support 2d oneDNN convolution transpose op"));

auto input_dims = input->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
const auto& input_dims = input->dims();
const auto data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
const auto& filter_dims = filter->dims();
const auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());

auto ksize = framework::vectorize(filter_data_dims);
const auto ksize = framework::vectorize(filter_data_dims);

UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize);
Expand All @@ -140,31 +141,30 @@ class ConvTransposeMKLDNNHandlerT
[](int64_t i) { return i - 1; });

const auto src_tz = framework::vectorize(input->dims());

auto weights_tz = GetWeightsTz(filter, groups);

auto dst_tz = framework::vectorize(output->dims());
const auto weights_tz = GetWeightsTz(filter, groups);
const auto dst_tz = framework::vectorize(output->dims());
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
auto chosen_memory_format = MKLDNNMemoryFormat::any;
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");

auto data_type = mkldnn::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value)
data_type = mkldnn::memory::data_type::bf16;

auto src_md =
const auto src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
auto weights_md =
const auto weights_md =
platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc(
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

const mkldnn::primitive_attr conv_trans_attr =
Expand All @@ -173,7 +173,7 @@ class ConvTransposeMKLDNNHandlerT
: mkldnn::prop_kind::forward_training;
if (bias) {
std::vector<int64_t> bias_tz = framework::vectorize(bias->dims());
auto bias_md =
const auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr, fwd_prop_kind,
Expand All @@ -188,8 +188,9 @@ class ConvTransposeMKLDNNHandlerT
}
}

mkldnn::primitive_attr CreatePostOps(std::string fuse_activation,
float fuse_alpha, float fuse_beta) {
mkldnn::primitive_attr CreatePostOps(const std::string& fuse_activation,
const float& fuse_alpha,
const float& fuse_beta) {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;

Expand Down Expand Up @@ -237,7 +238,7 @@ class ConvTransposeMKLDNNHandlerT
}

std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int groups, const bool is_test) {
const framework::Tensor* filter, const int& groups, const bool& is_test) {
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
Expand Down Expand Up @@ -282,7 +283,7 @@ class ConvTransposeMKLDNNHandlerT
}

std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) {
const framework::Tensor* bias, const bool& is_test) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
return bias_mem_p;
Expand All @@ -305,8 +306,9 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"Operator DNNL ConvTranspose must use CPUPlace"));
bool is_bfloat16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
const bool is_bfloat16 =
ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (is_bfloat16) {
if (force_fp32_output)
Execute<float>(ctx);
Expand Down

1 comment on commit 2a8a832

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.