From 633e70718e328ea12fe516fa8786fc92d31e5447 Mon Sep 17 00:00:00 2001 From: PengfeiChen Date: Fri, 31 Mar 2017 18:09:03 -0700 Subject: [PATCH] merge segmentation dependencies --- src/operator/batch_norm-inl.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 31f6531dd49c..08570bb51022 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -31,6 +31,7 @@ struct BatchNormParam : public dmlc::Parameter { float eps; float momentum; bool fix_gamma; + bool fix_linear_trans; bool use_global_stats; bool output_mean_var; DMLC_DECLARE_PARAMETER(BatchNormParam) { @@ -40,6 +41,8 @@ struct BatchNormParam : public dmlc::Parameter { .describe("Momentum for moving average"); DMLC_DECLARE_FIELD(fix_gamma).set_default(true) .describe("Fix gamma while training"); + DMLC_DECLARE_FIELD(fix_linear_trans).set_default(true) + .describe("Fix linear transformation while training"); DMLC_DECLARE_FIELD(use_global_stats).set_default(false) .describe("Whether use global moving statistics instead of local batch-norm. " "This will force change batch-norm into a scale shift operator."); @@ -180,7 +183,7 @@ class BatchNormOp : public Operator { tmp *= gvar; gmean += tmp; // assign - if (!param_.fix_gamma) { + if (!param_.fix_gamma || !param_.fix_linear_trans) { Assign(gslope, req[batchnorm::kGamma], sumall_except_dim<1>( grad * (data - broadcast<1>(mean, data.shape_)) / @@ -197,7 +200,7 @@ class BatchNormOp : public Operator { Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad)); } else { // use global statistics with freeze moving mean and var. - if (!param_.fix_gamma) { + if (!param_.fix_gamma || !param_.fix_linear_trans) { Assign(gslope, req[batchnorm::kGamma], sumall_except_dim<1>( grad * (data - broadcast<1>(moving_mean, data.shape_)) /