Skip to content

Commit b0b8746

Browse files
authored
Revert "[Unity] Split DecomposeOpsForTraining into two steps" (#16442)
1 parent 8f2e820 commit b0b8746

File tree

4 files changed

+121
-162
lines changed

4 files changed

+121
-162
lines changed

include/tvm/ir/transform.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -525,31 +525,6 @@ TVM_DLL Pass CreateModulePass(
525525
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
526526
String name, Array<runtime::String> required, bool traceable = false);
527527

528-
/*
529-
* \brief Utility to apply a pass to specific functions in an IRModule
530-
*
531-
* TVM uses IRModule to IRModule transformations at all stages of
532-
* lowering. These transformations may be useful when hand-writing an
533-
* optimized model, or to perform optimizations on specific kernels
534-
* within an IRModule. This utility allows a pass to be applied to a
535-
* specified function, without altering other functions in the module.
536-
*
537-
* \param pass The IRModule to IRModule pass to be applied.
538-
*
539-
* \param func_name_regex A regex used to select the functions to be
540-
* updated. The pass will be applied to all functions whose name
541-
* matches the regex.
542-
*
543-
* \param error_if_no_function_matches_regex Specifies the behavior if
544-
* an IRModule does not contain any function matching the provided
545-
* regex. If true, an error will be raised. If false (default),
546-
* the IRModule will be returned unmodified.
547-
*
548-
* \return The modified IRModule to IRModule pass.
549-
*/
550-
TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex,
551-
bool error_if_no_function_matches_regex = false);
552-
553528
/*!
554529
* \brief A special trace pass that prints the header and IR to LOG(INFO).
555530
* \param header The header to be attached to the output.

src/ir/transform.cc

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
#include <chrono>
3333
#include <iomanip>
34-
#include <regex>
3534
#include <stack>
3635
#include <unordered_set>
3736

@@ -532,36 +531,6 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
532531
return ModulePass(pass_func, pass_info);
533532
}
534533

535-
Pass ApplyPassToFunction(Pass pass, String func_name_regex,
536-
bool error_if_no_function_matches_regex) {
537-
auto pass_name =
538-
static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex)
539-
.str();
540-
std::regex regex(func_name_regex.operator std::string());
541-
542-
auto pass_func = [pass, regex](IRModule mod, PassContext) -> IRModule {
543-
IRModule subset;
544-
545-
for (const auto& [gvar, func] : mod->functions) {
546-
std::string name = gvar->name_hint;
547-
if (std::regex_match(name, regex)) {
548-
subset->Add(gvar, func);
549-
}
550-
}
551-
552-
if (subset->functions.size()) {
553-
IRModule new_subset = pass(subset);
554-
if (!new_subset.same_as(subset)) {
555-
mod.CopyOnWrite()->Update(new_subset);
556-
}
557-
}
558-
559-
return mod;
560-
};
561-
562-
return CreateModulePass(pass_func, 0, pass_name, {});
563-
}
564-
565534
TVM_REGISTER_NODE_TYPE(PassInfoNode);
566535

567536
TVM_REGISTER_GLOBAL("transform.PassInfo")

src/relax/transform/decompose_ops.cc

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, Array<Integer> axes) {
4848
return expand_dims(data, expand_axes);
4949
}
5050

51-
Tuple DecomposeBatchNorm(const Call& call) {
51+
Tuple SimplifyBatchNormInference(const Call& call) {
5252
auto attrs = call->attrs.as<BatchNormAttrs>();
5353
ICHECK_NOTNULL(attrs);
5454

@@ -75,18 +75,14 @@ Tuple DecomposeBatchNorm(const Call& call) {
7575
return Tuple({out, call->args[3], call->args[4]});
7676
}
7777

78-
Expr MutateBatchNormForTraining(Call call) {
78+
Tuple SimplifyBatchNormTraining(const Call& call) {
7979
auto attrs = call->attrs.as<BatchNormAttrs>();
8080
ICHECK_NOTNULL(attrs);
8181

82-
ICHECK_EQ(call->args.size(), 5);
8382
Expr data = call->args[0];
83+
TensorStructInfo sinfo = MatchTensorStructInfo(data);
8484
Expr gamma = call->args[1];
8585
Expr beta = call->args[2];
86-
Expr moving_mean = call->args[3];
87-
Expr moving_var = call->args[4];
88-
89-
TensorStructInfo sinfo = MatchTensorStructInfo(data);
9086

9187
Array<Integer> reduce_axes;
9288
for (int i = 0; i < sinfo->ndim; ++i) {
@@ -96,21 +92,35 @@ Expr MutateBatchNormForTraining(Call call) {
9692
}
9793

9894
Expr data_mean = mean(data, reduce_axes, false);
95+
Expr data_mean_rs = ExpandToMatchInput(data_mean, sinfo->ndim, {attrs->axis});
9996
Expr data_var = variance(data, reduce_axes, false);
97+
Expr data_var_rs = ExpandToMatchInput(data_var, sinfo->ndim, {attrs->axis});
10098

101-
Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
102-
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);
99+
// output = (x - mean) / sqrt(var + epsilon) * gamma + beta
100+
Expr epsilon = MakeConstantScalar(attrs->epsilon, sinfo->dtype);
101+
Expr sqrt_var = sqrt(add(data_var_rs, epsilon));
102+
Expr out = divide(subtract(data, data_mean_rs), sqrt_var);
103103

104-
Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean));
105-
Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var));
104+
if (attrs->scale) {
105+
out = multiply(out, ExpandToMatchInput(gamma, sinfo->ndim, {attrs->axis}));
106+
}
107+
if (attrs->center) {
108+
out = add(out, ExpandToMatchInput(beta, sinfo->ndim, {attrs->axis}));
109+
}
106110

107-
call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
108-
// return call;
111+
Expr moving_mean = call->args[3];
112+
Expr moving_var = call->args[4];
113+
Expr momentum = MakeConstantScalar(attrs->momentum, sinfo->dtype);
114+
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, sinfo->dtype);
109115

110-
return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, new_moving_var});
116+
return Tuple({
117+
out,
118+
add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)),
119+
add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)),
120+
});
111121
}
112122

113-
Expr DecomposeLayerNorm(const Call& call) {
123+
Expr SimplifyLayerNorm(const Call& call) {
114124
auto attrs = call->attrs.as<LayerNormAttrs>();
115125
ICHECK_NOTNULL(attrs);
116126

@@ -162,92 +172,92 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) {
162172
return ShapeExpr(shape_var);
163173
}
164174

165-
/*! \brief Update operators that have a training-specific form
166-
*
167-
* Some operators, such as relax.op.batch_norm, need additional
168-
* processing when being run for training. This mutator applies any mutations required
169-
*/
170-
class TrainingOperatorMutator : public ExprMutator {
171-
private:
172-
using ExprMutator::VisitExpr_;
175+
class OpDecomposer : public ExprMutator {
176+
public:
177+
constexpr static const char* kModeInference = "inference";
178+
constexpr static const char* kModeTraining = "training";
173179

174-
Expr VisitExpr_(const CallNode* call_node) final {
175-
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
176-
if (call->op == batch_norm_op_) {
177-
return MutateBatchNormForTraining(call);
178-
} else if (call->op == layer_norm_op_) {
179-
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
180-
// In the future maybe we can also remove this decomposition during training.
181-
return DecomposeLayerNorm(call);
182-
} else {
183-
return call;
184-
}
180+
explicit OpDecomposer(String mode) : ExprMutator(), mode_(mode) {
181+
CHECK(mode == kModeInference || mode == kModeTraining)
182+
<< "The argument mode must be one of the following values: \"inference\", \"training\".";
185183
}
186184

187-
/* composite opeartor list */
188-
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
189-
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
190-
};
191-
192-
class OpDecomposer : public ExprMutator {
193185
private:
194186
using ExprMutator::VisitExpr_;
195187

196188
Expr VisitExpr_(const CallNode* call_node) final {
197189
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
198190
if (call->op == batch_norm_op_) {
199-
return DecomposeBatchNorm(call);
191+
if (mode_ == kModeInference) {
192+
return SimplifyBatchNormInference(call);
193+
} else {
194+
ICHECK_EQ(mode_, kModeTraining);
195+
return SimplifyBatchNormTraining(call);
196+
}
197+
} else if (call->op == layer_norm_op_ && mode_ == kModeTraining) {
198+
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
199+
// In the future maybe we can also remove this decomposition during training.
200+
return SimplifyLayerNorm(call);
200201
} else if (call->op == tensor_to_shape_op_) {
201202
return TensorToShape(call, builder_);
202203
}
203204
return call;
204205
}
205206

207+
const String mode_;
208+
206209
/* composite opeartor list */
207210
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
211+
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
208212
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
209213
};
210214

211-
namespace transform {
215+
IRModule Decompose(IRModule mod, Optional<String> func_name, String mode) {
216+
auto op_decomposer = OpDecomposer(mode);
212217

213-
Pass MutateOpsForTraining() {
214-
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
215-
TrainingOperatorMutator mutator;
216-
return Downcast<Function>(mutator(func));
217-
};
218-
return CreateFunctionPass(/*pass_function=*/pass_func,
219-
/*opt_level=*/0,
220-
/*pass_name=*/"MutateOpsForTraining",
221-
/*required=*/{});
222-
}
218+
IRModuleNode* new_module = mod.CopyOnWrite();
223219

224-
Pass DecomposeOps() {
225-
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
226-
OpDecomposer mutator;
227-
return Downcast<Function>(mutator(func));
228-
};
229-
return CreateFunctionPass(/*pass_function=*/pass_func,
230-
/*opt_level=*/0,
231-
/*pass_name=*/"DecomposeOps",
232-
/*required=*/{});
220+
if (!func_name.defined()) { // simplify all functions
221+
Map<GlobalVar, BaseFunc> functions = mod->functions;
222+
for (const auto& func_pr : functions) {
223+
if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
224+
Function f = Downcast<Function>(op_decomposer(GetRef<Function>(relax_f)));
225+
new_module->Update(func_pr.first, f);
226+
}
227+
}
228+
} else { // simplify specified function
229+
auto* func_ptr = mod->Lookup(func_name.value()).as<FunctionNode>();
230+
CHECK(func_ptr) << func_name.value() << "is not a Relax Function";
231+
auto gvar = mod->GetGlobalVar(func_name.value());
232+
auto func = GetRef<Function>(func_ptr);
233+
func = Downcast<Function>(op_decomposer(func));
234+
new_module->Update(gvar, func);
235+
}
236+
237+
return GetRef<IRModule>(new_module);
233238
}
234239

240+
namespace transform {
235241
Pass DecomposeOpsForInference(Optional<String> func_name) {
236-
if (func_name) {
237-
return ApplyPassToFunction(DecomposeOps(), func_name.value());
238-
} else {
239-
return DecomposeOps();
240-
}
242+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
243+
PassContext pc) {
244+
return Decompose(mod, func_name, OpDecomposer::kModeInference);
245+
};
246+
return CreateModulePass(/*pass_function=*/pass_func,
247+
/*opt_level=*/0,
248+
/*pass_name=*/"DecomposeOpsForInference",
249+
/*required=*/{});
241250
}
242251

243252
Pass DecomposeOpsForTraining(Optional<String> func_name) {
244-
auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()},
245-
"DecomposeOpsForTraining");
246-
if (func_name) {
247-
return ApplyPassToFunction(module_pass, func_name.value());
248-
} else {
249-
return module_pass;
250-
}
253+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
254+
PassContext pc) {
255+
return Decompose(mod, func_name, OpDecomposer::kModeTraining);
256+
};
257+
return CreateModulePass(/*pass_function=*/pass_func,
258+
/*opt_level=*/0,
259+
/*pass_name=*/"DecomposeOpsForTraining",
260+
/*required=*/{});
251261
}
252262

