Skip to content

Commit 65e9808

Browse files
committed
[Relax] Refactor PatternRewriter into separate Block/Expr mutators
Prior to this commit, the `PatternRewriter` mutator handled pattern rewriting at either the expression level (`rewrite_call`) or the dataflow block level (`rewrite_bindings`). These two functionalities had different external APIs, defined diffierent member variables, and visited different IR nodes. In effect, it had two entirely independent implementations, which just happened to be implemented within the same class. This commit refactors the single `PatternRewriter` mutator into separate `BlockPatternRewriter` and `ExprPatternRewriter` mutators.
1 parent feb1043 commit 65e9808

File tree

2 files changed

+140
-102
lines changed

2 files changed

+140
-102
lines changed

include/tvm/relax/dataflow_matcher.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
6767
* \param f The function to rewrite
6868
* \return The rewritten or the input function, depending on the pattern matching result.
6969
*/
70-
TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f);
70+
TVM_DLL Function RewriteBindings(
71+
const PatternContext& ctx,
72+
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter, Function f);
7173

7274
/**
7375
* \brief Rewrite a function with the given pattern and the rewriter function.

src/relax/ir/dataflow_matcher.cc

Lines changed: 137 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
973973
});
974974

975975
/*!
976-
* \brief Apply pattern matching to each call node and dataflow block, and replace matching ones
976+
* \brief Apply pattern matching to each dataflow block, replacing matches
977977
* with the output of a user-provided rewriter function.
978978
*/
979-
class PatternRewriter : ExprMutator {
979+
class BlockPatternRewriter : ExprMutator {
980980
public:
981981
using ExprMutator::VisitBindingBlock_;
982982
using ExprMutator::VisitExpr_;
983983

984-
PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
985-
const std::unordered_set<const VarNode*>& params)
986-
: pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
987-
988-
PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
989-
const std::unordered_set<const VarNode*>& params)
990-
: ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
984+
BlockPatternRewriter(
985+
const PatternContext& ctx,
986+
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func)
987+
: ctx_(ctx), rewriter_func_(rewriter_func) {}
991988

992989
template <typename PatternType>
993-
static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) {
994-
std::unordered_set<const VarNode*> params;
995-
for (const auto& p : f->params) {
996-
params.insert(p.get());
997-
}
998-
PatternRewriter rewriter(pat, rewriter_func, params);
999-
return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
1000-
}
1001-
1002-
Expr VisitExpr_(const SeqExprNode* seq) override {
1003-
if (ctx_) {
1004-
return ExprMutator::VisitExpr_(seq);
1005-
}
1006-
1007-
auto cache = bindings_;
1008-
SeqExpr prev = GetRef<SeqExpr>(seq);
1009-
1010-
StructuralEqual struct_equal;
1011-
1012-
while (true) {
1013-
SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
1014-
if (struct_equal(prev, next)) {
1015-
return std::move(next);
1016-
}
1017-
1018-
// Canonicalization may result in two previously-different
1019-
// expressions being recognized as identical. Elimination of
1020-
// common subexpressions may result in trival var-to-var
1021-
// bindings that can be canonicalized. Therefore, iterate the
1022-
// simplification steps until converged.
1023-
while (true) {
1024-
auto start_of_loop = next;
1025-
next = Downcast<SeqExpr>(CanonicalizeBindings(next));
1026-
next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
1027-
next = Downcast<SeqExpr>(RemoveAllUnused(next));
1028-
if (struct_equal(start_of_loop, next)) {
1029-
break;
1030-
}
1031-
}
1032-
1033-
if (struct_equal(prev, next)) {
1034-
return std::move(next);
1035-
}
1036-
1037-
// Reset all knowledge of bindings that were collected from
1038-
// this DataflowBlock. The collected bindings are only after
1039-
// the point where they were collected, and we are repeating
1040-
// the mutation of this DataflowBlock.
1041-
bindings_ = cache;
1042-
prev = next;
1043-
}
990+
static Function Run(
991+
PatternType pat,
992+
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func,
993+
Function func) {
994+
BlockPatternRewriter rewriter(pat, rewriter_func);
995+
996+
func = Downcast<Function>(rewriter(func));
997+
func = Downcast<Function>(RemoveAllUnused(func));
998+
return func;
1044999
}
10451000

10461001
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override {
1047-
if (ctx_) {
1048-
return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
1049-
} else {
1050-
return ExprMutator::VisitBindingBlock_(block_node);
1051-
}
1052-
}
1053-
1054-
void VisitBinding_(const VarBindingNode* binding) override {
1055-
auto expr = VisitExpr(binding->value);
1056-
bindings_.Set(binding->var, expr);
1057-
ReEmitBinding(binding, expr);
1058-
}
1059-
1060-
Expr VisitExpr(const Expr& expr) override {
1061-
auto node = ExprMutator::VisitExpr(expr);
1062-
1063-
if (pattern_) {
1064-
if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) {
1065-
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
1066-
if (!rewritten_expr.same_as(node)) {
1067-
return builder_->Normalize(rewritten_expr);
1068-
}
1069-
}
1070-
}
1071-
return node;
1002+
return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
10721003
}
10731004

