Skip to content

Commit ab3572d

Browse files
committed
wip
1 parent 690b88e commit ab3572d

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

tests/python/relax/test_codegen_cutlass.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1945,5 +1945,40 @@ def main(
19451945
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
19461946

19471947

1948+
def test_attention_rewrite_multi_query():
1949+
@I.ir_module
1950+
class Module:
1951+
@R.function
1952+
def main(
1953+
q: R.Tensor((4, 16, 32, 8), dtype="float16"),
1954+
k_single: R.Tensor((4, 16, 1, 8), dtype="float16"),
1955+
v_single: R.Tensor((4, 16, 1, 8), dtype="float16"),
1956+
) -> R.Tensor((4, 16, 32, 8), dtype="float16"):
1957+
with R.dataflow():
1958+
k = R.repeat(k_single, 32, axis=2)
1959+
v = R.repeat(v_single, 32, axis=2)
1960+
1961+
lv = R.permute_dims(q, axes=[0, 2, 1, 3])
1962+
lv1 = R.reshape(lv, R.shape([128, 16, 8]))
1963+
lv2 = R.permute_dims(k, axes=[0, 2, 1, 3])
1964+
lv3 = R.reshape(lv2, R.shape([128, 16, 8]))
1965+
lv4 = R.permute_dims(v, axes=[0, 2, 1, 3])
1966+
lv5 = R.reshape(lv4, R.shape([128, 16, 8]))
1967+
1968+
lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
1969+
lv7 = R.matmul(lv1, lv6, out_dtype="float16")
1970+
lv3_1 = R.const(0.5, "float16")
1971+
lv8 = R.multiply(lv7, lv3_1)
1972+
lv11 = R.nn.softmax(lv8, axis=2)
1973+
lv12 = R.matmul(lv11, lv5, out_dtype="float16")
1974+
lv13 = R.reshape(lv12, R.shape([4, 32, 16, 8]))
1975+
lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])
1976+
R.output(lv6_1)
1977+
return lv6_1
1978+
1979+
mod = partition_for_cutlass(Module)
1980+
print(mod)
1981+
19481982
if __name__ == "__main__":
1949-
tvm.testing.main()
1983+
# tvm.testing.main()
1984+
test_attention_rewrite_multi_query()

0 commit comments

Comments
 (0)