Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Integrate MKLDNN Conv1d and support 3d layout #13530

Merged
merged 14 commits into from
Jan 2, 2019
Merged
10 changes: 2 additions & 8 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,24 +454,18 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {

mkldnn::memory::dims dims;
// These are shapes supprted by MKLDNN.
if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
|| shape.ndim() == 5) {
if (shape.ndim() >= 1 && shape.ndim() <= 5) {
dims.resize(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape[i];
} else if (shape.ndim() == 3) {
// If there are 3 dimensions, we'll force it to 4 dimensions.
dims.resize(shape.ndim() + 1);
dims[0] = 1;
for (size_t i = 0; i < shape.ndim(); i++)
dims[i + 1] = shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a performance difference between 3D and 4D implementation?

} else {
LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
}
mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
switch (dims.size()) {
case 1: layout = mkldnn::memory::format::x; break;
case 2: layout = mkldnn::memory::format::nc; break;
case 3: layout = mkldnn::memory::format::ncw; break;
case 4: layout = mkldnn::memory::format::nchw; break;
// This isn't the right layout when the data has 5 dimensions in MXNet.
// MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
Expand Down
5 changes: 3 additions & 2 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand All @@ -115,7 +116,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
Expand Down
11 changes: 11 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
|| param.act_type == activation::kTanh;
}

bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
(input.dtype() != mshadow::kFloat32))
return false;
return param.act_type == activation::kReLU
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}

static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
Expand Down
28 changes: 19 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,11 @@ struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
}
} // namespace op

static int GetTypeSize(int dtype) {
int size = -1;
Expand Down Expand Up @@ -253,14 +254,23 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
if (num_groups == 1) {
return GetMemDesc(arr);
} else {
CHECK_EQ(arr.shape().ndim(), 4U);
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
CHECK((arr.shape().ndim() == 3) || (arr.shape().ndim() == 4));
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
if (arr.shape().ndim() == 3) {
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
} else {
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the common part of dim3 and dim4, right?

mkldnn::memory::format::any};
}
}
}

Expand Down
84 changes: 63 additions & 21 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,39 +238,44 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
return mem;

mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
mkldnn::memory::dims tz = mkldnn::memory::dims{0};
mkldnn::memory::format format = mkldnn::memory::format::format_undef;
auto engine = CpuEngine::Get()->get_engine();
if (arr.shape().ndim() == 2) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
} else if (arr.shape().ndim() == 4 && num_groups == 1) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
format = mkldnn::memory::format::oi;
} else if (arr.shape().ndim() == 3) {
tz = num_groups > 1 ? mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2])} :
mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2])};
format = num_groups > 1 ? mkldnn::memory::format::goiw : mkldnn::memory::format::oiw;
} else if (arr.shape().ndim() == 4) {
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
tz = num_groups > 1 ? mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])} :
mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
format = num_groups > 1 ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
return nullptr;
}
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, format};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
if (mem == nullptr)
mem = arr.GetMKLDNNDataReorder(target_pd);
if (mem->get_primitive_desc() == target_pd) return mem;
Expand All @@ -284,6 +289,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
switch (num_dims) {
case 1: return mkldnn_x;
case 2: return mkldnn_nc;
case 3: return mkldnn_ncw;
case 4: return mkldnn_nchw;
case 5: return mkldnn_goihw;
default:
Expand All @@ -300,6 +306,30 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
return mkldnn_oi;
else
return desc.data.format;
} else if (desc.data.ndims == 3) {
switch (desc.data.format) {
case mkldnn_ncw:
case mkldnn_nwc:
case mkldnn_nCw8c:
case mkldnn_nCw16c:
return mkldnn_ncw;
case mkldnn_oiw:
case mkldnn_wio:
case mkldnn_Owi8o:
case mkldnn_OIw8i8o:
case mkldnn_OIw8o8i:
case mkldnn_OIw16i16o:
case mkldnn_OIw16o16i:
case mkldnn_Oiw16o:
case mkldnn_Owi16o:
case mkldnn_OIw8i16o2i:
case mkldnn_OIw8o16i2o:
case mkldnn_IOw16o16i:
return mkldnn_oiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
Expand Down Expand Up @@ -328,6 +358,18 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
case mkldnn_Ohwi16o:
case mkldnn_OhIw16o4i:
return mkldnn_oihw;
case mkldnn_goiw:
case mkldnn_gOwi8o:
case mkldnn_gOIw8o8i:
case mkldnn_gOIw8i8o:
case mkldnn_gOIw16i16o:
case mkldnn_gOIw16o16i:
case mkldnn_gOiw16o:
case mkldnn_gOwi16o:
case mkldnn_gOIw8i16o2i:
case mkldnn_gOIw8o16i2o:
case mkldnn_gIOw16o16i:
return mkldnn_goiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
Expand Down
115 changes: 77 additions & 38 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ namespace op {
DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
if (params.kernel.ndim() != 2)
if ((params.kernel.ndim() != 1) &&
(params.kernel.ndim() != 2))
return false;
return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4;
return SupportMKLDNNQuantize(input.dtype()) &&
((input.shape().ndim() == 3) ||
(input.shape().ndim() == 4));
}

mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
Expand All @@ -51,15 +54,23 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.conv_param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.conv_param.stride.ndim(), 2U);
CHECK_GE(param.conv_param.pad.ndim(), 2U);
CHECK_GE(param.conv_param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
mkldnn::memory::dims padding{0, 0};
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
mkldnn::memory::dims strides(param.conv_param.kernel.ndim());
mkldnn::memory::dims padding(param.conv_param.kernel.ndim());
if (param.conv_param.kernel.ndim() == 1) {
CHECK_GE(param.conv_param.stride.ndim(), 1U);
CHECK_GE(param.conv_param.pad.ndim(), 1U);
CHECK_GE(param.conv_param.dilate.ndim(), 1U);
strides[0] = param.conv_param.stride[0];
padding[0] = param.conv_param.pad[0];
} else if (param.conv_param.kernel.ndim() == 2) {
CHECK_GE(param.conv_param.stride.ndim(), 2U);
CHECK_GE(param.conv_param.pad.ndim(), 2U);
CHECK_GE(param.conv_param.dilate.ndim(), 2U);
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
}
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
if (param.mkldnn_param.with_relu) {
Expand Down Expand Up @@ -113,9 +124,13 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
}
return conv_pd;
} else {
mkldnn::memory::dims dilates{0, 0};
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
mkldnn::memory::dims dilates(param.conv_param.kernel.ndim());
if (param.conv_param.dilate.ndim() == 1) {
dilates[0] = param.conv_param.dilate[0] - 1;
} else if (param.conv_param.dilate.ndim() == 2) {
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
}
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
if (bias == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand Down Expand Up @@ -151,15 +166,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
TaoLv marked this conversation as resolved.
Show resolved Hide resolved

// MKL-DNN introduced padded formats since 0.15 which require more memory
// for computation compared with the actual tensor size. Currently, MKL-DNN
Expand All @@ -177,9 +200,13 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
}
return conv_pd;
} else {
mkldnn::memory::dims dilates{0, 0};
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
mkldnn::memory::dims dilates(param.kernel.ndim());
if (param.dilate.ndim() == 1) {
dilates[0] = param.dilate[0] - 1;
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
mkldnn::padding_kind::zero);
Expand All @@ -201,15 +228,23 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}

// MKL-DNN introduced padded formats since 0.15 which require more memory
// for computation compared with the actual tensor size. Currently, MKL-DNN
Expand Down Expand Up @@ -239,9 +274,13 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
}
return conv_pd;
} else {
mkldnn::memory::dims dilates{0, 0};
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
mkldnn::memory::dims dilates(param.kernel.ndim());
if (param.dilate.ndim() == 1) {
dilates[0] = param.dilate[0] - 1;
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
if (bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand Down