Skip to content

Commit

Permalink
add mkldnn softmax_output (apache#13699)
Browse files Browse the repository at this point in the history
* add mkldnn softmax_output

* fix gpu OP unittest error

* fix ci/jenkins/mxnet-validation/unix-gpu compiler error

* fix coding style

* fix Tao comments

* remove blank line, fix indentx

* modify according to sandeep's comments

* change get CPU engine method, and pravate variable

* move macro MXNET_USE_MKLDNN to the head

* modify according to Tao's comments

* make output layout as input

* change API of GetSoftmaxOutputForward

* add CommitOutput for mkldnn_softmax_output

* trigger Jenkins re-test

* add alias Softmax symbol for SoftmaxOutput OP

* indent and remove blank line
  • Loading branch information
rongzha1 authored and vdantu committed Mar 31, 2019
1 parent f76b2fb commit 483a2c2
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 40 deletions.
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ struct ActivationParam;
struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
struct SoftmaxOutputParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
10 changes: 8 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,21 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For sum */
void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data);

/* For copy */
void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For concat */
void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
Expand Down
145 changes: 145 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_softmax_output.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_softmax_output.cc
* \brief integrate mkldnn softmax to softmax_output forward
* \author Zhang Rong A
*/

#if MXNET_USE_MKLDNN == 1
#include "../../softmax_output-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl(
const SoftmaxOutputParam& param, bool is_train,
const int axis, const mkldnn::memory &input_mem) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
}

typedef ParamOpSign<SoftmaxOutputParam> MKLDNNSoftmaxOuputSignature;

class MKLDNNSoftmaxOutputFwd {
std::shared_ptr<mkldnn::softmax_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;

public:
const mkldnn::softmax_forward::primitive_desc fwd_pd;

MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train,
const int axis, const mkldnn::memory &mem): fwd_pd(
GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) {
}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
else
this->data_->set_data_handle(data.get_data_handle());

if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
output.get_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::softmax_forward>(
new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
*this->out_));
}
}

const mkldnn::softmax_forward &GetFwd() const {
return *fwd_;
}
};

static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param,
const OpContext &ctx,
const NDArray &in_data) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local
std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL
std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> fwds;
#endif
MKLDNNSoftmaxOuputSignature key(param);
key.AddSign(ctx.is_train);
key.AddSign(in_data);

// softmax_output has no axis parameter, so use it as it original implement.
int axis = in_data.shape().ndim() - 1;

auto it = fwds.find(key);
if (it == fwds.end()) {
auto in_mem = *(in_data.GetMKLDNNData());
MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

// This is only used for forward. For backward ,need double check compatibility
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param) {
return param.multi_output ? false : true;
}

void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
const SoftmaxOutputParam &param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);

NDArray idata = in_data[softmaxout_enum::kData];
NDArray odata = out_data[softmaxout_enum::kOut];
if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) {
idata = in_data[softmaxout_enum::kData].Reorder2Default();
}

auto input_mem = idata.GetMKLDNNData();
auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut],
input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]);

MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata);
fwd.SetNewMem(*input_mem, *out_mem.second);

MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(fwd.GetFwd());

CommitOutput(out_data[softmaxout_enum::kOut], out_mem);
stream->Submit();
}
} // namespace op
} // namespace mxnet
#endif
68 changes: 66 additions & 2 deletions src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ struct SoftmaxOutputParam : public dmlc::Parameter<SoftmaxOutputParam> {
"one-hot encoding of the gold label and distributed uniformly to"
"all other labels.");
};

bool operator==(const SoftmaxOutputParam& other) const {
return this->grad_scale == other.grad_scale &&
this->ignore_label == other.ignore_label &&
this->multi_output == other.multi_output &&
this->use_ignore == other.use_ignore &&
this->preserve_shape == other.preserve_shape &&
this->normalization == other.normalization &&
this->out_grad == other.out_grad &&
this->smooth_alpha == other.smooth_alpha;
}
};

template<typename xpu, typename DType>
Expand Down Expand Up @@ -267,9 +278,43 @@ class SoftmaxOutputOp : public Operator {
SoftmaxOutputParam param_;
}; // class SoftmaxOutputOp

// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(SoftmaxOutputParam param, int dtype);
void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const SoftmaxOutputParam &param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
CHECK_EQ(inputs.size(), 2U);

MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, {
SoftmaxOutputOp<xpu, DType> op(param);
op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
});
}

template<typename xpu>
void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
CHECK_EQ(inputs.size(), 2U);

std::vector<TBlob> out_grad{inputs[0]};
std::vector<TBlob> out_data{inputs[0]};
std::vector<TBlob> in_data(inputs.begin(), inputs.end());
int dtype = inputs[0].type_flag_;
const std::vector<TBlob> &in_grad = outputs;

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
SoftmaxOutputOp<xpu, DType> op(param);
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
});
}


#if DMLC_USE_CXX11
class SoftmaxOutputProp : public OperatorProperty {
Expand Down Expand Up @@ -414,4 +459,23 @@ class DeprecatedSoftmaxProp : public SoftmaxOutputProp {

} // namespace op
} // namespace mxnet

namespace std {
template<>
struct hash<mxnet::op::SoftmaxOutputParam> {
size_t operator()(const mxnet::op::SoftmaxOutputParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.grad_scale);
ret = dmlc::HashCombine(ret, val.ignore_label);
ret = dmlc::HashCombine(ret, val.multi_output);
ret = dmlc::HashCombine(ret, val.use_ignore);
ret = dmlc::HashCombine(ret, val.preserve_shape);
ret = dmlc::HashCombine(ret, val.normalization);
ret = dmlc::HashCombine(ret, val.out_grad);
ret = dmlc::HashCombine(ret, val.smooth_alpha);
return ret;
}
};
} // namespace std

#endif // MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_
Loading

0 comments on commit 483a2c2

Please sign in to comment.