Skip to content

Commit

Permalink
use enums in batch norm.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 30, 2017
1 parent a7ed5e3 commit b161709
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
21 changes: 14 additions & 7 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 5U);
std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
std::vector<TBlob> aux_states(inputs.begin() + 3, inputs.end());
std::vector<TBlob> in_data(inputs.begin(),
inputs.begin() + (int) batchnorm::kInMovingMean);
std::vector<TBlob> aux_states(inputs.begin() + (int) batchnorm::kInMovingMean,
inputs.end());
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
GetBatchNormOp<xpu, DType, AccReal>(param).Forward(ctx, in_data,
req, outputs, aux_states);
Expand All @@ -242,11 +244,16 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 11U);
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<TBlob> out_grad(inputs.begin(),
inputs.begin() + (param.output_mean_var ? 3U : 1U));
std::vector<TBlob> in_data(inputs.begin() + 3, inputs.begin() + 6);
std::vector<TBlob> aux_states(inputs.begin() + 6, inputs.begin() + 8);
std::vector<TBlob> out_data(inputs.begin() + 8, inputs.end());
int num_out_grads = param.output_mean_var ? 3U : 1U;
int in_data_start = 3;
int aux_states_start = in_data_start + (int) batchnorm::kInMovingMean;
int out_data_start = in_data_start + (int) batchnorm::kInMovingVar + 1;
std::vector<TBlob> out_grad(inputs.begin(), inputs.begin() + num_out_grads);
std::vector<TBlob> in_data(inputs.begin() + in_data_start,
inputs.begin() + aux_states_start);
std::vector<TBlob> aux_states(inputs.begin() + aux_states_start,
inputs.begin() + out_data_start);
std::vector<TBlob> out_data(inputs.begin() + out_data_start, inputs.end());
std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);

MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, {
Expand Down
10 changes: 5 additions & 5 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
using namespace mshadow;
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
const TShape &dshape = in_shape->at(0);
const TShape &dshape = in_shape->at(batchnorm::kData);

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -336,10 +336,10 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
return false;
}

in_shape->at(1) = TShape(Shape1(channelCount));
in_shape->at(2) = TShape(Shape1(channelCount));
in_shape->at(3) = TShape(Shape1(channelCount)); // kMovingMean
in_shape->at(4) = TShape(Shape1(channelCount)); // kMovingVar
in_shape->at(batchnorm::kGamma) = TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) = TShape(Shape1(channelCount)); // kMovingMean
in_shape->at(batchnorm::kInMovingVar) = TShape(Shape1(channelCount)); // kMovingVar

out_shape->clear();
out_shape->push_back(dshape); // kOut
Expand Down

0 comments on commit b161709

Please sign in to comment.