Skip to content

Commit

Permalink
merge segmentation dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
GrassSunFlower committed Apr 1, 2017
1 parent e03dc65 commit 633e707
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/operator/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
float eps;
float momentum;
bool fix_gamma;
bool fix_linear_trans;
bool use_global_stats;
bool output_mean_var;
DMLC_DECLARE_PARAMETER(BatchNormParam) {
Expand All @@ -40,6 +41,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
.describe("Momentum for moving average");
DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
.describe("Fix gamma while training");
DMLC_DECLARE_FIELD(fix_linear_trans).set_default(true)
.describe("Fix linear transformation while training");
DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
.describe("Whether use global moving statistics instead of local batch-norm. "
"This will force change batch-norm into a scale shift operator.");
Expand Down Expand Up @@ -180,7 +183,7 @@ class BatchNormOp : public Operator {
tmp *= gvar;
gmean += tmp;
// assign
if (!param_.fix_gamma) {
if (!param_.fix_gamma || !param_.fix_linear_trans) {
Assign(gslope, req[batchnorm::kGamma],
sumall_except_dim<1>(
grad * (data - broadcast<1>(mean, data.shape_)) /
Expand All @@ -197,7 +200,7 @@ class BatchNormOp : public Operator {
Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad));
} else {
// use global statistics with freeze moving mean and var.
if (!param_.fix_gamma) {
if (!param_.fix_gamma || !param_.fix_linear_trans) {
Assign(gslope, req[batchnorm::kGamma],
sumall_except_dim<1>(
grad * (data - broadcast<1>(moving_mean, data.shape_)) /
Expand Down

0 comments on commit 633e707

Please sign in to comment.