Skip to content

Commit

Permalink
fix bn bug
Browse files Browse the repository at this point in the history
  • Loading branch information
DHCZ committed May 23, 2017
1 parent 0b780a2 commit bece19a
Showing 1 changed file with 3 additions and 80 deletions.
83 changes: 3 additions & 80 deletions src/operator/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,86 +146,9 @@ class BatchNormOp : public Operator {
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data, grad, grad_in;
const real_t scale = static_cast<real_t>(out_grad[batchnorm::kOut].shape_[1]) /
static_cast<real_t>(out_grad[batchnorm::kOut].shape_.Size());
if (in_data[batchnorm::kData].ndim() == 2) {
Shape<4> dshape = Shape4(out_grad[batchnorm::kOut].shape_[0],
out_grad[batchnorm::kOut].shape_[1], 1, 1);
data = in_data[batchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
grad = out_grad[batchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
grad_in = in_grad[batchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
} else {
data = in_data[batchnorm::kData].get<xpu, 4, real_t>(s);
grad = out_grad[batchnorm::kOut].get<xpu, 4, real_t>(s);
grad_in = in_grad[batchnorm::kData].get<xpu, 4, real_t>(s);
}

Tensor<xpu, 1> mean = out_data[batchnorm::kMean].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> var = out_data[batchnorm::kVar].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> slope = in_data[batchnorm::kGamma].get<xpu, 1, real_t>(s);
// Tensor<xpu, 1> bias = in_data[kBeta].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> gslope = in_grad[batchnorm::kGamma].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> gbias = in_grad[batchnorm::kBeta].get<xpu, 1, real_t>(s);
// update moving avg
Tensor<xpu, 1> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, real_t>(s);

if (param_.fix_gamma) slope = 1.f;

if (ctx.is_train && !param_.use_global_stats) {
// get requested temp space
Tensor<xpu, 2> workspace = ctx.requested[batchnorm::kTempSpace].get_space<xpu>(
mshadow::Shape2(3, mean.shape_[0]), s);
Tensor<xpu, 1> gmean = workspace[0];
Tensor<xpu, 1> gvar = workspace[1];
Tensor<xpu, 1> tmp = workspace[2];

moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum);
moving_var = moving_var * param_.momentum + var * (1 - param_.momentum);
// cal
gvar = sumall_except_dim<1>((grad * broadcast<1>(slope, data.shape_)) *
(data - broadcast<1>(mean, data.shape_)) *
-0.5f *
F<mshadow_op::power>(broadcast<1>(var + param_.eps, data.shape_),
-1.5f));
gmean = sumall_except_dim<1>(grad * broadcast<1>(slope, data.shape_));
gmean *= -1.0f / F<mshadow_op::square_root>(var + param_.eps);
tmp = scale * sumall_except_dim<1>(-2.0f * (data - broadcast<1>(mean, data.shape_)));
tmp *= gvar;
gmean += tmp;
// assign
if (!param_.fix_gamma || !param_.fix_linear_trans) {
Assign(gslope, req[batchnorm::kGamma],
sumall_except_dim<1>(
grad * (data - broadcast<1>(mean, data.shape_)) /
F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, data.shape_))));
} else {
Assign(gslope, req[batchnorm::kGamma], 0.0f);
}
Assign(grad_in, req[batchnorm::kData],
(grad * broadcast<1>(slope, data.shape_)) *
broadcast<1>(1.0f / F<mshadow_op::square_root>(var + param_.eps), data.shape_) +
broadcast<1>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<1>(mean,
data.shape_)) +
broadcast<1>(gmean, data.shape_) * scale);
Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad));
} else {
// use global statistics with freeze moving mean and var.
if (!param_.fix_gamma || !param_.fix_linear_trans) {
Assign(gslope, req[batchnorm::kGamma],
sumall_except_dim<1>(
grad * (data - broadcast<1>(moving_mean, data.shape_)) /
F<mshadow_op::square_root>(broadcast<1>(moving_var + param_.eps, data.shape_))));
} else {
Assign(gslope, req[batchnorm::kGamma], 0.0f);
}
Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad));
Assign(grad_in, req[batchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) *
broadcast<1>(
1.0f / F<mshadow_op::square_root>(moving_var + param_.eps), data.shape_));
}
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
DoBackward(s, ctx, out_grad, in_data,
out_data, req, in_grad, aux_states);
}

private:
Expand Down

0 comments on commit bece19a

Please sign in to comment.