From 0753b198ce4d371b4fd7aa8012cebb1c304ebfc6 Mon Sep 17 00:00:00 2001 From: PatricZhao Date: Fri, 19 Jan 2018 14:04:23 +0800 Subject: [PATCH] LRN coding style changes (#21) * LRN coding style change * Add const for local variables * Add req for LRN forward * rebase code * align API interface * revert modification in test_executor. --- src/operator/nn/lrn.cc | 59 ++++++++--------- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 84 +++++++++++++------------ 2 files changed, 75 insertions(+), 68 deletions(-) diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 00cac28f2484..2605dfe28930 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -21,10 +21,11 @@ * Copyright (c) 2015 by Contributors * \file lrn.cc * \brief - * \author Bing Xu + * \author Bing Xu, Patric Zhao (patric.zhao@intel.com) */ #include "./lrn-inl.h" +#include "../operator_common.h" #if MXNET_USE_MKLDNN == 1 #include "./mkldnn/mkldnn_lrn-inl.h" #endif @@ -32,9 +33,9 @@ namespace mxnet { namespace op { -static bool LRNShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shape, - std::vector *out_shape) { +bool LRNShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { using namespace mshadow; CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; const TShape &dshape = in_shape->at(0); @@ -45,13 +46,13 @@ static bool LRNShape(const nnvm::NodeAttrs& attrs, return true; } -static inline std::vector ListArguments() { +inline std::vector ListArguments() { return {"data"}; } -static bool LRNType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { +bool LRNType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; @@ -80,37 +81,39 @@ struct LRNGrad { } }; -inline static bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { CHECK(!in_attrs->empty()); +#if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) { - *dispatch_mode = DispatchMode::kFComputeEx; + storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); + return true; } #endif - for (size_t i = 0; i < out_attrs->size(); i++) - (*out_attrs)[i] = kDefaultStorage; + storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); return true; } -inline static bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { CHECK(!in_attrs->empty()); +#if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) { - *dispatch_mode = DispatchMode::kFComputeEx; + storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); + return true; } #endif - for (size_t i = 0; i < out_attrs->size(); i++) - (*out_attrs)[i] = kDefaultStorage; + storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); return true; } diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index e0ecc1873d96..40ab8466370e 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -33,54 +33,58 @@ namespace mxnet { namespace op { -static inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { +inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { // TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward // Need to confirm with MKLDNN team and fix later return algorithm::lrn_across_channels; } -inline static lrn_forward::primitive_desc GetLRNFwd( - const LRNParam ¶m, bool is_train, const memory::desc &src_md) { - auto engine = CpuEngine::Get()->get_engine(); - auto alg_ = GetMKLDNNLRNAlgo(param); - auto alpha_ = param.alpha; - auto beta_ = param.beta; - auto nsize_ = param.nsize; - auto k_ = param.knorm; - auto kind_ = prop_kind::forward_training; +inline lrn_forward::primitive_desc GetLRNFwd(const LRNParam ¶m, + const bool is_train, + const memory::desc &src_md) { + const auto engine = CpuEngine::Get()->get_engine(); + const auto alg = GetMKLDNNLRNAlgo(param); + const float alpha = param.alpha; + const float beta = param.beta; + const int nsize = param.nsize; + const float k = param.knorm; + auto kind = prop_kind::forward_training; if (is_train) { - kind_ = prop_kind::forward_training; + kind = prop_kind::forward_training; } else { - kind_ = prop_kind::forward_scoring; + kind = prop_kind::forward_scoring; } - lrn_forward::desc fwd_desc_(kind_, alg_, src_md, nsize_, alpha_, beta_, k_); - return mkldnn::lrn_forward::primitive_desc(fwd_desc_, engine); + lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k); + return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine); } -inline static mkldnn::lrn_backward::primitive_desc GetLRNBwd( - const LRNParam ¶m, const mkldnn::memory::desc &diff_in_md, - const mkldnn::memory::desc &diff_md, - const lrn_forward::primitive_desc &lrnFwd_desc) { - auto engine = CpuEngine::Get()->get_engine(); - auto alg_ = GetMKLDNNLRNAlgo(param); - auto alpha_ = param.alpha; - auto beta_ = param.beta; - int nsize_ = param.nsize; - auto k_ = param.knorm; +inline mkldnn::lrn_backward::primitive_desc +GetLRNBwd(const LRNParam ¶m, + const mkldnn::memory::desc &diff_in_md, + const mkldnn::memory::desc &diff_md, + const lrn_forward::primitive_desc &lrnFwd_desc) { + const auto engine = CpuEngine::Get()->get_engine(); + const auto alg = GetMKLDNNLRNAlgo(param); + const float alpha = param.alpha; + const float beta = param.beta; + const int nsize = param.nsize; + const float k = param.knorm; - lrn_backward::desc lrnBwd_desc(alg_, diff_in_md, - diff_md, nsize_, alpha_, beta_, k_); + lrn_backward::desc lrnBwd_desc(alg, diff_in_md, + diff_md, nsize, alpha, beta, k); return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc, engine, lrnFwd_desc); } -void MKLDNNLRN_Forward(const OpContext &ctx, const LRNParam ¶m, - const NDArray &in_data, const OpReqType &req, +void MKLDNNLRN_Forward(const OpContext &ctx, + const LRNParam ¶m, + const NDArray &in_data, + const OpReqType req, const NDArray &out_data) { auto src_mem = in_data.GetMKLDNNData(); - auto src_md = src_mem->get_primitive_desc().desc(); - auto pdesc = GetLRNFwd(param, ctx.is_train, src_md); - auto dst_mem = const_cast(out_data).CreateMKLDNNData( + const auto src_md = src_mem->get_primitive_desc().desc(); + const auto pdesc = GetLRNFwd(param, ctx.is_train, src_md); + auto dst_mem = const_cast(out_data).CreateMKLDNNData( pdesc.dst_primitive_desc()); if (ctx.is_train) { std::shared_ptr ws_mem( @@ -97,17 +101,17 @@ void MKLDNNLRN_Forward(const OpContext &ctx, const LRNParam ¶m, } void MKLDNNLRN_Backward(const OpContext &ctx, const LRNParam ¶m, - const NDArray &out_grad, - const NDArray &in_data, - const OpReqType &req, - const NDArray &in_grad) { + const NDArray &out_grad, + const NDArray &in_data, + const OpReqType req, + const NDArray &in_grad) { if (req == kNullOp) { return; } // Repeat FW for getting workspace auto data_mem = in_data.GetMKLDNNData(); - auto data_md = data_mem->get_primitive_desc().desc(); - auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md); + const auto data_md = data_mem->get_primitive_desc().desc(); + const auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md); // TODO(Patric): To keep the function stateless, we can't pass workspace // from LRN forward to backward. We have to re-compute @@ -121,10 +125,10 @@ void MKLDNNLRN_Backward(const OpContext &ctx, const LRNParam ¶m, lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem), *ws_mem, *dst_temp)); - auto data_in_md = pdesc_fwd.src_primitive_desc().desc(); + const auto data_in_md = pdesc_fwd.src_primitive_desc().desc(); auto diff_mem = out_grad.GetMKLDNNData(); - auto diff_md = diff_mem->get_primitive_desc().desc(); - auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd); + const auto diff_md = diff_mem->get_primitive_desc().desc(); + const auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd); auto diff_src_mem = CreateMKLDNNMem(in_grad, pdesc_bwd.diff_src_primitive_desc(), req);