Skip to content

Commit 6c5a435

Browse files
committed
wip
1 parent 7926cbc commit 6c5a435

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/python/relax/test_codegen_cutlass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,8 +1961,8 @@ def rewrite_attention(f):
19611961
K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)
19621962

19631963
matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T)
1964-
multiply = is_op("relax.multiply")(matmul1, is_const())
1965-
softmax = is_op("relax.nn.softmax")(multiply)
1964+
multiply = is_op("relax.divide")(matmul1, is_const())
1965+
softmax = is_op("relax.astype")(is_op("relax.nn.softmax")(is_op("relax.astype")(multiply)))
19661966
matmul2 = is_op("relax.matmul")(softmax, V_BNSH)
19671967
pattern = is_op("relax.permute_dims")(is_op("relax.reshape")(matmul2, is_shape([4, 32, 16, 8])))
19681968

@@ -1996,9 +1996,9 @@ def main(
19961996

19971997
lv6 = R.permute_dims(lv3, axes=[0, 2, 1])
19981998
lv7 = R.matmul(lv1, lv6, out_dtype="float16")
1999-
lv3_1 = R.const(0.5, "float16")
2000-
lv8 = R.multiply(lv7, lv3_1)
2001-
lv11 = R.nn.softmax(lv8, axis=2)
1999+
lv3_1 = R.const(2.0, "float16")
2000+
lv8 = R.divide(lv7, lv3_1)
2001+
lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16")
20022002
lv12 = R.matmul(lv11, lv5, out_dtype="float16")
20032003
lv13 = R.reshape(lv12, R.shape([4, 32, 16, 8]))
20042004
lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3])

0 commit comments

Comments
 (0)