Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator {
Expr VisitExpr(const Expr& expr) override {
auto node = ExprMutator::VisitExpr(expr);

if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
if (!rewritten_expr.same_as(node)) {
return builder_->Normalize(rewritten_expr);
}
std::vector<DFPattern> matches_top_level;
if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
return builder_->Normalize(rewritten.value());
}

return node;
}

private:
Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
std::vector<DFPattern>* matches_top_level) {
ICHECK(matches_top_level);

// Special handling if the user-supplied pattern is a `OrPattern`.
// While the `ExtractMatchedExpr` can handle matching the
// `OrPattern`, it will return on the first match, even if the
// `rewriter_func_` doesn't apply a replacement. Unpacking the
// `OrPattern` here allows the match to be resumed if
// `rewriter_func_` returns the original function unmodified.
// This is only valid for a top-level match.
if (auto or_pattern = pattern.as<OrPatternNode>()) {
matches_top_level->push_back(pattern);
Optional<Expr> output = TryRewrite(expr, or_pattern->left, matches_top_level);
if (!output.defined()) {
output = TryRewrite(expr, or_pattern->right, matches_top_level);
}
matches_top_level->pop_back();
return output;
}

if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
auto matches = opt_matches.value();
for (const auto& pat : *matches_top_level) {
matches.Set(pat, expr);
}

Expr rewritten_expr = rewriter_func_(expr, matches);
if (!rewritten_expr.same_as(expr)) {
return builder_->Normalize(rewritten_expr);
}
}

return NullOpt;
}

/*! \brief The pattern for rewriting call nodes */
DFPattern pattern_;
/*!
Expand Down
63 changes: 63 additions & 0 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,5 +1889,68 @@ def expected():
tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_if_rewriter_returns_no_op():
"""Rewriter participates in the pattern matching

Sometimes, the pattern-matching syntax is insufficient to check if
a replacement may be performed. In this case, the `rewriter`
function may perform additional validation. If this validation
fails, the `rewriter` function can return the original expression,
and no replacement is performed.

In addition, when the `rewriter` returns the original expression,
the pattern match should backtrack to determine if another branch
of the match may have produced a replacement.

This functionality allows pattern replacements to be composed.
"""

pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())

pat_arg = wildcard()
pat_zeros = is_op("relax.zeros")(wildcard())
pat_add = is_op("relax.add")(pat_arg, pat_zeros)

# OR conditions are checked in the order that they occur. Because
# `pat_match_no_rewrite` is a superset of `pat_add`, it will
# always match first.
pat = pat_match_no_rewrite | pat_add

def rewriter(expr, matches):
if pat_match_no_rewrite in matches:
# This branch simulates a rewrite whose precondition has
# failed. If the pattern-matching treats this as a
# successful match with no replacemen required, then no
# rewrite would be performed. On the other hand, if the
# pattern-matching treats this as an unsuccessful match,
# then it can backtrack and attempt `pat_add` instead.
return expr
elif pat_add in matches:
return matches[pat_arg]
else:
raise RuntimeError("Pattern matched, but neither branch matched")

@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)

R.output(C)
return C

@R.function(private=True)
def expected():
with R.dataflow():
C = R.ones([64, 128], "int32")

R.output(C)
return C

after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)


if __name__ == "__main__":
tvm.testing.main()