diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index fb7a76f1ea7a..780b86f2ff01 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -78,11 +78,11 @@ class SimplifyReshape : public DFPatternRewrite { }; /*! - * \brief SimplifyCast matches the pattern of cast data to the same dtype. + * \brief SimplifySameCast matches the pattern of cast data to the same dtype. */ -class SimplifyCast : public DFPatternRewrite { +class SimplifySameCast : public DFPatternRewrite { public: - SimplifyCast() { + SimplifySameCast() { data_pat_ = IsWildcard(); like_pat_ = IsWildcard(); pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_}); @@ -104,6 +104,69 @@ class SimplifyCast : public DFPatternRewrite { DFPattern like_pat_; }; +/*! + * \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops + */ +class SimplifyConsecutiveCast : public DFPatternRewrite { + public: + SimplifyConsecutiveCast() { + data_ = IsWildcard(); + cast1_ = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_}); + pattern_ = IsOp("cast_like")({cast1_, IsWildcard()}) || IsOp("cast")({cast1_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + static const Op& cast_op = Op::Get("cast"); + auto data = node_map[data_][0]; + auto cast1 = Downcast(node_map[cast1_][0]); + auto data_type = Downcast(data->checked_type()); + DataType cast1_dtype; + if (cast1->op == cast_op) { + auto attr = cast1->attrs.as(); + CHECK(attr); + cast1_dtype = attr->dtype; + } else { // cast_like + cast1_dtype = Downcast(cast1->args[1]->checked_type())->dtype; + } + if (!IsWidenCast(data_type->dtype, cast1_dtype)) { + // Cannot remove the narrow cast + return post; + } + const CallNode* cast2 = post.as(); + DataType cast2_dtype; + if (cast2->op == cast_op) { + auto attr = cast2->attrs.as(); + CHECK(attr); + cast2_dtype = attr->dtype; + } else { // cast_like + cast2_dtype = Downcast(cast2->args[1]->checked_type())->dtype; + } + auto expr = MakeCast(data, cast2_dtype); + // We need to set the checked type as it may be needed in the next callback + expr->checked_type_ = TensorType(data_type->shape, cast2_dtype); + return expr; + } + + bool IsWidenCast(DataType origin, DataType cast) const { + if (origin.code() == cast.code() && origin.bits() <= cast.bits()) { + return true; + } + if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) { + // BFloat cast cannot be omitted + return false; + } + if (origin.code() < cast.code()) { + return true; + } + return false; + } + + protected: + DFPattern data_; + DFPattern cast1_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -597,7 +660,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); } diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 9f11d3827064..1734af1b5518 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -402,7 +402,7 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) -def test_simplify_cast(): +def test_simplify_same_cast(): dtype = "int32" data = relay.var("data", shape=(3, 4, 5), dtype=dtype) expr1 = relay.cast(data, dtype) @@ -416,6 +416,35 @@ def test_simplify_cast(): assert tvm.ir.structural_equal(actual2, expected) +def test_simplify_consecutive_cast(): + x = relay.var("x", shape=(3, 4, 5), dtype="int8") + y = relay.var("y", shape=(3, 4), dtype="int64") + z = relay.var("z", shape=(3,), dtype="float32") + expr1 = relay.cast(x, "int16") + expr2 = relay.cast(expr1, "int32") + expr3 = relay.cast_like(expr2, y) + expr4 = relay.cast_like(expr3, z) + + actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "int32")) + assert tvm.ir.structural_equal(actual1, expected) + actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "int64")) + assert tvm.ir.structural_equal(actual2, expected) + actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "float32")) + assert tvm.ir.structural_equal(actual3, expected) + + # cannot simplify the narrow cast + x = relay.var("x", shape=(3, 4, 5), dtype="float32") + y = relay.var("y", shape=(3, 4), dtype="float32") + expr1 = relay.cast(x, "int32") + expr2 = relay.cast_like(expr1, y) + actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(expr2) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")