253263
TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference")

tests/python/relax/test_transform_decompose_ops.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,39 +137,44 @@ def main(
137137
R.Tensor((64,), dtype="float32"),
138138
):
139139
with R.dataflow():
140-
# This portion is training-specific, computing the
141-
# mean/variance of the dataset.
142-
lv = R.mean(x, axis=[0, 2, 3], keepdims=False)
143-
lv3 = R.variance(x, axis=[0, 2, 3], keepdims=False)
144-
145-
# This portion is identical to the batch_norm run during inference
146-
lv1 = R.expand_dims(lv, axis=[0, 2, 3])
147-
lv2 = R.subtract(x, lv1)
148-
lv4 = R.expand_dims(lv3, axis=[0, 2, 3])
149-
lv5 = R.add(lv4, R.const(9.9999997473787516e-06, "float32"))
150-
lv6 = R.sqrt(lv5)
151-
lv7 = R.divide(lv2, lv6)
152-
lv8 = R.expand_dims(gamma, axis=[0, 2, 3])
153-
lv9 = R.multiply(lv7, lv8)
154-
lv10 = R.expand_dims(beta, axis=[0, 2, 3])
155-
lv11 = R.add(lv9, lv10)
156-
inner_tuple = (lv11, lv, lv3)
157-
# This is the result that would be returned from a
158-
# batch_norm at inference.
159-
160-
# However, at training we need to update the moving
161-
# mean/variance, and to return those updated values.
162-
inner_res = inner_tuple[0]
163-
lv12 = R.multiply(R.const(0.89999997615814209, "float32"), moving_mean)
164-
lv13 = R.multiply(R.const(0.10000000149011612, "float32"), lv)
165-
lv14 = R.add(lv12, lv13)
166-
lv15 = R.multiply(R.const(0.89999997615814209, "float32"), moving_var)
167-
lv16 = R.multiply(R.const(0.10000000149011612, "float32"), lv3)
168-
lv17 = R.add(lv15, lv16)
169-
bn = (inner_res, lv14, lv17)
170-
gv0 = bn[0]
171-
gv1 = bn[1]
172-
gv2 = bn[2]
140+
lv: R.Tensor((64,), dtype="float32") = R.mean(x, axis=[0, 2, 3], keepdims=False)
141+
lv1: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv, axis=[0, 2, 3])
142+
lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = R.subtract(x, lv1)
143+
lv3: R.Tensor((64,), dtype="float32") = R.variance(
144+
x, axis=[0, 2, 3], keepdims=False
145+
)
146+
lv4: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(lv3, axis=[0, 2, 3])
147+
lv5: R.Tensor((1, 64, 1, 1), dtype="float32") = R.add(
148+
lv4, R.const(9.9999997473787516e-06, "float32")
149+
)
150+
lv6: R.Tensor((1, 64, 1, 1), dtype="float32") = R.sqrt(lv5)
151+
lv7: R.Tensor((1, 64, 112, 112), dtype="float32") = R.divide(lv2, lv6)
152+
lv8: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(gamma, axis=[0, 2, 3])
153+
lv9: R.Tensor((1, 64, 112, 112), dtype="float32") = R.multiply(lv7, lv8)
154+
lv10: R.Tensor((1, 64, 1, 1), dtype="float32") = R.expand_dims(beta, axis=[0, 2, 3])
155+
lv11: R.Tensor((1, 64, 112, 112), dtype="float32") = R.add(lv9, lv10)
156+
lv12: R.Tensor((64,), dtype="float32") = R.multiply(
157+
R.const(0.89999997615814209, "float32"), moving_mean
158+
)
159+
lv13: R.Tensor((64,), dtype="float32") = R.multiply(
160+
R.const(0.10000000149011612, "float32"), lv
161+
)
162+
lv14: R.Tensor((64,), dtype="float32") = R.add(lv12, lv13)
163+
lv15: R.Tensor((64,), dtype="float32") = R.multiply(
164+
R.const(0.89999997615814209, "float32"), moving_var
165+
)
166+
lv16: R.Tensor((64,), dtype="float32") = R.multiply(
167+
R.const(0.10000000149011612, "float32"), lv3
168+
)
169+
lv17: R.Tensor((64,), dtype="float32") = R.add(lv15, lv16)
170+
bn: R.Tuple(
171+
R.Tensor((1, 64, 112, 112), dtype="float32"),
172+
R.Tensor((64,), dtype="float32"),
173+
R.Tensor((64,), dtype="float32"),
174+
) = (lv11, lv14, lv17)
175+
gv0: R.Tensor((1, 64, 112, 112), dtype="float32") = bn[0]
176+
gv1: R.Tensor((64,), dtype="float32") = bn[1]
177+
gv2: R.Tensor((64,), dtype="float32") = bn[2]
173178
R.output(gv0, gv1, gv2)
174179
return (gv0, gv1, gv2)
175180

0 commit comments

Comments
 (0)