diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index f770c4aba350..bf220b847c0e 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -174,11 +174,13 @@ struct ActivationParam; struct ConvolutionParam; struct DeconvolutionParam; struct SoftmaxParam; +struct SoftmaxOutputParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); bool SupportMKLDNNSoftmax(const SoftmaxParam& param); +bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 50937706d934..39f26325b2a5 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -76,6 +76,12 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); +/* For softmax_output */ +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + /* For sum */ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const OpReqType &req, @@ -83,8 +89,8 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, /* For copy */ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); /* For concat */ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc new file mode 100644 index 000000000000..ae34fe633d6f --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_softmax_output.cc + * \brief integrate mkldnn softmax to softmax_output forward + * \author Zhang Rong A +*/ + +#if MXNET_USE_MKLDNN == 1 +#include "../../softmax_output-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( + const SoftmaxOutputParam& param, bool is_train, + const int axis, const mkldnn::memory &input_mem) { + mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto prop = is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); + return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); +} + +typedef ParamOpSign MKLDNNSoftmaxOuputSignature; + +class MKLDNNSoftmaxOutputFwd { + std::shared_ptr fwd_; + std::shared_ptr data_; + std::shared_ptr out_; + + public: + const mkldnn::softmax_forward::primitive_desc fwd_pd; + + MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train, + const int axis, const mkldnn::memory &mem): fwd_pd( + GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { + if (this->data_ == nullptr) + this->data_ = std::shared_ptr(new mkldnn::memory( + data.get_primitive_desc(), data.get_data_handle())); + else + this->data_->set_data_handle(data.get_data_handle()); + + if (this->out_ == nullptr) + this->out_ = std::shared_ptr(new mkldnn::memory( + output.get_primitive_desc(), output.get_data_handle())); + else + this->out_->set_data_handle(output.get_data_handle()); + + if (this->fwd_ == nullptr) { + this->fwd_ = std::shared_ptr( + new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_), + *this->out_)); + } + } + + const mkldnn::softmax_forward &GetFwd() const { + return *fwd_; + } +}; + +static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param, + const OpContext &ctx, + const NDArray &in_data) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map fwds; +#else + static MX_THREAD_LOCAL + std::unordered_map fwds; +#endif + MKLDNNSoftmaxOuputSignature key(param); + key.AddSign(ctx.is_train); + key.AddSign(in_data); + + // softmax_output has no axis parameter, so use it as it original implement. + int axis = in_data.shape().ndim() - 1; + + auto it = fwds.find(key); + if (it == fwds.end()) { + auto in_mem = *(in_data.GetMKLDNNData()); + MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +// This is only used for forward. For backward ,need double check compatibility +bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m) { + return param.multi_output ? false : true; +} + +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + + NDArray idata = in_data[softmaxout_enum::kData]; + NDArray odata = out_data[softmaxout_enum::kOut]; + if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) { + idata = in_data[softmaxout_enum::kData].Reorder2Default(); + } + + auto input_mem = idata.GetMKLDNNData(); + auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut], + input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); + + MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata); + fwd.SetNewMem(*input_mem, *out_mem.second); + + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrim(fwd.GetFwd()); + + CommitOutput(out_data[softmaxout_enum::kOut], out_mem); + stream->Submit(); +} +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index fec321b97e4c..5a01d3a73a95 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -88,6 +88,17 @@ struct SoftmaxOutputParam : public dmlc::Parameter { "one-hot encoding of the gold label and distributed uniformly to" "all other labels."); }; + + bool operator==(const SoftmaxOutputParam& other) const { + return this->grad_scale == other.grad_scale && + this->ignore_label == other.ignore_label && + this->multi_output == other.multi_output && + this->use_ignore == other.use_ignore && + this->preserve_shape == other.preserve_shape && + this->normalization == other.normalization && + this->out_grad == other.out_grad && + this->smooth_alpha == other.smooth_alpha; + } }; template @@ -267,9 +278,43 @@ class SoftmaxOutputOp : public Operator { SoftmaxOutputParam param_; }; // class SoftmaxOutputOp -// Decalre Factory function, used for dispatch specialization template -Operator* CreateOp(SoftmaxOutputParam param, int dtype); +void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + CHECK_EQ(inputs.size(), 2U); + + MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, { + SoftmaxOutputOp op(param); + op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); + }); +} + +template +void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + CHECK_EQ(inputs.size(), 2U); + + std::vector out_grad{inputs[0]}; + std::vector out_data{inputs[0]}; + std::vector in_data(inputs.begin(), inputs.end()); + int dtype = inputs[0].type_flag_; + const std::vector &in_grad = outputs; + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + SoftmaxOutputOp op(param); + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api); + }); +} + #if DMLC_USE_CXX11 class SoftmaxOutputProp : public OperatorProperty { @@ -414,4 +459,23 @@ class DeprecatedSoftmaxProp : public SoftmaxOutputProp { } // namespace op } // namespace mxnet + +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::SoftmaxOutputParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.grad_scale); + ret = dmlc::HashCombine(ret, val.ignore_label); + ret = dmlc::HashCombine(ret, val.multi_output); + ret = dmlc::HashCombine(ret, val.use_ignore); + ret = dmlc::HashCombine(ret, val.preserve_shape); + ret = dmlc::HashCombine(ret, val.normalization); + ret = dmlc::HashCombine(ret, val.out_grad); + ret = dmlc::HashCombine(ret, val.smooth_alpha); + return ret; + } +}; +} // namespace std + #endif // MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 5ba421fd195b..322ac0b93426 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -21,30 +21,137 @@ * Copyright (c) 2015 by Contributors * \file softmax_output.cc * \brief - * \author Bing Xu + * \author Bing Xu, Zhang Rong A */ #include "./softmax_output-inl.h" - +#if MXNET_USE_MKLDNN == 1 +#include "./nn/mkldnn/mkldnn_ops-inl.h" +#endif namespace mxnet { namespace op { -template<> -Operator *CreateOp(SoftmaxOutputParam param, int dtype) { - Operator *op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new SoftmaxOutputOp(param); - }) - return op; + +DMLC_REGISTER_PARAMETER(SoftmaxOutputParam); +struct SoftmaxOutputGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + std::vector out_data(n->num_outputs()); + for (uint32_t i = 0; i < out_data.size(); ++i) { + out_data[i] = nnvm::NodeEntry{n, i, 0}; + } + std::vector heads; + heads.push_back(out_data[softmaxout_enum::kOut]); + heads.push_back(n->inputs[softmaxout_enum::kLabel]); + + nnvm::NodePtr gnode = nnvm::Node::Create(); + gnode->inputs = std::move(heads); + gnode->control_deps.emplace_back(n); + gnode->attrs = n->attrs; + gnode->attrs.op = nnvm::Op::Get("_backward_SoftmaxOutput"); + gnode->attrs.name = n->attrs.name + "_backward"; + std::vector in_grad(2); + in_grad[0] = nnvm::NodeEntry{gnode, 0, 0}; + in_grad[1] = nnvm::NodeEntry{gnode, 1, 0}; + return in_grad; + } +}; + +static inline std::vector ListArguments() { + return {"data", "label"}; } -// DO_BIND_DISPATCH comes from operator_common.h -Operator *SoftmaxOutputProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + CHECK_EQ(in_type->size(), 2U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); + } + } + out_type->clear(); + out_type->push_back(dtype); + return true; } -DMLC_REGISTER_PARAMETER(SoftmaxOutputParam); +static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; -MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) + // label.shape == data.shape: use probability as label + if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { + if (param.multi_output) { + TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); + TShape lshape2(dshape.ndim() - 1); + lshape2[0] = dshape[0]; + for (index_t i = 2; i < dshape.ndim(); ++i) + lshape2[i-1] = dshape[i]; + TShape lshape3 = dshape; + lshape3[1] = 1; + if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { + in_shape->at(softmaxout_enum::kLabel) = lshape1; + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) { + } else { + std::ostringstream os; + os << "Expecting " << lshape1 << " or " << lshape2 + << ". But got " << in_shape->at(softmaxout_enum::kLabel); + throw InferShapeError(os.str(), softmaxout_enum::kLabel); + } + } else { + TShape label_shape(dshape.ndim() - 1); + for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + label_shape[i] = dshape[i]; + SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); + } + } + + out_shape->clear(); + out_shape->push_back(dshape); + return true; +} + +#if MXNET_USE_MKLDNN == 1 +inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} + +void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 2U); + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); +} +#endif + +NNVM_REGISTER_OP(SoftmaxOutput) .describe(R"code(Computes the gradient of cross entropy loss with respect to softmax output. - This operator computes the gradient in two steps. @@ -121,23 +228,41 @@ MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) - ``'valid'``: divide the gradient by the number of instances which are not ignored. )code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", SoftmaxOutputStorageType) +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) +#endif +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", SoftmaxOutputShape) +.set_attr("FInferType", SoftmaxOutputType) +.set_attr("FCompute", SoftmaxOutputCompute) +.set_attr("FGradient", SoftmaxOutputGrad{"_backward_SoftmaxOutput"}) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) .add_argument("data", "NDArray-or-Symbol", "Input array.") .add_argument("label", "NDArray-or-Symbol", "Ground truth label.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); +// Softmax symbol is renamed to SoftmaxOutput and deprecated since Dec, 2015 +NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax"); -MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp) -.describe(R"code(Please use `SoftmaxOutput`. - -.. note:: - - This operator has been renamed to `SoftmaxOutput`, which - computes the gradient of cross-entropy loss w.r.t softmax output. - To just compute softmax output, use the `softmax` operator. - -)code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input array.") -.add_arguments(SoftmaxOutputParam::__FIELDS__()); - +NNVM_REGISTER_OP(_backward_SoftmaxOutput) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr_parser(ParamParser) +.set_attr("FCompute", SoftmaxOutputGradCompute); } // namespace op } // namespace mxnet diff --git a/src/operator/softmax_output.cu b/src/operator/softmax_output.cu index afcc8f4fc6bd..b2a41672e92a 100644 --- a/src/operator/softmax_output.cu +++ b/src/operator/softmax_output.cu @@ -28,14 +28,12 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(SoftmaxOutputParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new SoftmaxOutputOp(param); - }) - return op; -} + +NNVM_REGISTER_OP(SoftmaxOutput) +.set_attr("FCompute", SoftmaxOutputCompute); + +NNVM_REGISTER_OP(_backward_SoftmaxOutput) +.set_attr("FCompute", SoftmaxOutputGradCompute); } // namespace op } // namespace mxnet