@@ -1995,5 +1995,61 @@ def test_partition_parallel_branch_with_same_input():
19951995 assert tvm .ir .structural_equal (partitioned , reference )
19961996
19971997
1998+ def test_rewrite_with_pattern_recursion ():
1999+ data = relay .var ("data" , relay .TensorType ((2 , 8 ), "float32" ))
2000+ dense_weight = relay .const (np .zeros ((4 , 8 )))
2001+ feat = relay .nn .dense (data , dense_weight )
2002+ feat = relay .cast (feat , "float32" )
2003+ feat = relay .cast (feat , "float32" )
2004+ feat = relay .cast (feat , "float32" )
2005+ feat = relay .cast (feat , "float32" )
2006+ feat = relay .cast (feat , "float32" )
2007+ oup = relay .cast (feat , "float32" )
2008+
2009+ expected = relay .nn .relu (oup )
2010+
2011+ class TheRewrite (DFPatternCallback ):
2012+ def __init__ (self , pattern ):
2013+ super (TheRewrite , self ).__init__ (rewrite_once = True )
2014+ self .pattern = pattern
2015+
2016+ def callback (self , pre , post , node_map ):
2017+ return relay .nn .relu (post )
2018+
2019+ def test_reset_call_args ():
2020+ dense_pattern = is_op ("nn.dense" )(wildcard (), wildcard ())
2021+ wildcard_redirect = wildcard ()
2022+ the_pattern = is_op ("cast" )(wildcard_redirect )
2023+ the_pattern2 = the_pattern | dense_pattern
2024+ wildcard_redirect .redirect_to (the_pattern2 )
2025+
2026+ actual = rewrite (TheRewrite (the_pattern ), oup )
2027+ tvm .ir .assert_structural_equal (actual , expected )
2028+
2029+ def test_reset_alt_left ():
2030+ dense_pattern = is_op ("nn.dense" )(wildcard (), wildcard ())
2031+ wildcard_redirect = wildcard ()
2032+ or_pattern = wildcard_redirect | dense_pattern
2033+ the_pattern = is_op ("cast" )(or_pattern )
2034+ wildcard_redirect .redirect_to (the_pattern )
2035+
2036+ actual = rewrite (TheRewrite (the_pattern ), oup )
2037+ tvm .ir .assert_structural_equal (actual , expected )
2038+
2039+ def test_reset_alt_right ():
2040+ dense_pattern = is_op ("nn.dense" )(wildcard (), wildcard ())
2041+ wildcard_redirect = wildcard ()
2042+ or_pattern = dense_pattern | wildcard_redirect
2043+ the_pattern = is_op ("cast" )(or_pattern )
2044+ wildcard_redirect .redirect_to (the_pattern )
2045+
2046+ actual = rewrite (TheRewrite (the_pattern ), oup )
2047+ tvm .ir .assert_structural_equal (actual , expected )
2048+
2049+ test_reset_call_args ()
2050+ test_reset_alt_left ()
2051+ test_reset_alt_right ()
2052+
2053+
19982054if __name__ == "__main__" :
19992055 tvm .testing .main ()
0 commit comments