@@ -1961,8 +1961,8 @@ def rewrite_attention(f):
19611961 K_BNSH_T = is_op ("relax.permute_dims" )(K_BNSH )
19621962
19631963 matmul1 = is_op ("relax.matmul" )(Q_BNSH , K_BNSH_T )
1964- multiply = is_op ("relax.multiply " )(matmul1 , is_const ())
1965- softmax = is_op ("relax.nn.softmax" )(multiply )
1964+ multiply = is_op ("relax.divide " )(matmul1 , is_const ())
1965+ softmax = is_op ("relax.astype" )( is_op ( "relax. nn.softmax" )(is_op ( "relax.astype" )( multiply )) )
19661966 matmul2 = is_op ("relax.matmul" )(softmax , V_BNSH )
19671967 pattern = is_op ("relax.permute_dims" )(is_op ("relax.reshape" )(matmul2 , is_shape ([4 , 32 , 16 , 8 ])))
19681968
@@ -1996,9 +1996,9 @@ def main(
19961996
19971997 lv6 = R .permute_dims (lv3 , axes = [0 , 2 , 1 ])
19981998 lv7 = R .matmul (lv1 , lv6 , out_dtype = "float16" )
1999- lv3_1 = R .const (0.5 , "float16" )
2000- lv8 = R .multiply (lv7 , lv3_1 )
2001- lv11 = R .nn .softmax (lv8 , axis = 2 )
1999+ lv3_1 = R .const (2.0 , "float16" )
2000+ lv8 = R .divide (lv7 , lv3_1 )
2001+ lv11 = R .astype ( R . nn .softmax (R . astype ( lv8 , "float32" ), axis = 2 ), "float16" )
20022002 lv12 = R .matmul (lv11 , lv5 , out_dtype = "float16" )
20032003 lv13 = R .reshape (lv12 , R .shape ([4 , 32 , 16 , 8 ]))
20042004 lv6_1 = R .permute_dims (lv13 , axes = [0 , 2 , 1 , 3 ])
0 commit comments