Skip to content

Commit

Permalink
LRN coding style changes (apache#21)
Browse files Browse the repository at this point in the history
* LRN coding style change

* Add const for local variables

* Add req for LRN forward

* rebase code

* align API interface

* revert modification in test_executor.
  • Loading branch information
PatricZhao authored and Olivier committed Feb 6, 2018
1 parent e9fd871 commit 0753b19
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 68 deletions.
59 changes: 31 additions & 28 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
* Copyright (c) 2015 by Contributors
* \file lrn.cc
* \brief
* \author Bing Xu
* \author Bing Xu, Patric Zhao ([email protected])
*/

#include "./lrn-inl.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#include "./mkldnn/mkldnn_lrn-inl.h"
#endif

namespace mxnet {
namespace op {

static bool LRNShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
bool LRNShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
const TShape &dshape = in_shape->at(0);
Expand All @@ -45,13 +46,13 @@ static bool LRNShape(const nnvm::NodeAttrs& attrs,
return true;
}

static inline std::vector<std::string> ListArguments() {
inline std::vector<std::string> ListArguments() {
return {"data"};
}

static bool LRNType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
bool LRNType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
Expand Down Expand Up @@ -80,37 +81,39 @@ struct LRNGrad {
}
};

inline static bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK(!in_attrs->empty());
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
return true;
}
#endif
for (size_t i = 0; i < out_attrs->size(); i++)
(*out_attrs)[i] = kDefaultStorage;
storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
return true;
}

inline static bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
bool LRNBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK(!in_attrs->empty());
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
return true;
}
#endif
for (size_t i = 0; i < out_attrs->size(); i++)
(*out_attrs)[i] = kDefaultStorage;
storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
return true;
}

Expand Down
84 changes: 44 additions & 40 deletions src/operator/nn/mkldnn/mkldnn_lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,54 +33,58 @@
namespace mxnet {
namespace op {

static inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
// TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward
// Need to confirm with MKLDNN team and fix later
return algorithm::lrn_across_channels;
}

inline static lrn_forward::primitive_desc GetLRNFwd(
const LRNParam &param, bool is_train, const memory::desc &src_md) {
auto engine = CpuEngine::Get()->get_engine();
auto alg_ = GetMKLDNNLRNAlgo(param);
auto alpha_ = param.alpha;
auto beta_ = param.beta;
auto nsize_ = param.nsize;
auto k_ = param.knorm;
auto kind_ = prop_kind::forward_training;
inline lrn_forward::primitive_desc GetLRNFwd(const LRNParam &param,
const bool is_train,
const memory::desc &src_md) {
const auto engine = CpuEngine::Get()->get_engine();
const auto alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;
auto kind = prop_kind::forward_training;
if (is_train) {
kind_ = prop_kind::forward_training;
kind = prop_kind::forward_training;
} else {
kind_ = prop_kind::forward_scoring;
kind = prop_kind::forward_scoring;
}
lrn_forward::desc fwd_desc_(kind_, alg_, src_md, nsize_, alpha_, beta_, k_);
return mkldnn::lrn_forward::primitive_desc(fwd_desc_, engine);
lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
}

inline static mkldnn::lrn_backward::primitive_desc GetLRNBwd(
const LRNParam &param, const mkldnn::memory::desc &diff_in_md,
const mkldnn::memory::desc &diff_md,
const lrn_forward::primitive_desc &lrnFwd_desc) {
auto engine = CpuEngine::Get()->get_engine();
auto alg_ = GetMKLDNNLRNAlgo(param);
auto alpha_ = param.alpha;
auto beta_ = param.beta;
int nsize_ = param.nsize;
auto k_ = param.knorm;
inline mkldnn::lrn_backward::primitive_desc
GetLRNBwd(const LRNParam &param,
const mkldnn::memory::desc &diff_in_md,
const mkldnn::memory::desc &diff_md,
const lrn_forward::primitive_desc &lrnFwd_desc) {
const auto engine = CpuEngine::Get()->get_engine();
const auto alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;

lrn_backward::desc lrnBwd_desc(alg_, diff_in_md,
diff_md, nsize_, alpha_, beta_, k_);
lrn_backward::desc lrnBwd_desc(alg, diff_in_md,
diff_md, nsize, alpha, beta, k);
return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc,
engine, lrnFwd_desc);
}

void MKLDNNLRN_Forward(const OpContext &ctx, const LRNParam &param,
const NDArray &in_data, const OpReqType &req,
void MKLDNNLRN_Forward(const OpContext &ctx,
const LRNParam &param,
const NDArray &in_data,
const OpReqType req,
const NDArray &out_data) {
auto src_mem = in_data.GetMKLDNNData();
auto src_md = src_mem->get_primitive_desc().desc();
auto pdesc = GetLRNFwd(param, ctx.is_train, src_md);
auto dst_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
const auto src_md = src_mem->get_primitive_desc().desc();
const auto pdesc = GetLRNFwd(param, ctx.is_train, src_md);
auto dst_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
pdesc.dst_primitive_desc());
if (ctx.is_train) {
std::shared_ptr<const mkldnn::memory> ws_mem(
Expand All @@ -97,17 +101,17 @@ void MKLDNNLRN_Forward(const OpContext &ctx, const LRNParam &param,
}

void MKLDNNLRN_Backward(const OpContext &ctx, const LRNParam &param,
const NDArray &out_grad,
const NDArray &in_data,
const OpReqType &req,
const NDArray &in_grad) {
const NDArray &out_grad,
const NDArray &in_data,
const OpReqType req,
const NDArray &in_grad) {
if (req == kNullOp) {
return;
}
// Repeat FW for getting workspace
auto data_mem = in_data.GetMKLDNNData();
auto data_md = data_mem->get_primitive_desc().desc();
auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md);
const auto data_md = data_mem->get_primitive_desc().desc();
const auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md);

// TODO(Patric): To keep the function stateless, we can't pass workspace
// from LRN forward to backward. We have to re-compute
Expand All @@ -121,10 +125,10 @@ void MKLDNNLRN_Backward(const OpContext &ctx, const LRNParam &param,
lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
*ws_mem, *dst_temp));

auto data_in_md = pdesc_fwd.src_primitive_desc().desc();
const auto data_in_md = pdesc_fwd.src_primitive_desc().desc();
auto diff_mem = out_grad.GetMKLDNNData();
auto diff_md = diff_mem->get_primitive_desc().desc();
auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd);
const auto diff_md = diff_mem->get_primitive_desc().desc();
const auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd);
auto diff_src_mem = CreateMKLDNNMem(in_grad,
pdesc_bwd.diff_src_primitive_desc(), req);

Expand Down

0 comments on commit 0753b19

Please sign in to comment.