10741005
private:
@@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator {
11061037
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
11071038
auto df_block = Downcast<DataflowBlock>(block);
11081039
Map<Var, Expr> bindings = AnalyzeVar2Value(df_block);
1109-
if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) {
1040+
if (auto matches = MatchGraph(ctx_, df_block, bindings)) {
11101041
builder_->BeginDataflowBlock();
11111042
Map<Var, Expr> replacements = rewriter_func_(matches.value(), bindings);
11121043

@@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator {
11401071
return block;
11411072
}
11421073

1143-
/*! \brief The pattern for rewriting call nodes */
1144-
Optional<DFPattern> pattern_;
11451074
/*! \brief The pattern constraint contexts for rewriting dataflow blocks */
1146-
Optional<PatternContext> ctx_;
1075+
PatternContext ctx_;
11471076
/*!
11481077
* \brief The user-provided rewriter function. Its signature and semantics are:
1149-
* - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched
1150-
* call node and the map of patterns and matched expressions, it should return a new call node
1151-
* to replace the original one or the original matched call node as is.
1152-
* - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow block rewriting.
1153-
* Given the map of patterns and corresponding variables (bound variables or parameters),
1154-
* it should return a map that specifies new values for matched bound variables. It can refer
1078+
*
1079+
* - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr>
1080+
*
1081+
* Given the map of patterns and corresponding variables (bound
1082+
* variables or parameters), it should return a map that
1083+
* specifies new values for matched bound variables. It can refer
11551084
* to the passed bindings to create the replacement expressions.
11561085
*/
1157-
PackedFunc rewriter_func_;
1158-
std::unordered_set<const VarNode*> params_;
1086+
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func_;
1087+
};
1088+
1089+
/*!
1090+
* \brief Apply pattern matching to each expression, replacing
1091+
* matches with the output of a user-provided rewriter function.
1092+
*/
1093+
class ExprPatternRewriter : ExprMutator {
1094+
public:
1095+
using ExprMutator::VisitBindingBlock_;
1096+
using ExprMutator::VisitExpr_;
1097+
1098+
ExprPatternRewriter(DFPattern pat,
1099+
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func)
1100+
: pattern_(pat), rewriter_func_(rewriter_func) {}
1101+
1102+
template <typename PatternType>
1103+
static Function Run(PatternType pat,
1104+
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func,
1105+
Function func) {
1106+
ExprPatternRewriter rewriter(pat, rewriter_func);
1107+
func = Downcast<Function>(rewriter(func));
1108+
func = Downcast<Function>(RemoveAllUnused(func));
1109+
return func;
1110+
}
1111+
1112+
Expr VisitExpr_(const SeqExprNode* seq) override {
1113+
auto cache = bindings_;
1114+
SeqExpr prev = GetRef<SeqExpr>(seq);
1115+
1116+
StructuralEqual struct_equal;
1117+
1118+
while (true) {
1119+
SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
1120+
if (struct_equal(prev, next)) {
1121+
return std::move(next);
1122+
}
1123+
1124+
// Canonicalization may result in two previously-different
1125+
// expressions being recognized as identical. Elimination of
1126+
// common subexpressions may result in trival var-to-var
1127+
// bindings that can be canonicalized. Therefore, iterate the
1128+
// simplification steps until converged.
1129+
while (true) {
1130+
auto start_of_loop = next;
1131+
next = Downcast<SeqExpr>(CanonicalizeBindings(next));
1132+
next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
1133+
next = Downcast<SeqExpr>(RemoveAllUnused(next));
1134+
if (struct_equal(start_of_loop, next)) {
1135+
break;
1136+
}
1137+
}
1138+
1139+
if (struct_equal(prev, next)) {
1140+
return std::move(next);
1141+
}
1142+
1143+
// Reset all knowledge of bindings that were collected from
1144+
// this SeqExpr. The collected bindings are only after
1145+
// the point where they were collected, and we are repeating
1146+
// the mutation of this SeqExpr.
1147+
bindings_ = cache;
1148+
prev = next;
1149+
}
1150+
}
1151+
1152+
void VisitBinding_(const VarBindingNode* binding) override {
1153+
auto expr = VisitExpr(binding->value);
1154+
bindings_.Set(binding->var, expr);
1155+
ReEmitBinding(binding, expr);
1156+
}
1157+
1158+
Expr VisitExpr(const Expr& expr) override {
1159+
auto node = ExprMutator::VisitExpr(expr);
1160+
1161+
if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
1162+
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
1163+
if (!rewritten_expr.same_as(node)) {
1164+
return builder_->Normalize(rewritten_expr);
1165+
}
1166+
}
1167+
1168+
return node;
1169+
}
1170+
1171+
private:
1172+
/*! \brief The pattern for rewriting call nodes */
1173+
DFPattern pattern_;
1174+
/*!
1175+
* \brief The user-provided rewriter function. Its signature and semantics are:
1176+
*
1177+
* - (Call, Map<DFPattern, Expr>) -> Call
1178+
*
1179+
* Given the matched call node and the map of patterns and
1180+
* matched expressions, it should return a new call node to
1181+
* replace the original one or the original matched call node as
1182+
* is.
1183+
*/
1184+
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func_;
1185+
1186+
/*! \brief The known variable bindings
1187+
*
1188+
* The variable bindings whose value is known. This must be tracked
1189+
* separately from the block builder, so that it can be reset after
1190+
* each iteration of the mutate-until-converged loop applied to
1191+
* `SeqExpr`.
1192+
*/
11591193
Map<Var, Expr> bindings_;
11601194
};
11611195

1162-
Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f) {
1163-
return PatternRewriter::Run(ctx, rewriter, f);
1196+
Function RewriteBindings(
1197+
const PatternContext& ctx,
1198+
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter, Function func) {
1199+
return BlockPatternRewriter::Run(ctx, rewriter, func);
11641200
}
11651201

11661202
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);
11671203

11681204
Function RewriteCall(const DFPattern& pat,
1169-
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function f) {
1170-
return PatternRewriter::Run(pat, rewriter, f);
1205+
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function func) {
1206+
return ExprPatternRewriter::Run(pat, rewriter, func);
11711207
}
11721208

11731209
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);

0 commit comments

Comments
 (0)