Skip to content

Commit

Permalink
Merge pull request #3 from wentingj/mkldnn-concat
Browse files Browse the repository at this point in the history
add mkldnn surport for concat
  • Loading branch information
zheng-da committed Dec 8, 2017
2 parents f57fd90 + b31823a commit 8758290
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 4 deletions.
97 changes: 95 additions & 2 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*/

#include "./concat-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -103,12 +104,100 @@ static bool ConcatType(const nnvm::NodeAttrs& attrs,
return true;
}

inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK(!in_attrs->empty());
CHECK_EQ(out_attrs->size(), 1U);
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
(*out_attrs)[0] = kMKLDNNStorage;
return true;
}
#endif
*dispatch_mode = DispatchMode::kFCompute;
(*out_attrs)[0] = kDefaultStorage;
return true;
}

inline static bool backward_ConcatStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
#if MXNET_USE_MKLDNN == 1
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
for (size_t i = 0; i < out_attrs->size(); i++)
(*out_attrs)[i] = kMKLDNNStorage;
return true;
}
#endif
*dispatch_mode = DispatchMode::kFCompute;
for (size_t i = 0; i < out_attrs->size(); i++)
(*out_attrs)[i] = kDefaultStorage;
return true;
}

void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(!inputs.empty());
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
if (req[0] == kNullOp) return;
#if MXNET_USE_MKLDNN == 1
// MKLDNN support 2D and 4D concat
if (inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) {
if (inputs[0].dtype() == mshadow::kFloat32) {
MKLDNNConcat_Forward(attrs, op_ctx, inputs, req, outputs);
}
} else {
std::vector<TBlob> in_blobs(inputs.size());
for (size_t i = 0; i < in_blobs.size(); i++)
in_blobs[i] = inputs[i].data();
std::vector<TBlob> out_blobs(outputs.size());
for (size_t i = 0; i < out_blobs.size(); i++)
out_blobs[i] = outputs[i].data();
ConcatCompute<cpu>(attrs, op_ctx, in_blobs, req, out_blobs);
}
#endif
}

static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req, const std::vector<NDArray>& outputs) {
#if MXNET_USE_MKLDNN == 1
if (inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) {
if (inputs[0].dtype() == mshadow::kFloat32) {
MKLDNNConcat_Backward(attrs, ctx, inputs, req, outputs);
}
} else {
std::vector<TBlob> in_blobs(1);
in_blobs[0] = inputs[0].data();
std::vector<TBlob> out_blobs(outputs.size());
for (size_t i = 0; i < out_blobs.size(); i++)
out_blobs[i] = outputs[i].data();
ConcatGradCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
}
#endif
}

struct ConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
const ConcatParam& param = nnvm::get<ConcatParam>(n->attrs.parsed);
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
for (size_t i = 0; i < n->inputs.size(); i++) {
heads.push_back(n->inputs[i]);
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
Expand Down Expand Up @@ -165,7 +254,9 @@ Example::
})
.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
Expand All @@ -180,7 +271,9 @@ NNVM_REGISTER_OP(_backward_Concat)
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
.set_attr<FInferStorageType>("FInferStorageType", backward_ConcatStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU);

} // namespace op
} // namespace mxnet
88 changes: 88 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_concat.cc
* \brief
* \author Wenting Jiang
*/
#include "../concat-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

void MKLDNNConcat_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data, const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
int num_in_data = param.num_args;
int concat_dim = param.dim;
std::vector<mkldnn::memory::primitive_desc> data_md;
std::vector<mkldnn::primitive::at> data_mem;
for (int i =0; i < num_in_data; i++) {
std::shared_ptr<const mkldnn::memory> tmp_mem = in_data[i].GetMKLDNNData();
auto tmp_pd = tmp_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
data_mem.push_back(*tmp_mem);
}
mkldnn::concat::primitive_desc fwd_pd(concat_dim, data_md);
auto engine = CpuEngine::Instance().get_engine();
auto out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
fwd_pd.dst_primitive_desc(), req[concat_enum::kOut]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::concat(fwd_pd, data_mem, *out_mem.second));
CommitOutput(out_data[concat_enum::kOut], out_mem);
MKLDNNStream::Instance().Submit();
}

void MKLDNNConcat_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs, const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
int num_in_data = param.num_args;
int axis_ = param.dim;
auto engine = CpuEngine::Instance().get_engine();
std::shared_ptr<const mkldnn::memory>gz_mem = inputs[0].GetMKLDNNData();
mkldnn::memory::primitive_desc gz_pd = gz_mem->get_primitive_desc();
/* init the offset */
mkldnn::memory::dims offsets = {0, 0, 0, 0};
for (int i = 0; i < num_in_data; i++) {
mkldnn::memory::dims diff_src_tz = {inputs[i+1].shape()[0], inputs[i+1].shape()[1],
inputs[i+1].shape()[2], inputs[i+1].shape()[3]};
auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc();
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
// create view from gy to gxs[i]
std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
view_pd.reset(new mkldnn::view::primitive_desc(gz_pd, diff_src_tz, offsets));
// create reorder primitive from gy to gxs[i]
mkldnn::reorder::primitive_desc reorder_pd(
view_pd.get()->dst_primitive_desc(), diff_src_mpd);
offsets[axis_] += diff_src_tz[axis_];
MKLDNNStream::Instance().RegisterPrim(mkldnn::reorder(
reorder_pd, *gz_mem, *gradi_mem_.second));
CommitOutput(outputs[i], gradi_mem_);
}
MKLDNNStream::Instance().Submit();
}

} // namespace op
} // namespace mxnet
#endif
12 changes: 10 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@ void MKLDNNSum_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

/* For copy */
void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For concat */
void MKLDNNConcat_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data, const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNConcat_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs, const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

} // namespace op
} // namespace mxnet
Expand Down

0 comments on commit 8758290

Please sign in to comment.