@@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer> axes) {
4848 return expand_dims (data, expand_axes);
4949}
5050
51- Tuple DecomposeBatchNorm (const Call& call) {
51+ Tuple SimplifyBatchNormInference (const Call& call) {
5252 auto attrs = call->attrs .as <BatchNormAttrs>();
5353 ICHECK_NOTNULL (attrs);
5454
@@ -75,18 +75,14 @@ Tuple DecomposeBatchNorm(const Call& call) {
7575 return Tuple ({out, call->args [3 ], call->args [4 ]});
7676}
7777
78- Expr MutateBatchNormForTraining ( Call call) {
78+ Tuple SimplifyBatchNormTraining ( const Call& call) {
7979 auto attrs = call->attrs .as <BatchNormAttrs>();
8080 ICHECK_NOTNULL (attrs);
8181
82- ICHECK_EQ (call->args .size (), 5 );
8382 Expr data = call->args [0 ];
83+ TensorStructInfo sinfo = MatchTensorStructInfo (data);
8484 Expr gamma = call->args [1 ];
8585 Expr beta = call->args [2 ];
86- Expr moving_mean = call->args [3 ];
87- Expr moving_var = call->args [4 ];
88-
89- TensorStructInfo sinfo = MatchTensorStructInfo (data);
9086
9187 Array<Integer> reduce_axes;
9288 for (int i = 0 ; i < sinfo->ndim ; ++i) {
@@ -96,21 +92,35 @@ Expr MutateBatchNormForTraining(Call call) {
9692 }
9793
9894 Expr data_mean = mean (data, reduce_axes, false );
95+ Expr data_mean_rs = ExpandToMatchInput (data_mean, sinfo->ndim , {attrs->axis });
9996 Expr data_var = variance (data, reduce_axes, false );
97+ Expr data_var_rs = ExpandToMatchInput (data_var, sinfo->ndim , {attrs->axis });
10098
101- Expr momentum = MakeConstantScalar (attrs->momentum , sinfo->dtype );
102- Expr one_minus_mom = MakeConstantScalar (1 - attrs->momentum , sinfo->dtype );
99+ // output = (x - mean) / sqrt(var + epsilon) * gamma + beta
100+ Expr epsilon = MakeConstantScalar (attrs->epsilon , sinfo->dtype );
101+ Expr sqrt_var = sqrt (add (data_var_rs, epsilon));
102+ Expr out = divide (subtract (data, data_mean_rs), sqrt_var);
103103
104- Expr new_moving_mean = add (multiply (one_minus_mom, moving_mean), multiply (momentum, data_mean));
105- Expr new_moving_var = add (multiply (one_minus_mom, moving_var), multiply (momentum, data_var));
104+ if (attrs->scale ) {
105+ out = multiply (out, ExpandToMatchInput (gamma, sinfo->ndim , {attrs->axis }));
106+ }
107+ if (attrs->center ) {
108+ out = add (out, ExpandToMatchInput (beta, sinfo->ndim , {attrs->axis }));
109+ }
106110
107- call.CopyOnWrite ()->args = {data, gamma, beta, data_mean, data_var};
108- // return call;
111+ Expr moving_mean = call->args [3 ];
112+ Expr moving_var = call->args [4 ];
113+ Expr momentum = MakeConstantScalar (attrs->momentum , sinfo->dtype );
114+ Expr one_minus_mom = MakeConstantScalar (1 - attrs->momentum , sinfo->dtype );
109115
110- return relax::Tuple ({TupleGetItem (call, 0 ), new_moving_mean, new_moving_var});
116+ return Tuple ({
117+ out,
118+ add (multiply (one_minus_mom, moving_mean), multiply (momentum, data_mean)),
119+ add (multiply (one_minus_mom, moving_var), multiply (momentum, data_var)),
120+ });
111121}
112122
113- Expr DecomposeLayerNorm (const Call& call) {
123+ Expr SimplifyLayerNorm (const Call& call) {
114124 auto attrs = call->attrs .as <LayerNormAttrs>();
115125 ICHECK_NOTNULL (attrs);
116126
@@ -162,92 +172,92 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) {
162172 return ShapeExpr (shape_var);
163173}
164174
165- /* ! \brief Update operators that have a training-specific form
166- *
167- * Some operators, such as relax.op.batch_norm, need additional
168- * processing when being run for training. This mutator applies any mutations required
169- */
170- class TrainingOperatorMutator : public ExprMutator {
171- private:
172- using ExprMutator::VisitExpr_;
175+ class OpDecomposer : public ExprMutator {
176+ public:
177+ constexpr static const char * kModeInference = " inference" ;
178+ constexpr static const char * kModeTraining = " training" ;
173179
174- Expr VisitExpr_ (const CallNode* call_node) final {
175- Call call = Downcast<Call>(VisitExprPostOrder_ (call_node));
176- if (call->op == batch_norm_op_) {
177- return MutateBatchNormForTraining (call);
178- } else if (call->op == layer_norm_op_) {
179- // Here we only decompose LayerNorm in training because it is more efficient as a single op.
180- // In the future maybe we can also remove this decomposition during training.
181- return DecomposeLayerNorm (call);
182- } else {
183- return call;
184- }
180+ explicit OpDecomposer (String mode) : ExprMutator(), mode_(mode) {
181+ CHECK (mode == kModeInference || mode == kModeTraining )
182+ << " The argument mode must be one of the following values: \" inference\" , \" training\" ." ;
185183 }
186184
187- /* composite opeartor list */
188- const Op& batch_norm_op_ = Op::Get(" relax.nn.batch_norm" );
189- const Op& layer_norm_op_ = Op::Get(" relax.nn.layer_norm" );
190- };
191-
192- class OpDecomposer : public ExprMutator {
193185 private:
194186 using ExprMutator::VisitExpr_;
195187
196188 Expr VisitExpr_ (const CallNode* call_node) final {
197189 Call call = Downcast<Call>(VisitExprPostOrder_ (call_node));
198190 if (call->op == batch_norm_op_) {
199- return DecomposeBatchNorm (call);
191+ if (mode_ == kModeInference ) {
192+ return SimplifyBatchNormInference (call);
193+ } else {
194+ ICHECK_EQ (mode_, kModeTraining );
195+ return SimplifyBatchNormTraining (call);
196+ }
197+ } else if (call->op == layer_norm_op_ && mode_ == kModeTraining ) {
198+ // Here we only decompose LayerNorm in training because it is more efficient as a single op.
199+ // In the future maybe we can also remove this decomposition during training.
200+ return SimplifyLayerNorm (call);
200201 } else if (call->op == tensor_to_shape_op_) {
201202 return TensorToShape (call, builder_);
202203 }
203204 return call;
204205 }
205206
207+ const String mode_;
208+
206209 /* composite opeartor list */
207210 const Op& batch_norm_op_ = Op::Get(" relax.nn.batch_norm" );
211+ const Op& layer_norm_op_ = Op::Get(" relax.nn.layer_norm" );
208212 const Op& tensor_to_shape_op_ = Op::Get(" relax.tensor_to_shape" );
209213};
210214
211- namespace transform {
215+ IRModule Decompose (IRModule mod, Optional<String> func_name, String mode) {
216+ auto op_decomposer = OpDecomposer (mode);
212217
213- Pass MutateOpsForTraining () {
214- auto pass_func = [](Function func, IRModule, PassContext) -> Function {
215- TrainingOperatorMutator mutator;
216- return Downcast<Function>(mutator (func));
217- };
218- return CreateFunctionPass (/* pass_function=*/ pass_func,
219- /* opt_level=*/ 0 ,
220- /* pass_name=*/ " MutateOpsForTraining" ,
221- /* required=*/ {});
222- }
218+ IRModuleNode* new_module = mod.CopyOnWrite ();
223219
224- Pass DecomposeOps () {
225- auto pass_func = [](Function func, IRModule, PassContext) -> Function {
226- OpDecomposer mutator;
227- return Downcast<Function>(mutator (func));
228- };
229- return CreateFunctionPass (/* pass_function=*/ pass_func,
230- /* opt_level=*/ 0 ,
231- /* pass_name=*/ " DecomposeOps" ,
232- /* required=*/ {});
220+ if (!func_name.defined ()) { // simplify all functions
221+ Map<GlobalVar, BaseFunc> functions = mod->functions ;
222+ for (const auto & func_pr : functions) {
223+ if (const auto * relax_f = func_pr.second .as <FunctionNode>()) {
224+ Function f = Downcast<Function>(op_decomposer (GetRef<Function>(relax_f)));
225+ new_module->Update (func_pr.first , f);
226+ }
227+ }
228+ } else { // simplify specified function
229+ auto * func_ptr = mod->Lookup (func_name.value ()).as <FunctionNode>();
230+ CHECK (func_ptr) << func_name.value () << " is not a Relax Function" ;
231+ auto gvar = mod->GetGlobalVar (func_name.value ());
232+ auto func = GetRef<Function>(func_ptr);
233+ func = Downcast<Function>(op_decomposer (func));
234+ new_module->Update (gvar, func);
235+ }
236+
237+ return GetRef<IRModule>(new_module);
233238}
234239
240+ namespace transform {
235241Pass DecomposeOpsForInference (Optional<String> func_name) {
236- if (func_name) {
237- return ApplyPassToFunction (DecomposeOps (), func_name.value ());
238- } else {
239- return DecomposeOps ();
240- }
242+ runtime::TypedPackedFunc<IRModule (IRModule, PassContext)> pass_func = [=](IRModule mod,
243+ PassContext pc) {
244+ return Decompose (mod, func_name, OpDecomposer::kModeInference );
245+ };
246+ return CreateModulePass (/* pass_function=*/ pass_func,
247+ /* opt_level=*/ 0 ,
248+ /* pass_name=*/ " DecomposeOpsForInference" ,
249+ /* required=*/ {});
241250}
242251
243252Pass DecomposeOpsForTraining (Optional<String> func_name) {
244- auto module_pass = tvm::transform::Sequential ({MutateOpsForTraining (), DecomposeOps ()},
245- " DecomposeOpsForTraining" );
246- if (func_name) {
247- return ApplyPassToFunction (module_pass, func_name.value ());
248- } else {
249- return module_pass;
250- }
253+ runtime::TypedPackedFunc<IRModule (IRModule, PassContext)> pass_func = [=](IRModule mod,
254+ PassContext pc) {
255+ return Decompose (mod, func_name, OpDecomposer::kModeTraining );
256+ };
257+ return CreateModulePass (/* pass_function=*/ pass_func,
258+ /* opt_level=*/ 0 ,
259+ /* pass_name=*/ " DecomposeOpsForTraining" ,
260+ /* required=*/ {});
251261}
252262
253263TVM_REGISTER_GLOBAL (" relax.transform.DecomposeOpsForInference" )
0 commit comments