Skip to content

Commit 7e9054b

Browse files
committed
[Relax] Allow composition of DFPattern replacements
The `rewrite_call` function accepts a `DFPattern`, and a function to rewrite expressions matching that pattern. Often, the rewriting function will perform additional validation that cannot be expressed within the `DFPattern` itself. If this additional validation fails, the rewriter function will return the matched expression unmodified. Prior to this commit, an `OrPattern` that matches on the first branch, but whose rewriter function does not apply a modification, would prevent the second branch from being checked. This commit updates the `ExprPatternRewriter` to check both branches of a `OrPattern`, if the rewriter function of the first branch does not modify the result.
1 parent 65e9808 commit 7e9054b

File tree

2 files changed

+102
-5
lines changed

2 files changed

+102
-5
lines changed

src/relax/ir/dataflow_matcher.cc

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator {
11581158
Expr VisitExpr(const Expr& expr) override {
11591159
auto node = ExprMutator::VisitExpr(expr);
11601160

1161-
if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
1162-
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
1163-
if (!rewritten_expr.same_as(node)) {
1164-
return builder_->Normalize(rewritten_expr);
1165-
}
1161+
std::vector<DFPattern> matches_top_level;
1162+
if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
1163+
return builder_->Normalize(rewritten.value());
11661164
}
11671165

11681166
return node;
11691167
}
11701168

11711169
private:
1170+
Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
1171+
std::vector<DFPattern>* matches_top_level) {
1172+
ICHECK(matches_top_level);
1173+
1174+
// Special handling if the user-supplied pattern is a `OrPattern`.
1175+
// While the `ExtractMatchedExpr` can handle match the
1176+
// `OrPattern`, it will return on the first match, even if the
1177+
// `rewriter_func_` doesn't apply a replacement. Unpacking the
1178+
// `OrPattern` here allows the match to be resumed if
1179+
// `rewriter_func_` returns the original function unmodified.
1180+
// This is only valid for a top-level match.
1181+
if (auto or_pattern = pattern.as<OrPatternNode>()) {
1182+
matches_top_level->push_back(pattern);
1183+
Optional<Expr> output = TryRewrite(expr, or_pattern->left, matches_top_level);
1184+
if (!output.defined()) {
1185+
output = TryRewrite(expr, or_pattern->right, matches_top_level);
1186+
}
1187+
matches_top_level->pop_back();
1188+
return output;
1189+
}
1190+
1191+
if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
1192+
auto matches = opt_matches.value();
1193+
for (const auto& pat : *matches_top_level) {
1194+
matches.Set(pat, expr);
1195+
}
1196+
1197+
Expr rewritten_expr = rewriter_func_(expr, matches);
1198+
if (!rewritten_expr.same_as(expr)) {
1199+
return builder_->Normalize(rewritten_expr);
1200+
}
1201+
}
1202+
1203+
return NullOpt;
1204+
}
1205+
11721206
/*! \brief The pattern for rewriting call nodes */
11731207
DFPattern pattern_;
11741208
/*!

tests/python/relax/test_dataflow_pattern.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,5 +1889,68 @@ def expected():
18891889
tvm.ir.assert_structural_equal(expected, after)
18901890

18911891

1892+
def test_backtrack_if_rewriter_returns_no_op():
1893+
"""Rewriter participates in the pattern matching
1894+
1895+
Sometimes, the pattern-matching syntax is insufficient to check if
1896+
a replacement may be performed. In this case, the `rewriter`
1897+
function may perform additional validation. If this validation
1898+
fails, the `rewriter` function can return the original expression,
1899+
and no replacement is performed.
1900+
1901+
In addition, when the `rewriter` returns the original expression,
1902+
the pattern match should backtrack to determine if another branch
1903+
of the match may have produced a replacement.
1904+
1905+
This functionality allows pattern replacements to be composed.
1906+
"""
1907+
1908+
pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())
1909+
1910+
pat_arg = wildcard()
1911+
pat_zeros = is_op("relax.zeros")(wildcard())
1912+
pat_add = is_op("relax.add")(pat_arg, pat_zeros)
1913+
1914+
# OR conditions are checked in the order that they occur. Because
1915+
# `pat_match_no_rewrite` is a superset of `pat_add`, it will
1916+
# always match first.
1917+
pat = pat_match_no_rewrite | pat_add
1918+
1919+
def rewriter(expr, matches):
1920+
if pat_match_no_rewrite in matches:
1921+
# This branch simulates a rewrite whose precondition has
1922+
# failed. If the pattern-matching treats this as a
1923+
# successful match with no replacemen required, then no
1924+
# rewrite would be performed. On the other hand, if the
1925+
# pattern-matching treats this as an unsuccessful match,
1926+
# then it can backtrack and attempt `pat_add` instead.
1927+
return expr
1928+
elif pat_add in matches:
1929+
return matches[pat_arg]
1930+
else:
1931+
raise RuntimeError("Pattern matched, but neither branch matched")
1932+
1933+
@R.function(private=True)
1934+
def before():
1935+
with R.dataflow():
1936+
A = R.ones([64, 128], "int32")
1937+
B = R.zeros([64, 128], "int32")
1938+
C = R.add(A, B)
1939+
1940+
R.output(C)
1941+
return C
1942+
1943+
@R.function(private=True)
1944+
def expected():
1945+
with R.dataflow():
1946+
C = R.ones([64, 128], "int32")
1947+
1948+
R.output(C)
1949+
return C
1950+
1951+
after = rewrite_call(pat, rewriter, before)
1952+
tvm.ir.assert_structural_equal(expected, after)
1953+
1954+
18921955
if __name__ == "__main__":
18931956
tvm.testing.main()

0 commit comments

Comments
 (0)