@@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
973973 });
974974
975975/* !
976- * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones
976+ * \brief Apply pattern matching to each dataflow block, replacing matches
977977 * with the output of a user-provided rewriter function.
978978 */
979- class PatternRewriter : ExprMutator {
979+ class BlockPatternRewriter : ExprMutator {
980980 public:
981981 using ExprMutator::VisitBindingBlock_;
982982 using ExprMutator::VisitExpr_;
983983
984- PatternRewriter (DFPattern pat, PackedFunc rewriter_func,
985- const std::unordered_set<const VarNode*>& params)
986- : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
987-
988- PatternRewriter (const PatternContext& ctx, PackedFunc rewriter_func,
989- const std::unordered_set<const VarNode*>& params)
990- : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
984+ BlockPatternRewriter (
985+ const PatternContext& ctx,
986+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func)
987+ : ctx_(ctx), rewriter_func_(rewriter_func) {}
991988
992989 template <typename PatternType>
993- static Function Run (PatternType pat, PackedFunc rewriter_func, Function f) {
994- std::unordered_set<const VarNode*> params;
995- for (const auto & p : f->params ) {
996- params.insert (p.get ());
997- }
998- PatternRewriter rewriter (pat, rewriter_func, params);
999- return Downcast<Function>(RemoveAllUnused (rewriter.VisitExpr (f)));
1000- }
1001-
1002- Expr VisitExpr_ (const SeqExprNode* seq) override {
1003- if (ctx_) {
1004- return ExprMutator::VisitExpr_ (seq);
1005- }
1006-
1007- auto cache = bindings_;
1008- SeqExpr prev = GetRef<SeqExpr>(seq);
1009-
1010- StructuralEqual struct_equal;
1011-
1012- while (true ) {
1013- SeqExpr next = Downcast<SeqExpr>(builder_->Normalize (ExprMutator::VisitExpr_ (prev.get ())));
1014- if (struct_equal (prev, next)) {
1015- return std::move (next);
1016- }
1017-
1018- // Canonicalization may result in two previously-different
1019- // expressions being recognized as identical. Elimination of
1020- // common subexpressions may result in trival var-to-var
1021- // bindings that can be canonicalized. Therefore, iterate the
1022- // simplification steps until converged.
1023- while (true ) {
1024- auto start_of_loop = next;
1025- next = Downcast<SeqExpr>(CanonicalizeBindings (next));
1026- next = Downcast<SeqExpr>(EliminateCommonSubexpr (next));
1027- next = Downcast<SeqExpr>(RemoveAllUnused (next));
1028- if (struct_equal (start_of_loop, next)) {
1029- break ;
1030- }
1031- }
1032-
1033- if (struct_equal (prev, next)) {
1034- return std::move (next);
1035- }
1036-
1037- // Reset all knowledge of bindings that were collected from
1038- // this DataflowBlock. The collected bindings are only after
1039- // the point where they were collected, and we are repeating
1040- // the mutation of this DataflowBlock.
1041- bindings_ = cache;
1042- prev = next;
1043- }
990+ static Function Run (
991+ PatternType pat,
992+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func,
993+ Function func) {
994+ BlockPatternRewriter rewriter (pat, rewriter_func);
995+
996+ func = Downcast<Function>(rewriter (func));
997+ func = Downcast<Function>(RemoveAllUnused (func));
998+ return func;
1044999 }
10451000
10461001 BindingBlock VisitBindingBlock_ (const DataflowBlockNode* block_node) override {
1047- if (ctx_) {
1048- return RewriteDataflowBlockFixedPoint (GetRef<DataflowBlock>(block_node));
1049- } else {
1050- return ExprMutator::VisitBindingBlock_ (block_node);
1051- }
1052- }
1053-
1054- void VisitBinding_ (const VarBindingNode* binding) override {
1055- auto expr = VisitExpr (binding->value );
1056- bindings_.Set (binding->var , expr);
1057- ReEmitBinding (binding, expr);
1058- }
1059-
1060- Expr VisitExpr (const Expr& expr) override {
1061- auto node = ExprMutator::VisitExpr (expr);
1062-
1063- if (pattern_) {
1064- if (auto matches_opt = ExtractMatchedExpr (pattern_.value (), node, bindings_)) {
1065- Expr rewritten_expr = rewriter_func_ (node, matches_opt.value ());
1066- if (!rewritten_expr.same_as (node)) {
1067- return builder_->Normalize (rewritten_expr);
1068- }
1069- }
1070- }
1071- return node;
1002+ return RewriteDataflowBlockFixedPoint (GetRef<DataflowBlock>(block_node));
10721003 }
10731004
10741005 private:
@@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator {
11061037 BindingBlock RewriteDataflowBlockFixedPoint (BindingBlock block) {
11071038 auto df_block = Downcast<DataflowBlock>(block);
11081039 Map<Var, Expr> bindings = AnalyzeVar2Value (df_block);
1109- if (auto matches = MatchGraph (ctx_. value () , df_block, bindings)) {
1040+ if (auto matches = MatchGraph (ctx_, df_block, bindings)) {
11101041 builder_->BeginDataflowBlock ();
11111042 Map<Var, Expr> replacements = rewriter_func_ (matches.value (), bindings);
11121043
@@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator {
11401071 return block;
11411072 }
11421073
1143- /* ! \brief The pattern for rewriting call nodes */
1144- Optional<DFPattern> pattern_;
11451074 /* ! \brief The pattern constraint contexts for rewriting dataflow blocks */
1146- Optional< PatternContext> ctx_;
1075+ PatternContext ctx_;
11471076 /* !
11481077 * \brief The user-provided rewriter function. Its signature and semantics are:
1149- * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched
1150- * call node and the map of patterns and matched expressions, it should return a new call node
1151- * to replace the original one or the original matched call node as is.
1152- * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow block rewriting.
1153- * Given the map of patterns and corresponding variables (bound variables or parameters),
1154- * it should return a map that specifies new values for matched bound variables. It can refer
1078+ *
1079+ * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr>
1080+ *
1081+ * Given the map of patterns and corresponding variables (bound
1082+ * variables or parameters), it should return a map that
1083+ * specifies new values for matched bound variables. It can refer
11551084 * to the passed bindings to create the replacement expressions.
11561085 */
1157- PackedFunc rewriter_func_;
1158- std::unordered_set<const VarNode*> params_;
1086+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func_;
1087+ };
1088+
1089+ /* !
1090+ * \brief Apply pattern matching to each expression, replacing
1091+ * matches with the output of a user-provided rewriter function.
1092+ */
1093+ class ExprPatternRewriter : ExprMutator {
1094+ public:
1095+ using ExprMutator::VisitBindingBlock_;
1096+ using ExprMutator::VisitExpr_;
1097+
1098+ ExprPatternRewriter (DFPattern pat,
1099+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func)
1100+ : pattern_(pat), rewriter_func_(rewriter_func) {}
1101+
1102+ template <typename PatternType>
1103+ static Function Run (PatternType pat,
1104+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func,
1105+ Function func) {
1106+ ExprPatternRewriter rewriter (pat, rewriter_func);
1107+ func = Downcast<Function>(rewriter (func));
1108+ func = Downcast<Function>(RemoveAllUnused (func));
1109+ return func;
1110+ }
1111+
1112+ Expr VisitExpr_ (const SeqExprNode* seq) override {
1113+ auto cache = bindings_;
1114+ SeqExpr prev = GetRef<SeqExpr>(seq);
1115+
1116+ StructuralEqual struct_equal;
1117+
1118+ while (true ) {
1119+ SeqExpr next = Downcast<SeqExpr>(builder_->Normalize (ExprMutator::VisitExpr_ (prev.get ())));
1120+ if (struct_equal (prev, next)) {
1121+ return std::move (next);
1122+ }
1123+
1124+ // Canonicalization may result in two previously-different
1125+ // expressions being recognized as identical. Elimination of
1126+ // common subexpressions may result in trival var-to-var
1127+ // bindings that can be canonicalized. Therefore, iterate the
1128+ // simplification steps until converged.
1129+ while (true ) {
1130+ auto start_of_loop = next;
1131+ next = Downcast<SeqExpr>(CanonicalizeBindings (next));
1132+ next = Downcast<SeqExpr>(EliminateCommonSubexpr (next));
1133+ next = Downcast<SeqExpr>(RemoveAllUnused (next));
1134+ if (struct_equal (start_of_loop, next)) {
1135+ break ;
1136+ }
1137+ }
1138+
1139+ if (struct_equal (prev, next)) {
1140+ return std::move (next);
1141+ }
1142+
1143+ // Reset all knowledge of bindings that were collected from
1144+ // this SeqExpr. The collected bindings are only after
1145+ // the point where they were collected, and we are repeating
1146+ // the mutation of this SeqExpr.
1147+ bindings_ = cache;
1148+ prev = next;
1149+ }
1150+ }
1151+
1152+ void VisitBinding_ (const VarBindingNode* binding) override {
1153+ auto expr = VisitExpr (binding->value );
1154+ bindings_.Set (binding->var , expr);
1155+ ReEmitBinding (binding, expr);
1156+ }
1157+
1158+ Expr VisitExpr (const Expr& expr) override {
1159+ auto node = ExprMutator::VisitExpr (expr);
1160+
1161+ if (auto matches_opt = ExtractMatchedExpr (pattern_, node, bindings_)) {
1162+ Expr rewritten_expr = rewriter_func_ (node, matches_opt.value ());
1163+ if (!rewritten_expr.same_as (node)) {
1164+ return builder_->Normalize (rewritten_expr);
1165+ }
1166+ }
1167+
1168+ return node;
1169+ }
1170+
1171+ private:
1172+ /* ! \brief The pattern for rewriting call nodes */
1173+ DFPattern pattern_;
1174+ /* !
1175+ * \brief The user-provided rewriter function. Its signature and semantics are:
1176+ *
1177+ * - (Call, Map<DFPattern, Expr>) -> Call
1178+ *
1179+ * Given the matched call node and the map of patterns and
1180+ * matched expressions, it should return a new call node to
1181+ * replace the original one or the original matched call node as
1182+ * is.
1183+ */
1184+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func_;
1185+
1186+ /* ! \brief The known variable bindings
1187+ *
1188+ * The variable bindings whose value is known. This must be tracked
1189+ * separately from the block builder, so that it can be reset after
1190+ * each iteration of the mutate-until-converged loop applied to
1191+ * `SeqExpr`.
1192+ */
11591193 Map<Var, Expr> bindings_;
11601194};
11611195
1162- Function RewriteBindings (const PatternContext& ctx, PackedFunc rewriter, Function f) {
1163- return PatternRewriter::Run (ctx, rewriter, f);
1196+ Function RewriteBindings (
1197+ const PatternContext& ctx,
1198+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter, Function func) {
1199+ return BlockPatternRewriter::Run (ctx, rewriter, func);
11641200}
11651201
11661202TVM_REGISTER_GLOBAL (" relax.dpl.rewrite_bindings" ).set_body_typed(RewriteBindings);
11671203
11681204Function RewriteCall (const DFPattern& pat,
1169- TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function f ) {
1170- return PatternRewriter ::Run (pat, rewriter, f );
1205+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function func ) {
1206+ return ExprPatternRewriter ::Run (pat, rewriter, func );
11711207}
11721208
11731209TVM_REGISTER_GLOBAL (" relax.dpl.rewrite_call" ).set_body_typed(RewriteCall);
0 commit comments