diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 4838570bda97..1220156f1056 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -227,8 +227,10 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const BatchNormParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 5U); - std::vector in_data(inputs.begin(), inputs.begin() + 3); - std::vector aux_states(inputs.begin() + 3, inputs.end()); + std::vector in_data(inputs.begin(), + inputs.begin() + (int) batchnorm::kInMovingMean); + std::vector aux_states(inputs.begin() + (int) batchnorm::kInMovingMean, + inputs.end()); MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { GetBatchNormOp(param).Forward(ctx, in_data, req, outputs, aux_states); @@ -242,11 +244,16 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 11U); const BatchNormParam& param = nnvm::get(attrs.parsed); - std::vector out_grad(inputs.begin(), - inputs.begin() + (param.output_mean_var ? 3U : 1U)); - std::vector in_data(inputs.begin() + 3, inputs.begin() + 6); - std::vector aux_states(inputs.begin() + 6, inputs.begin() + 8); - std::vector out_data(inputs.begin() + 8, inputs.end()); + int num_out_grads = param.output_mean_var ? 3U : 1U; + int in_data_start = 3; + int aux_states_start = in_data_start + (int) batchnorm::kInMovingMean; + int out_data_start = in_data_start + (int) batchnorm::kInMovingVar + 1; + std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); + std::vector in_data(inputs.begin() + in_data_start, + inputs.begin() + aux_states_start); + std::vector aux_states(inputs.begin() + aux_states_start, + inputs.begin() + out_data_start); + std::vector out_data(inputs.begin() + out_data_start, inputs.end()); std::vector in_grad(outputs.begin(), outputs.begin() + 3); MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 298de204a53f..bbf4da9874c4 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -323,7 +323,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const BatchNormParam& param = nnvm::get(attrs.parsed); using namespace mshadow; CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; - const TShape &dshape = in_shape->at(0); + const TShape &dshape = in_shape->at(batchnorm::kData); const size_t channelAxis = static_cast(param.axis < 0 ? static_cast(dshape.ndim()) + param.axis @@ -336,10 +336,10 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, return false; } - in_shape->at(1) = TShape(Shape1(channelCount)); - in_shape->at(2) = TShape(Shape1(channelCount)); - in_shape->at(3) = TShape(Shape1(channelCount)); // kMovingMean - in_shape->at(4) = TShape(Shape1(channelCount)); // kMovingVar + in_shape->at(batchnorm::kGamma) = TShape(Shape1(channelCount)); + in_shape->at(batchnorm::kBeta) = TShape(Shape1(channelCount)); + in_shape->at(batchnorm::kInMovingMean) = TShape(Shape1(channelCount)); // kMovingMean + in_shape->at(batchnorm::kInMovingVar) = TShape(Shape1(channelCount)); // kMovingVar out_shape->clear(); out_shape->push_back(dshape); // kOut