Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Integrate MKLDNN Conv1d and support 3d layout #13530

Merged
merged 14 commits into from
Jan 2, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,24 +454,18 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {

mkldnn::memory::dims dims;
// These are shapes supprted by MKLDNN.
if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
|| shape.ndim() == 5) {
if (shape.ndim() >= 1 && shape.ndim() <= 5) {
dims.resize(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape[i];
} else if (shape.ndim() == 3) {
// If there are 3 dimensions, we'll force it to 4 dimensions.
dims.resize(shape.ndim() + 1);
dims[0] = 1;
for (size_t i = 0; i < shape.ndim(); i++)
dims[i + 1] = shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a performance difference between 3D and 4D implementation?

} else {
LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
}
mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
switch (dims.size()) {
case 1: layout = mkldnn::memory::format::x; break;
case 2: layout = mkldnn::memory::format::nc; break;
case 3: layout = mkldnn::memory::format::ncw; break;
case 4: layout = mkldnn::memory::format::nchw; break;
// This isn't the right layout when the data has 5 dimensions in MXNet.
// MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
Expand Down
5 changes: 3 additions & 2 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand All @@ -115,7 +116,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
Expand Down
9 changes: 9 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
|| param.act_type == activation::kTanh;
}

bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
(input.dtype() != mshadow::kFloat32))
return false;
return SupportMKLDNNAct(param);
}

static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
Expand Down
28 changes: 19 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,11 @@ struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
}
} // namespace op

static int GetTypeSize(int dtype) {
int size = -1;
Expand Down Expand Up @@ -253,14 +254,23 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
if (num_groups == 1) {
return GetMemDesc(arr);
} else {
CHECK_EQ(arr.shape().ndim(), 4U);
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
CHECK((arr.shape().ndim() == 3) || (arr.shape().ndim() == 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a variable to save the value of arr.shape().ndim() to avoid mutiple time call

<< "MKL-DNN weight currectly supports 3d and 4d layout";
const int N = 0, H = 2, W = 3, C = 1;
if (arr.shape().ndim() == 3) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
} else {
mkldnn::memory::dims tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the common part of dim3 and dim4, right?

mkldnn::memory::format::any};
}
}
}

Expand Down
99 changes: 73 additions & 26 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,39 +239,49 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
return mem;

mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
mkldnn::memory::dims tz = mkldnn::memory::dims{0};
mkldnn::memory::format format = mkldnn::memory::format::format_undef;
auto engine = CpuEngine::Get()->get_engine();
const int O = 0, I = 1, H = 2, W = 3;
if (arr.shape().ndim() == 2) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
} else if (arr.shape().ndim() == 4 && num_groups == 1) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I])};
format = mkldnn::memory::format::oi;
} else if (arr.shape().ndim() == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups,
static_cast<int>(arr.shape()[O] /
num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])};
format = num_groups > 1 ? mkldnn::memory::format::goiw
: mkldnn::memory::format::oiw;
} else if (arr.shape().ndim() == 4) {
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
tz = num_groups > 1
? mkldnn::memory::dims{num_groups,
static_cast<int>(arr.shape()[O] /
num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
format = num_groups > 1 ? mkldnn::memory::format::goihw
: mkldnn::memory::format::oihw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
return nullptr;
}
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, format};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
if (mem == nullptr)
mem = arr.GetMKLDNNDataReorder(target_pd);
if (mem->get_primitive_desc() == target_pd) return mem;
Expand All @@ -285,6 +295,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
switch (num_dims) {
case 1: return mkldnn_x;
case 2: return mkldnn_nc;
case 3: return mkldnn_ncw;
case 4: return mkldnn_nchw;
case 5: return mkldnn_goihw;
default:
Expand All @@ -301,6 +312,30 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
return mkldnn_oi;
else
return desc.data.format;
} else if (desc.data.ndims == 3) {
switch (desc.data.format) {
case mkldnn_ncw:
case mkldnn_nwc:
case mkldnn_nCw8c:
case mkldnn_nCw16c:
return mkldnn_ncw;
case mkldnn_oiw:
case mkldnn_wio:
case mkldnn_Owi8o:
case mkldnn_OIw8i8o:
case mkldnn_OIw8o8i:
case mkldnn_OIw16i16o:
case mkldnn_OIw16o16i:
case mkldnn_Oiw16o:
case mkldnn_Owi16o:
case mkldnn_OIw8i16o2i:
case mkldnn_OIw8o16i2o:
case mkldnn_IOw16o16i:
return mkldnn_oiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
Expand Down Expand Up @@ -329,6 +364,18 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
case mkldnn_Ohwi16o:
case mkldnn_OhIw16o4i:
return mkldnn_oihw;
case mkldnn_goiw:
case mkldnn_gOwi8o:
case mkldnn_gOIw8o8i:
case mkldnn_gOIw8i8o:
case mkldnn_gOIw16i16o:
case mkldnn_gOIw16o16i:
case mkldnn_gOiw16o:
case mkldnn_gOwi16o:
case mkldnn_gOIw8i16o2i:
case mkldnn_gOIw8o16i2o:
case mkldnn_gIOw16o16i:
return mkldnn_goiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
Expand Down
Loading