diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 01a478429ce9..89f5327c9722 100755 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -31,6 +31,7 @@ struct BatchNormParam : public dmlc::Parameter { float eps; float momentum; bool fix_gamma; + bool fix_linear_trans; bool use_global_stats; bool output_mean_var; DMLC_DECLARE_PARAMETER(BatchNormParam) { @@ -40,6 +41,8 @@ struct BatchNormParam : public dmlc::Parameter { .describe("Momentum for moving average"); DMLC_DECLARE_FIELD(fix_gamma).set_default(true) .describe("Fix gamma while training"); + DMLC_DECLARE_FIELD(fix_linear_trans).set_default(false) + .describe("Fix linear transformation while training"); DMLC_DECLARE_FIELD(use_global_stats).set_default(false) .describe("Whether use global moving statistics instead of local batch-norm. " "This will force change batch-norm into a scale shift operator."); @@ -180,7 +183,7 @@ class BatchNormOp : public Operator { tmp *= gvar; gmean += tmp; // assign - if (!param_.fix_gamma) { + if (!param_.fix_gamma || !param_.fix_linear_trans) { Assign(gslope, req[batchnorm::kGamma], sumall_except_dim<1>( grad * (data - broadcast<1>(mean, data.shape_)) / @@ -197,7 +200,7 @@ class BatchNormOp : public Operator { Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad)); } else { // use global statistics with freeze moving mean and var. - if (!param_.fix_gamma) { + if (!param_.fix_gamma || !param_.fix_linear_trans) { Assign(gslope, req[batchnorm::kGamma], sumall_except_dim<1>( grad * (data - broadcast<1>(moving_mean, data.shape_)) /