diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 1a1cfa9d23e3..41647e261de6 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -818,16 +818,31 @@ class PatternRewriter : ExprMutator { } } - Expr VisitExpr_(const CallNode* call_node) final { - auto call = ExprMutator::VisitExpr_(call_node); - if (!pattern_) { - return call; - } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) { - auto rewriten_expr = rewriter_func_(call, matches_opt.value()); - memo_[call_node] = rewriten_expr; - return rewriten_expr; - } - return call; + Expr VisitExpr(const Expr& expr) final { + auto node = ExprMutator::VisitExpr(expr); + if (pattern_) { + if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) { + Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); + if (!rewritten_expr.same_as(node)) { + rewritten_expr = builder_->Normalize(rewritten_expr); + + // If the rewriter returns a variable (e.g. when rewriting + // from `R.add(x, R.const(0.0))` to `x`), the variable + // should be dereferenced to avoid trivial `var_2 = var_1` + // bindings. This lookup is done using the builder_ instead + // of the bindings_, as the previous `builder_->Normalize` + // call may have introduced variable bindings. + if (auto opt_var = rewritten_expr.as()) { + if (auto binding = builder_->LookupBinding(opt_var.value())) { + rewritten_expr = binding.value(); + } + } + memo_[expr.get()] = rewritten_expr; + return rewritten_expr; + } + } + } + return node; } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 4cd9c36cc8f8..3444eff79b82 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1287,7 +1287,8 @@ def before( x: R.Tensor((1024,)), ): with R.dataflow(): - out = R.add(R.const(1.0), x) + y = R.add(x, x) + out = R.add(R.const(1.0), y) R.output(out) return out @@ -1296,8 +1297,10 @@ def expected( x: R.Tensor((1024,)), ): with R.dataflow(): - out = R.add(x, R.const(2.0)) + y = R.add(x, x) + out = R.add(y, R.const(2.0)) R.output(out) + return out pattern_add = is_op("relax.add") @@ -1308,10 +1311,14 @@ def expected( pattern = pattern_op(pattern_arg, pattern_const) - def rewriter(_expr, matches): + def rewriter(expr, matches): op = matches[pattern_op] arg = matches[pattern_arg] - return rx.Call(op, [arg, rx.const(2.0)]) + const = matches[pattern_const].data.numpy() + if const.shape == tuple() and const[()] == 1.0: + return rx.Call(op, [arg, rx.const(2.0)]) + else: + return expr after = rewrite_call(pattern, rewriter, before) tvm.ir.assert_structural_equal(after, expected) @@ -1365,5 +1372,40 @@ def rewriter(_expr, matches): tvm.ir.assert_structural_equal(after, expected) +def test_rewrite_without_trivial_binding(): + """rewrite_call should avoid producing trivial "y = x" bindings""" + + @R.function(private=True) + def before(x: R.Tensor((1024,))): + with R.dataflow(): + a = R.add(x, x) + b = R.reshape(a, (1024,)) + R.output(b) + return b + + @R.function(private=True) + def expected(x: R.Tensor((1024,))): + with R.dataflow(): + a = R.add(x, x) + R.output(a) + return a + + pattern_arg = wildcard() + pattern_shape_expr = wildcard() + pattern = is_op("relax.reshape")(pattern_arg, pattern_shape_expr) + + def rewriter(expr, matches): + arg = matches[pattern_arg] + shape_expr = matches[pattern_shape_expr] + + if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr): + return arg + else: + return expr + + after = rewrite_call(pattern, rewriter, before) + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main()