From 0c81bf2c0e432d23fcb88fdf232efbf58fca8b69 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Thu, 20 Dec 2018 16:40:38 +0800 Subject: [PATCH 01/16] add mkldnn softmax_output --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 + src/operator/nn/mkldnn/mkldnn_ops-inl.h | 4 + .../nn/mkldnn/mkldnn_softmax_output.cc | 149 ++++++++++ src/operator/softmax_output-inl.h | 68 +++++ src/operator/softmax_output.cc | 266 ++++++++++++++---- 5 files changed, 428 insertions(+), 61 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_softmax_output.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index f770c4aba350..bf220b847c0e 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -174,11 +174,13 @@ struct ActivationParam; struct ConvolutionParam; struct DeconvolutionParam; struct SoftmaxParam; +struct SoftmaxOutputParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); bool SupportMKLDNNSoftmax(const SoftmaxParam& param); +bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 50937706d934..6640fd49b52f 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -75,6 +75,10 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext & void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); +/* For softmax_output */ +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, const std::vector &req, + const std::vector &out_data); /* For sum */ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc new file mode 100644 index 000000000000..a896bbfb13e8 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_softmax_output.cc + * \brief + * \author Zhang Rong A +*/ + +#include "../../softmax_output-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( + const SoftmaxOutputParam& param, bool is_train, + const NDArray &data, const mkldnn::memory &input_mem) { + mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); + mkldnn::memory::desc data_md = data_mpd.desc(); + auto cpu_engine = data_mpd.get_engine(); + int axis = data.shape().ndim() - 1; + mkldnn::softmax_forward::desc desc = is_train + ? mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_training, + data_md, axis) + : mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_scoring, + data_md, axis); + return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); +} + +typedef ParamOpSign MKLDNNSoftmaxOuputSignature; + +class MKLDNNSoftmaxOutputFwd { + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr out; + + public: + const mkldnn::softmax_forward::primitive_desc fwd_pd; + + MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train, + const NDArray &data, const mkldnn::memory &mem): fwd_pd( + GetSoftmaxOutputFwdDescImpl(param, is_train, data, mem)) { + } + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + data.get_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + output.get_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (this->fwd == nullptr) { + this->fwd = std::shared_ptr( + new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data), + *this->out)); + } + } + + const mkldnn::softmax_forward &GetFwd() const { + return *fwd; + } +}; + +static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param, + const OpContext &ctx, const NDArray &in_data, + const mkldnn::memory &in_mem) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local + std::unordered_map fwds; +#else + static MX_THREAD_LOCAL + std::unordered_map fwds; +#endif + MKLDNNSoftmaxOuputSignature key(param); + key.AddSign(ctx.is_train); + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, in_data, in_mem); + auto ins_ret = fwds.insert(std::pair( + key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + + +bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m) { + return param.multi_output ? false : true; +} + +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, const std::vector &req, + const std::vector &out_data) { + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + + NDArray idata = in_data[softmaxout_enum::kData]; + NDArray odata = out_data[softmaxout_enum::kOut]; + if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) { + idata = in_data[softmaxout_enum::kData].Reorder2Default(); + } + + auto input_mem = idata.GetMKLDNNData(); + auto output_mem = odata.GetMKLDNNData(); + + MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata, *input_mem); + fwd.SetNewMem(*input_mem, *output_mem); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrim(fwd.GetFwd()); + stream->Submit(); +} + +} // namespace op +} // namespace mxnet + + +#endif + + + + + diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index fec321b97e4c..e64b011daf93 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -88,6 +88,16 @@ struct SoftmaxOutputParam : public dmlc::Parameter { "one-hot encoding of the gold label and distributed uniformly to" "all other labels."); }; + bool operator==(const SoftmaxOutputParam& other) const { + return this->grad_scale == other.grad_scale && + this->ignore_label == other.ignore_label && + this->multi_output == other.multi_output && + this->use_ignore == other.use_ignore && + this->preserve_shape == other.preserve_shape && + this->normalization == other.normalization && + this->out_grad == other.out_grad && + this->smooth_alpha == other.smooth_alpha; + } }; template @@ -271,6 +281,45 @@ class SoftmaxOutputOp : public Operator { template Operator* CreateOp(SoftmaxOutputParam param, int dtype); +template +void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + CHECK_EQ(inputs.size(), 2U); + std::vector in_data(inputs.begin(), + inputs.begin() + softmaxout_enum::kLabel); + MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, { + SoftmaxOutputOpop(param); + op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); + }); +} + +template +void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + CHECK_EQ(inputs.size(), 2U); + + std::vector out_grad{inputs[0]}; + std::vector out_data{inputs[0]}; + std::vector in_data(inputs.begin(), inputs.end()); + int dtype = inputs[0].type_flag_; + const std::vector &in_grad = outputs; + + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + SoftmaxOutputOpop(param); + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api); + }); +} + + #if DMLC_USE_CXX11 class SoftmaxOutputProp : public OperatorProperty { public: @@ -414,4 +463,23 @@ class DeprecatedSoftmaxProp : public SoftmaxOutputProp { } // namespace op } // namespace mxnet + +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::SoftmaxOutputParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.grad_scale); + ret = dmlc::HashCombine(ret, val.ignore_label); + ret = dmlc::HashCombine(ret, val.multi_output); + ret = dmlc::HashCombine(ret, val.use_ignore); + ret = dmlc::HashCombine(ret, val.preserve_shape); + ret = dmlc::HashCombine(ret, val.normalization); + ret = dmlc::HashCombine(ret, val.out_grad); + ret = dmlc::HashCombine(ret, val.smooth_alpha); + return ret; + } +}; +} // namespace std + #endif // MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 5ba421fd195b..2b2ef6152bc1 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -21,30 +21,142 @@ * Copyright (c) 2015 by Contributors * \file softmax_output.cc * \brief - * \author Bing Xu + * \author Bing Xu, Zhang Rong A */ #include "./softmax_output-inl.h" - +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn_ops-inl.h" +#endif namespace mxnet { namespace op { -template<> -Operator *CreateOp(SoftmaxOutputParam param, int dtype) { - Operator *op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new SoftmaxOutputOp(param); - }) - return op; + +DMLC_REGISTER_PARAMETER(SoftmaxOutputParam); +struct SoftmaxOutputGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + std::vector out_data(n->num_outputs()); + for (uint32_t i = 0; i < out_data.size(); ++i) { + out_data[i] = nnvm::NodeEntry{n, i, 0}; + } + std::vector heads; + heads.push_back(out_data[softmaxout_enum::kOut]); + heads.push_back(n->inputs[softmaxout_enum::kLabel]); + + nnvm::NodePtr gnode = nnvm::Node::Create(); + gnode->inputs = std::move(heads); + gnode->control_deps.emplace_back(n); + gnode->attrs = n->attrs; + gnode->attrs.op = nnvm::Op::Get("_backward_SoftmaxOutput"); + gnode->attrs.name = n->attrs.name + "_backward"; + std::vector in_grad(2); + for (uint32_t i = 0; i < 2; ++i) { + in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; + } + return in_grad; + } +}; + +static inline std::vector ListArguments() { + return {"data", "label"}; +} + + +static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + CHECK_GE(in_type->size(), 2U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); + } + } + out_type->clear(); + out_type->push_back(dtype); + return true; } -// DO_BIND_DISPATCH comes from operator_common.h -Operator *SoftmaxOutputProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + + // label.shape == data.shape: use probability as label + if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { + if (param.multi_output) { + TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); + TShape lshape2(dshape.ndim() - 1); + lshape2[0] = dshape[0]; + for (index_t i = 2; i < dshape.ndim(); ++i) + lshape2[i-1] = dshape[i]; + TShape lshape3 = dshape; + lshape3[1] = 1; + if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { + in_shape->at(softmaxout_enum::kLabel) = lshape1; + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) { + } else { + std::ostringstream os; + os << "Expecting " << lshape1 << " or " << lshape2 + << ". But got " << in_shape->at(softmaxout_enum::kLabel); + throw InferShapeError(os.str(), softmaxout_enum::kLabel); + } + } else { + TShape label_shape(dshape.ndim() - 1); + for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + label_shape[i] = dshape[i]; + SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); + } + } + + out_shape->clear(); + out_shape->push_back(dshape); + return true; } -DMLC_REGISTER_PARAMETER(SoftmaxOutputParam); +#if MXNET_USE_MKLDNN == 1 +inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} + +void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 2U); + const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); + // MKLDNN softmaxOutput only works well on the special MKLDNN layout. + if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs); + auto fn = SoftmaxOutputCompute; + MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); +} -MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) +#endif + +NNVM_REGISTER_OP(SoftmaxOutput) +MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check .describe(R"code(Computes the gradient of cross entropy loss with respect to softmax output. - This operator computes the gradient in two steps. @@ -57,15 +169,15 @@ MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) - Softmax Function: - .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} + .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} - Cross Entropy Function: - .. math:: \text{CE(label, output)} = - \sum_i \text{label}_i \log(\text{output}_i) + .. math:: \text{CE(label, output)} = - \sum_i \text{label}_i \log(\text{output}_i) - The gradient of cross entropy loss w.r.t softmax output: - .. math:: \text{gradient} = \text{output} - \text{label} + .. math:: \text{gradient} = \text{output} - \text{label} - During forward propagation, the softmax function is computed for each instance in the input array. @@ -74,70 +186,102 @@ MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) and `multi_output` to specify the way to compute softmax: - By default, `preserve_shape` is ``false``. This operator will reshape the input array - into a 2-D array with shape :math:`(d_1, \frac{s}{d_1})` and then compute the softmax function for - each row in the reshaped array, and afterwards reshape it back to the original shape - :math:`(d_1, d_2, ..., d_n)`. + into a 2-D array with shape :math:`(d_1, \frac{s}{d_1})` and then compute the softmax function for + each row in the reshaped array, and afterwards reshape it back to the original shape + :math:`(d_1, d_2, ..., d_n)`. - If `preserve_shape` is ``true``, the softmax function will be computed along - the last axis (`axis` = ``-1``). + the last axis (`axis` = ``-1``). - If `multi_output` is ``true``, the softmax function will be computed along - the second axis (`axis` = ``1``). + the second axis (`axis` = ``1``). - During backward propagation, the gradient of cross-entropy loss w.r.t softmax output array is computed. The provided label can be a one-hot label array or a probability label array. - If the parameter `use_ignore` is ``true``, `ignore_label` can specify input instances - with a particular label to be ignored during backward propagation. **This has no effect when - softmax `output` has same shape as `label`**. - - Example:: - - data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] - label = [1,0,2,3] - ignore_label = 1 - SoftmaxOutput(data=data, label = label,\ - multi_output=true, use_ignore=true,\ - ignore_label=ignore_label) - ## forward softmax output - [[ 0.0320586 0.08714432 0.23688284 0.64391428] - [ 0.25 0.25 0.25 0.25 ] - [ 0.25 0.25 0.25 0.25 ] - [ 0.25 0.25 0.25 0.25 ]] - ## backward gradient output - [[ 0. 0. 0. 0. ] - [-0.75 0.25 0.25 0.25] - [ 0.25 0.25 -0.75 0.25] - [ 0.25 0.25 0.25 -0.75]] - ## notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label. + with a particular label to be ignored during backward propagation. **This has no effect when + softmax `output` has same shape as `label`**. + + Example:: + + data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] + label = [1,0,2,3] + ignore_label = 1 + SoftmaxOutput(data=data, label = label,\ + multi_output=true, use_ignore=true,\ + ignore_label=ignore_label) + ## forward softmax output + [[ 0.0320586 0.08714432 0.23688284 0.64391428] + [ 0.25 0.25 0.25 0.25 ] + [ 0.25 0.25 0.25 0.25 ] + [ 0.25 0.25 0.25 0.25 ]] + ## backward gradient output + [[ 0. 0. 0. 0. ] + [-0.75 0.25 0.25 0.25] + [ 0.25 0.25 -0.75 0.25] + [ 0.25 0.25 0.25 -0.75]] + ## notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label. - The parameter `grad_scale` can be used to rescale the gradient, which is often used to - give each loss function different weights. + give each loss function different weights. - This operator also supports various ways to normalize the gradient by `normalization`, - The `normalization` is applied if softmax output has different shape than the labels. - The `normalization` mode can be set to the followings: + The `normalization` is applied if softmax output has different shape than the labels. + The `normalization` mode can be set to the followings: - - ``'null'``: do nothing. - - ``'batch'``: divide the gradient by the batch size. - - ``'valid'``: divide the gradient by the number of instances which are not ignored. + - ``'null'``: do nothing. + - ``'batch'``: divide the gradient by the batch size. + - ``'valid'``: divide the gradient by the number of instances which are not ignored. )code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", SoftmaxOutputStorageType) +#endif +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif +.set_attr("FInferShape", SoftmaxOutputShape) +.set_attr("FInferType", SoftmaxOutputType) +.set_attr("FCompute", SoftmaxOutputCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) +#endif +.set_attr("FGradient", SoftmaxOutputGrad{"_backward_SoftmaxOutput"}) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) .add_argument("data", "NDArray-or-Symbol", "Input array.") .add_argument("label", "NDArray-or-Symbol", "Ground truth label.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); -MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp) -.describe(R"code(Please use `SoftmaxOutput`. -.. note:: +NNVM_REGISTER_OP(_backward_SoftmaxOutput) +.set_num_outputs(2) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr_parser(ParamParser) +.set_attr("FCompute", SoftmaxOutputGradCompute); - This operator has been renamed to `SoftmaxOutput`, which - computes the gradient of cross-entropy loss w.r.t softmax output. - To just compute softmax output, use the `softmax` operator. - -)code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input array.") -.add_arguments(SoftmaxOutputParam::__FIELDS__()); } // namespace op } // namespace mxnet + + From 8cc2eec106e1153536f0bf604737abcaace800da Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Thu, 20 Dec 2018 22:10:40 +0800 Subject: [PATCH 02/16] fix gpu OP unittest error --- src/operator/softmax_output.cu | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/operator/softmax_output.cu b/src/operator/softmax_output.cu index afcc8f4fc6bd..b2a41672e92a 100644 --- a/src/operator/softmax_output.cu +++ b/src/operator/softmax_output.cu @@ -28,14 +28,12 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(SoftmaxOutputParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new SoftmaxOutputOp(param); - }) - return op; -} + +NNVM_REGISTER_OP(SoftmaxOutput) +.set_attr("FCompute", SoftmaxOutputCompute); + +NNVM_REGISTER_OP(_backward_SoftmaxOutput) +.set_attr("FCompute", SoftmaxOutputGradCompute); } // namespace op } // namespace mxnet From ea81bbdec86d0de6f11fa36c787e6369e921b018 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Fri, 21 Dec 2018 09:46:24 +0800 Subject: [PATCH 03/16] fix ci/jenkins/mxnet-validation/unix-gpu compiler error --- src/operator/softmax_output.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 2b2ef6152bc1..34ad673d2d0d 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -25,7 +25,7 @@ */ #include "./softmax_output-inl.h" #if MXNET_USE_MKLDNN == 1 -#include "./mkldnn_ops-inl.h" +#include "./nn/mkldnn/mkldnn_ops-inl.h" #endif namespace mxnet { namespace op { From 403dc37ffc86dcd16bf5752d2d37837a43eba48e Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 25 Dec 2018 11:07:54 +0800 Subject: [PATCH 04/16] fix coding style --- src/operator/nn/mkldnn/mkldnn_ops-inl.h | 10 +- .../nn/mkldnn/mkldnn_softmax_output.cc | 67 +++--- src/operator/softmax_output-inl.h | 27 +-- src/operator/softmax_output.cc | 206 +++++++++--------- 4 files changed, 147 insertions(+), 163 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 6640fd49b52f..39f26325b2a5 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -75,10 +75,12 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext & void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); + /* For softmax_output */ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, const std::vector &req, - const std::vector &out_data); + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); /* For sum */ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -87,8 +89,8 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, /* For copy */ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); /* For concat */ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index a896bbfb13e8..20a576a4e234 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -32,8 +32,8 @@ namespace mxnet { namespace op { static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( - const SoftmaxOutputParam& param, bool is_train, - const NDArray &data, const mkldnn::memory &input_mem) { + const SoftmaxOutputParam& param, bool is_train, + const NDArray &data, const mkldnn::memory &input_mem) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); auto cpu_engine = data_mpd.get_engine(); @@ -57,38 +57,39 @@ class MKLDNNSoftmaxOutputFwd { const mkldnn::softmax_forward::primitive_desc fwd_pd; MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train, - const NDArray &data, const mkldnn::memory &mem): fwd_pd( - GetSoftmaxOutputFwdDescImpl(param, is_train, data, mem)) { + const NDArray &data, const mkldnn::memory &mem): fwd_pd( + GetSoftmaxOutputFwdDescImpl(param, is_train, data, mem)) { } void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( data.get_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( - output.get_primitive_desc(), output.get_data_handle())); - else - this->out->set_data_handle(output.get_data_handle()); - - if (this->fwd == nullptr) { - this->fwd = std::shared_ptr( - new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data), - *this->out)); - } + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + output.get_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (this->fwd == nullptr) { + this->fwd = std::shared_ptr( + new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data), + *this->out)); + } } const mkldnn::softmax_forward &GetFwd() const { - return *fwd; + return *fwd; } }; static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param, - const OpContext &ctx, const NDArray &in_data, - const mkldnn::memory &in_mem) { + const OpContext &ctx, + const NDArray &in_data, + const mkldnn::memory &in_mem) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; @@ -102,11 +103,8 @@ static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, in_data, in_mem); - auto ins_ret = fwds.insert(std::pair( - key, fwd)); - CHECK(ins_ret.second); - it = ins_ret.first; + MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, in_data, in_mem); + it = AddToCache(&fwds, key, fwd); } return it->second; } @@ -116,9 +114,11 @@ bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m) { return param.multi_output ? false : true; } -void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); NDArray idata = in_data[softmaxout_enum::kData]; @@ -142,8 +142,3 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &c #endif - - - - - diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index e64b011daf93..c38fc018afee 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -88,6 +88,7 @@ struct SoftmaxOutputParam : public dmlc::Parameter { "one-hot encoding of the gold label and distributed uniformly to" "all other labels."); }; + bool operator==(const SoftmaxOutputParam& other) const { return this->grad_scale == other.grad_scale && this->ignore_label == other.ignore_label && @@ -277,32 +278,28 @@ class SoftmaxOutputOp : public Operator { SoftmaxOutputParam param_; }; // class SoftmaxOutputOp -// Decalre Factory function, used for dispatch specialization -template -Operator* CreateOp(SoftmaxOutputParam param, int dtype); - template void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); const std::vector no_use_but_adapt_origin_api; CHECK_EQ(inputs.size(), 2U); std::vector in_data(inputs.begin(), inputs.begin() + softmaxout_enum::kLabel); MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, { - SoftmaxOutputOpop(param); - op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); + SoftmaxOutputOp op(param); + op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); }); } template void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); const std::vector no_use_but_adapt_origin_api; CHECK_EQ(inputs.size(), 2U); @@ -314,8 +311,8 @@ void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs, const std::vector &in_grad = outputs; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - SoftmaxOutputOpop(param); - op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api); + SoftmaxOutputOp op(param); + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api); }); } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 34ad673d2d0d..6109951af91f 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -58,76 +58,77 @@ struct SoftmaxOutputGrad { }; static inline std::vector ListArguments() { - return {"data", "label"}; + return {"data", "label"}; } static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, std::vector *out_type) { - CHECK_GE(in_type->size(), 2U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - for (size_t i = 0; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); - } + std::vector *in_type, + std::vector *out_type) { + CHECK_GE(in_type->size(), 2U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } - out_type->clear(); - out_type->push_back(dtype); - return true; + } + out_type->clear(); + out_type->push_back(dtype); + return true; } static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shape, - std::vector *out_shape) { - using namespace mshadow; + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; const SoftmaxOutputParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; - const TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; - - // label.shape == data.shape: use probability as label - if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { - if (param.multi_output) { - TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); - TShape lshape2(dshape.ndim() - 1); - lshape2[0] = dshape[0]; - for (index_t i = 2; i < dshape.ndim(); ++i) - lshape2[i-1] = dshape[i]; - TShape lshape3 = dshape; - lshape3[1] = 1; - if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { - in_shape->at(softmaxout_enum::kLabel) = lshape1; - } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { - } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { - } else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) { - } else { - std::ostringstream os; - os << "Expecting " << lshape1 << " or " << lshape2 - << ". But got " << in_shape->at(softmaxout_enum::kLabel); - throw InferShapeError(os.str(), softmaxout_enum::kLabel); - } + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + + // label.shape == data.shape: use probability as label + if (dshape != (*in_shape)[softmaxout_enum::kLabel]) { + if (param.multi_output) { + TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]); + TShape lshape2(dshape.ndim() - 1); + lshape2[0] = dshape[0]; + for (index_t i = 2; i < dshape.ndim(); ++i) + lshape2[i-1] = dshape[i]; + TShape lshape3 = dshape; + lshape3[1] = 1; + if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) { + in_shape->at(softmaxout_enum::kLabel) = lshape1; + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) { + } else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) { } else { - TShape label_shape(dshape.ndim() - 1); - for (index_t i = 0; i + 1 < dshape.ndim(); ++i) - label_shape[i] = dshape[i]; - SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); + std::ostringstream os; + os << "Expecting " << lshape1 << " or " << lshape2 + << ". But got " << in_shape->at(softmaxout_enum::kLabel); + throw InferShapeError(os.str(), softmaxout_enum::kLabel); } + } else { + TShape label_shape(dshape.ndim() - 1); + for (index_t i = 0; i + 1 < dshape.ndim(); ++i) + label_shape[i] = dshape[i]; + SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape); } + } - out_shape->clear(); - out_shape->push_back(dshape); - return true; + out_shape->clear(); + out_shape->push_back(dshape); + return true; } #if MXNET_USE_MKLDNN == 1 inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs) { + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); @@ -136,10 +137,10 @@ inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs, } void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { CHECK_EQ(inputs.size(), 2U); const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); // MKLDNN softmaxOutput only works well on the special MKLDNN layout. @@ -156,7 +157,6 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, #endif NNVM_REGISTER_OP(SoftmaxOutput) -MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check .describe(R"code(Computes the gradient of cross entropy loss with respect to softmax output. - This operator computes the gradient in two steps. @@ -169,15 +169,15 @@ MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check - Softmax Function: - .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} + .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} - Cross Entropy Function: - .. math:: \text{CE(label, output)} = - \sum_i \text{label}_i \log(\text{output}_i) + .. math:: \text{CE(label, output)} = - \sum_i \text{label}_i \log(\text{output}_i) - The gradient of cross entropy loss w.r.t softmax output: - .. math:: \text{gradient} = \text{output} - \text{label} + .. math:: \text{gradient} = \text{output} - \text{label} - During forward propagation, the softmax function is computed for each instance in the input array. @@ -186,51 +186,51 @@ MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check and `multi_output` to specify the way to compute softmax: - By default, `preserve_shape` is ``false``. This operator will reshape the input array - into a 2-D array with shape :math:`(d_1, \frac{s}{d_1})` and then compute the softmax function for - each row in the reshaped array, and afterwards reshape it back to the original shape - :math:`(d_1, d_2, ..., d_n)`. + into a 2-D array with shape :math:`(d_1, \frac{s}{d_1})` and then compute the softmax function for + each row in the reshaped array, and afterwards reshape it back to the original shape + :math:`(d_1, d_2, ..., d_n)`. - If `preserve_shape` is ``true``, the softmax function will be computed along - the last axis (`axis` = ``-1``). + the last axis (`axis` = ``-1``). - If `multi_output` is ``true``, the softmax function will be computed along - the second axis (`axis` = ``1``). + the second axis (`axis` = ``1``). - During backward propagation, the gradient of cross-entropy loss w.r.t softmax output array is computed. The provided label can be a one-hot label array or a probability label array. - If the parameter `use_ignore` is ``true``, `ignore_label` can specify input instances - with a particular label to be ignored during backward propagation. **This has no effect when - softmax `output` has same shape as `label`**. - - Example:: - - data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] - label = [1,0,2,3] - ignore_label = 1 - SoftmaxOutput(data=data, label = label,\ - multi_output=true, use_ignore=true,\ - ignore_label=ignore_label) - ## forward softmax output - [[ 0.0320586 0.08714432 0.23688284 0.64391428] - [ 0.25 0.25 0.25 0.25 ] - [ 0.25 0.25 0.25 0.25 ] - [ 0.25 0.25 0.25 0.25 ]] - ## backward gradient output - [[ 0. 0. 0. 0. ] - [-0.75 0.25 0.25 0.25] - [ 0.25 0.25 -0.75 0.25] - [ 0.25 0.25 0.25 -0.75]] - ## notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label. + with a particular label to be ignored during backward propagation. **This has no effect when + softmax `output` has same shape as `label`**. + + Example:: + + data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] + label = [1,0,2,3] + ignore_label = 1 + SoftmaxOutput(data=data, label = label,\ + multi_output=true, use_ignore=true,\ + ignore_label=ignore_label) + ## forward softmax output + [[ 0.0320586 0.08714432 0.23688284 0.64391428] + [ 0.25 0.25 0.25 0.25 ] + [ 0.25 0.25 0.25 0.25 ] + [ 0.25 0.25 0.25 0.25 ]] + ## backward gradient output + [[ 0. 0. 0. 0. ] + [-0.75 0.25 0.25 0.25] + [ 0.25 0.25 -0.75 0.25] + [ 0.25 0.25 0.25 -0.75]] + ## notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label. - The parameter `grad_scale` can be used to rescale the gradient, which is often used to - give each loss function different weights. + give each loss function different weights. - This operator also supports various ways to normalize the gradient by `normalization`, - The `normalization` is applied if softmax output has different shape than the labels. - The `normalization` mode can be set to the followings: + The `normalization` is applied if softmax output has different shape than the labels. + The `normalization` mode can be set to the followings: - - ``'null'``: do nothing. - - ``'batch'``: divide the gradient by the batch size. - - ``'valid'``: divide the gradient by the number of instances which are not ignored. + - ``'null'``: do nothing. + - ``'batch'``: divide the gradient by the batch size. + - ``'valid'``: divide the gradient by the number of instances which are not ignored. )code" ADD_FILELINE) .set_num_inputs(2) @@ -240,24 +240,18 @@ MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check .set_attr("FInferStorageType", SoftmaxOutputStorageType) #endif .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - return std::vector{"data", "label"}; + return std::vector{"data", "label"}; }) -.set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output"}; +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; }) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) +.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) #endif .set_attr("FInferShape", SoftmaxOutputShape) .set_attr("FInferType", SoftmaxOutputType) .set_attr("FCompute", SoftmaxOutputCompute) -#if MXNET_USE_MKLDNN == 1 -.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) -#endif .set_attr("FGradient", SoftmaxOutputGrad{"_backward_SoftmaxOutput"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; @@ -270,9 +264,6 @@ MXNET_ADD_SPARSE_OP_ALIAS(SoftmaxOutput) // to be check NNVM_REGISTER_OP(_backward_SoftmaxOutput) .set_num_outputs(2) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; @@ -280,7 +271,6 @@ NNVM_REGISTER_OP(_backward_SoftmaxOutput) .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxOutputGradCompute); - } // namespace op } // namespace mxnet From dbfff62273eec29f42fe5de48920496b139c2a44 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 25 Dec 2018 14:34:50 +0800 Subject: [PATCH 05/16] fix Tao comments --- src/operator/softmax_output-inl.h | 17 ++++++++--------- src/operator/softmax_output.cc | 14 +++----------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index c38fc018afee..37dd917c0e03 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -91,13 +91,13 @@ struct SoftmaxOutputParam : public dmlc::Parameter { bool operator==(const SoftmaxOutputParam& other) const { return this->grad_scale == other.grad_scale && - this->ignore_label == other.ignore_label && - this->multi_output == other.multi_output && - this->use_ignore == other.use_ignore && - this->preserve_shape == other.preserve_shape && - this->normalization == other.normalization && - this->out_grad == other.out_grad && - this->smooth_alpha == other.smooth_alpha; + this->ignore_label == other.ignore_label && + this->multi_output == other.multi_output && + this->use_ignore == other.use_ignore && + this->preserve_shape == other.preserve_shape && + this->normalization == other.normalization && + this->out_grad == other.out_grad && + this->smooth_alpha == other.smooth_alpha; } }; @@ -286,8 +286,7 @@ void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs, const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); const std::vector no_use_but_adapt_origin_api; CHECK_EQ(inputs.size(), 2U); - std::vector in_data(inputs.begin(), - inputs.begin() + softmaxout_enum::kLabel); + MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, { SoftmaxOutputOp op(param); op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 6109951af91f..76e9ba163729 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -147,8 +147,7 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs); - auto fn = SoftmaxOutputCompute; - MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); return; } FallBackCompute(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); @@ -238,6 +237,8 @@ NNVM_REGISTER_OP(SoftmaxOutput) .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 .set_attr("FInferStorageType", SoftmaxOutputStorageType) +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) #endif .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data", "label"}; @@ -245,10 +246,6 @@ NNVM_REGISTER_OP(SoftmaxOutput) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output"}; }) -#if MXNET_USE_MKLDNN == 1 -.set_attr("TIsMKLDNN", true) -.set_attr("FComputeEx", SoftmaxOutputComputeExCPU) -#endif .set_attr("FInferShape", SoftmaxOutputShape) .set_attr("FInferType", SoftmaxOutputType) .set_attr("FCompute", SoftmaxOutputCompute) @@ -260,8 +257,6 @@ NNVM_REGISTER_OP(SoftmaxOutput) .add_argument("label", "NDArray-or-Symbol", "Ground truth label.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); - - NNVM_REGISTER_OP(_backward_SoftmaxOutput) .set_num_outputs(2) .set_attr("TIsBackward", true) @@ -270,8 +265,5 @@ NNVM_REGISTER_OP(_backward_SoftmaxOutput) }) .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxOutputGradCompute); - } // namespace op } // namespace mxnet - - From 6e4e94626a50cb5daadf133734b2990aa121f4be Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Wed, 26 Dec 2018 10:20:04 +0800 Subject: [PATCH 06/16] remove blank line, fix indentx --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 4 ---- src/operator/softmax_output-inl.h | 14 +++++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 20a576a4e234..f2ffeb958845 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -109,7 +109,6 @@ static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& return it->second; } - bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m) { return param.multi_output ? false : true; } @@ -136,9 +135,6 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, stream->RegisterPrim(fwd.GetFwd()); stream->Submit(); } - } // namespace op } // namespace mxnet - - #endif diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 37dd917c0e03..5a01d3a73a95 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -91,13 +91,13 @@ struct SoftmaxOutputParam : public dmlc::Parameter { bool operator==(const SoftmaxOutputParam& other) const { return this->grad_scale == other.grad_scale && - this->ignore_label == other.ignore_label && - this->multi_output == other.multi_output && - this->use_ignore == other.use_ignore && - this->preserve_shape == other.preserve_shape && - this->normalization == other.normalization && - this->out_grad == other.out_grad && - this->smooth_alpha == other.smooth_alpha; + this->ignore_label == other.ignore_label && + this->multi_output == other.multi_output && + this->use_ignore == other.use_ignore && + this->preserve_shape == other.preserve_shape && + this->normalization == other.normalization && + this->out_grad == other.out_grad && + this->smooth_alpha == other.smooth_alpha; } }; From 38ec7088d513195b35b15889a3416f89b13c8b2f Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Thu, 27 Dec 2018 13:51:45 +0800 Subject: [PATCH 07/16] modify according to sandeep's comments --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index f2ffeb958845..13f26b9060fe 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -19,7 +19,7 @@ /*! * \file mkldnn_softmax_output.cc - * \brief + * \brief integrate mkldnn softmax to softmax_output forward * \author Zhang Rong A */ @@ -37,6 +37,7 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); auto cpu_engine = data_mpd.get_engine(); + // softmax_output has no axis parameter, so use it as it original implement. int axis = data.shape().ndim() - 1; mkldnn::softmax_forward::desc desc = is_train ? mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_training, From 37e9c11a8f15f607f5cb5f39101fea810dcdb4ad Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Fri, 28 Dec 2018 11:03:50 +0800 Subject: [PATCH 08/16] change get CPU engine method, and pravate variable --- .../nn/mkldnn/mkldnn_softmax_output.cc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 13f26b9060fe..ed42addeb63e 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -36,7 +36,7 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( const NDArray &data, const mkldnn::memory &input_mem) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); - auto cpu_engine = data_mpd.get_engine(); + auto cpu_engine = CpuEngine::Get()->get_engine(); // softmax_output has no axis parameter, so use it as it original implement. int axis = data.shape().ndim() - 1; mkldnn::softmax_forward::desc desc = is_train @@ -50,9 +50,9 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( typedef ParamOpSign MKLDNNSoftmaxOuputSignature; class MKLDNNSoftmaxOutputFwd { - std::shared_ptr fwd; - std::shared_ptr data; - std::shared_ptr out; + std::shared_ptr fwd_; + std::shared_ptr data_; + std::shared_ptr out_; public: const mkldnn::softmax_forward::primitive_desc fwd_pd; @@ -63,27 +63,27 @@ class MKLDNNSoftmaxOutputFwd { } void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( + if (this->data_ == nullptr) + this->data_ = std::shared_ptr(new mkldnn::memory( data.get_primitive_desc(), data.get_data_handle())); else - this->data->set_data_handle(data.get_data_handle()); + this->data_->set_data_handle(data.get_data_handle()); - if (this->out == nullptr) - this->out = std::shared_ptr(new mkldnn::memory( + if (this->out_ == nullptr) + this->out_ = std::shared_ptr(new mkldnn::memory( output.get_primitive_desc(), output.get_data_handle())); else - this->out->set_data_handle(output.get_data_handle()); + this->out_->set_data_handle(output.get_data_handle()); - if (this->fwd == nullptr) { - this->fwd = std::shared_ptr( - new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data), - *this->out)); + if (this->fwd_ == nullptr) { + this->fwd_ = std::shared_ptr( + new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_), + *this->out_)); } } const mkldnn::softmax_forward &GetFwd() const { - return *fwd; + return *fwd_; } }; From d7af1118b229b653c629be4f65e176c51ffc8962 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 8 Jan 2019 09:29:34 +0800 Subject: [PATCH 09/16] move macro MXNET_USE_MKLDNN to the head --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index ed42addeb63e..0df65910656e 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -23,11 +23,11 @@ * \author Zhang Rong A */ +#if MXNET_USE_MKLDNN == 1 #include "../../softmax_output-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#if MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { From 58a23bda3ee402e3d52d40a50539e43e18e565ca Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Fri, 11 Jan 2019 14:09:45 +0800 Subject: [PATCH 10/16] modify according to Tao's comments --- .../nn/mkldnn/mkldnn_softmax_output.cc | 27 +++++++++---------- src/operator/softmax_output.cc | 8 +++--- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 0df65910656e..b7955391f945 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -33,17 +33,13 @@ namespace op { static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( const SoftmaxOutputParam& param, bool is_train, - const NDArray &data, const mkldnn::memory &input_mem) { + const int axis, const mkldnn::memory &input_mem) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); auto cpu_engine = CpuEngine::Get()->get_engine(); - // softmax_output has no axis parameter, so use it as it original implement. - int axis = data.shape().ndim() - 1; - mkldnn::softmax_forward::desc desc = is_train - ? mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_training, - data_md, axis) - : mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_scoring, - data_md, axis); + auto prop = is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); } @@ -58,8 +54,8 @@ class MKLDNNSoftmaxOutputFwd { const mkldnn::softmax_forward::primitive_desc fwd_pd; MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train, - const NDArray &data, const mkldnn::memory &mem): fwd_pd( - GetSoftmaxOutputFwdDescImpl(param, is_train, data, mem)) { + const int axis, const mkldnn::memory &mem): fwd_pd( + GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) { } void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { @@ -89,7 +85,7 @@ class MKLDNNSoftmaxOutputFwd { static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param, const OpContext &ctx, - const NDArray &in_data, + const int axis, const mkldnn::memory &in_mem) { #if DMLC_CXX11_THREAD_LOCAL static thread_local @@ -100,16 +96,17 @@ static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& #endif MKLDNNSoftmaxOuputSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(in_data); + key.AddSign(axis); auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, in_data, in_mem); + MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem); it = AddToCache(&fwds, key, fwd); } return it->second; } +// This is only used for forward. For backward ,need double check compatibility bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m) { return param.multi_output ? false : true; } @@ -123,6 +120,8 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, NDArray idata = in_data[softmaxout_enum::kData]; NDArray odata = out_data[softmaxout_enum::kOut]; + // softmax_output has no axis parameter, so use it as it original implement. + int axis = idata.shape().ndim() - 1; if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) { idata = in_data[softmaxout_enum::kData].Reorder2Default(); } @@ -130,7 +129,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, auto input_mem = idata.GetMKLDNNData(); auto output_mem = odata.GetMKLDNNData(); - MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata, *input_mem); + MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, axis, *input_mem); fwd.SetNewMem(*input_mem, *output_mem); MKLDNNStream *stream = MKLDNNStream::Get(); stream->RegisterPrim(fwd.GetFwd()); diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 76e9ba163729..1b7de2d14033 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -50,9 +50,8 @@ struct SoftmaxOutputGrad { gnode->attrs.op = nnvm::Op::Get("_backward_SoftmaxOutput"); gnode->attrs.name = n->attrs.name + "_backward"; std::vector in_grad(2); - for (uint32_t i = 0; i < 2; ++i) { - in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; - } + in_grad[0] = nnvm::NodeEntry{gnode, 0, 0}; + in_grad[1] = nnvm::NodeEntry{gnode, 1, 0}; return in_grad; } }; @@ -65,7 +64,7 @@ static inline std::vector ListArguments() { static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { - CHECK_GE(in_type->size(), 2U); + CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; for (size_t i = 0; i < in_type->size(); ++i) { @@ -143,7 +142,6 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 2U); const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); - // MKLDNN softmaxOutput only works well on the special MKLDNN layout. if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs); From bbf11d79895bc7ec2290776cec43ee29a918630c Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Mon, 14 Jan 2019 16:08:55 +0800 Subject: [PATCH 11/16] make output layout as input --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index b7955391f945..4f8e1d99e3e3 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -96,7 +96,7 @@ static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& #endif MKLDNNSoftmaxOuputSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(axis); + key.AddSign(in_mem); auto it = fwds.find(key); if (it == fwds.end()) { @@ -127,10 +127,12 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, } auto input_mem = idata.GetMKLDNNData(); - auto output_mem = odata.GetMKLDNNData(); + auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut], + input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, axis, *input_mem); - fwd.SetNewMem(*input_mem, *output_mem); + fwd.SetNewMem(*input_mem, *out_mem.second); + MKLDNNStream *stream = MKLDNNStream::Get(); stream->RegisterPrim(fwd.GetFwd()); stream->Submit(); From 3420b61146129aec85c91bcdbc20a9c3ce511c74 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 15 Jan 2019 10:20:11 +0800 Subject: [PATCH 12/16] change API of GetSoftmaxOutputForward --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 4f8e1d99e3e3..17918dabbad3 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -85,8 +85,7 @@ class MKLDNNSoftmaxOutputFwd { static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param, const OpContext &ctx, - const int axis, - const mkldnn::memory &in_mem) { + const NDArray &in_data) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; @@ -96,10 +95,14 @@ static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& #endif MKLDNNSoftmaxOuputSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(in_mem); + key.AddSign(in_data); + + // softmax_output has no axis parameter, so use it as it original implement. + int axis = in_data.shape().ndim() - 1; auto it = fwds.find(key); if (it == fwds.end()) { + auto in_mem = *(in_data.GetMKLDNNData()); MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem); it = AddToCache(&fwds, key, fwd); } @@ -120,8 +123,6 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, NDArray idata = in_data[softmaxout_enum::kData]; NDArray odata = out_data[softmaxout_enum::kOut]; - // softmax_output has no axis parameter, so use it as it original implement. - int axis = idata.shape().ndim() - 1; if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) { idata = in_data[softmaxout_enum::kData].Reorder2Default(); } @@ -130,7 +131,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut], input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); - MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, axis, *input_mem); + MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata); fwd.SetNewMem(*input_mem, *out_mem.second); MKLDNNStream *stream = MKLDNNStream::Get(); From 25162cb64b0e56fb6411dc18456ec059d64ae9cf Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 15 Jan 2019 14:26:04 +0800 Subject: [PATCH 13/16] add CommitOutput for mkldnn_softmax_output --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 17918dabbad3..1beef572b6b6 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -136,6 +136,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, MKLDNNStream *stream = MKLDNNStream::Get(); stream->RegisterPrim(fwd.GetFwd()); + CommitOutput(out_data[softmaxout_enum::kOut], out_mem); stream->Submit(); } } // namespace op From 171da54eb75eb04ab2c6dd290428f576cb0b4a26 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 15 Jan 2019 22:07:47 +0800 Subject: [PATCH 14/16] trigger Jenkins re-test --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index 1beef572b6b6..a5b4a7903439 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -136,6 +136,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, MKLDNNStream *stream = MKLDNNStream::Get(); stream->RegisterPrim(fwd.GetFwd()); + CommitOutput(out_data[softmaxout_enum::kOut], out_mem); stream->Submit(); } From c031f6b5f858f8a04a551c34d0e740b8a76ba81c Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 22 Jan 2019 20:10:20 +0800 Subject: [PATCH 15/16] add alias Softmax symbol for SoftmaxOutput OP --- src/operator/softmax_output.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 1b7de2d14033..f60707f6e2d5 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -255,6 +255,9 @@ NNVM_REGISTER_OP(SoftmaxOutput) .add_argument("label", "NDArray-or-Symbol", "Ground truth label.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); +// Softmax symbol is renamed to SoftmaxOutput and deprecated since Dec, 2015 +NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax"); + NNVM_REGISTER_OP(_backward_SoftmaxOutput) .set_num_outputs(2) .set_attr("TIsBackward", true) From ee845f7866e87be48d750041815a87b1190c1adf Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Tue, 12 Feb 2019 21:17:58 +0800 Subject: [PATCH 16/16] indent and remove blank line --- src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 2 +- src/operator/softmax_output.cc | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc index a5b4a7903439..ae34fe633d6f 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc @@ -129,7 +129,7 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, auto input_mem = idata.GetMKLDNNData(); auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut], - input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); + input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]); MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata); fwd.SetNewMem(*input_mem, *out_mem.second); diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index f60707f6e2d5..322ac0b93426 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -60,7 +60,6 @@ static inline std::vector ListArguments() { return {"data", "label"}; } - static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -150,7 +149,6 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, } FallBackCompute(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); } - #endif NNVM_REGISTER_OP(SoftmaxOutput)