@@ -1945,5 +1945,40 @@ def main(
19451945 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
19461946
19471947
1948+ def test_attention_rewrite_multi_query ():
1949+ @I .ir_module
1950+ class Module :
1951+ @R .function
1952+ def main (
1953+ q : R .Tensor ((4 , 16 , 32 , 8 ), dtype = "float16" ),
1954+ k_single : R .Tensor ((4 , 16 , 1 , 8 ), dtype = "float16" ),
1955+ v_single : R .Tensor ((4 , 16 , 1 , 8 ), dtype = "float16" ),
1956+ ) -> R .Tensor ((4 , 16 , 32 , 8 ), dtype = "float16" ):
1957+ with R .dataflow ():
1958+ k = R .repeat (k_single , 32 , axis = 2 )
1959+ v = R .repeat (v_single , 32 , axis = 2 )
1960+
1961+ lv = R .permute_dims (q , axes = [0 , 2 , 1 , 3 ])
1962+ lv1 = R .reshape (lv , R .shape ([128 , 16 , 8 ]))
1963+ lv2 = R .permute_dims (k , axes = [0 , 2 , 1 , 3 ])
1964+ lv3 = R .reshape (lv2 , R .shape ([128 , 16 , 8 ]))
1965+ lv4 = R .permute_dims (v , axes = [0 , 2 , 1 , 3 ])
1966+ lv5 = R .reshape (lv4 , R .shape ([128 , 16 , 8 ]))
1967+
1968+ lv6 = R .permute_dims (lv3 , axes = [0 , 2 , 1 ])
1969+ lv7 = R .matmul (lv1 , lv6 , out_dtype = "float16" )
1970+ lv3_1 = R .const (0.5 , "float16" )
1971+ lv8 = R .multiply (lv7 , lv3_1 )
1972+ lv11 = R .nn .softmax (lv8 , axis = 2 )
1973+ lv12 = R .matmul (lv11 , lv5 , out_dtype = "float16" )
1974+ lv13 = R .reshape (lv12 , R .shape ([4 , 32 , 16 , 8 ]))
1975+ lv6_1 = R .permute_dims (lv13 , axes = [0 , 2 , 1 , 3 ])
1976+ R .output (lv6_1 )
1977+ return lv6_1
1978+
1979+ mod = partition_for_cutlass (Module )
1980+ print (mod )
1981+
19481982if __name__ == "__main__" :
1949- tvm .testing .main ()
1983+ # tvm.testing.main()
1984+ test_attention_rewrite_multi_query ()
0 commit comments