Skip to content

Commit 9828698

Browse files
committed
wip
1 parent 5d01fd1 commit 9828698

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,9 +765,6 @@ def get_batch_on_arg(arg_name, arg_shape):
765765
and int(annotations["arch"]) >= 80
766766
)
767767

768-
print(int(attrs["head_dim"]) <= 256, int(attrs["head_dim"]) % 8 == 0, int(attrs["head_dim"]) == int(attrs["head_dim_value"]),int(annotations["arch"]) >= 80, annotations["ret_dtype"] == "float16", "bias" not in attrs, int(annotations["arch"]) >= 80)
769-
770-
771768
if use_flash:
772769
headers.append("flash.h")
773770
attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0

tests/python/relax/test_codegen_cutlass.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2005,10 +2005,23 @@ def main(
20052005
R.output(lv6_1)
20062006
return lv6_1
20072007

2008+
q_np = np.random.randn(4, 16, 32, 8).astype("float16")
2009+
k_np = np.random.randn(4, 16, 1, 8).astype("float16")
2010+
v_np = np.random.randn(4, 16, 1, 8).astype("float16")
2011+
args = [q_np, k_np, v_np]
2012+
ref = build_and_run(Module, args, "llvm", legalize=True)
2013+
print(ref)
2014+
2015+
return
2016+
20082017
Module["main"] = rewrite_attention(Module["main"])
20092018
mod = partition_for_cutlass(Module)
20102019
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
2011-
print(codegen_pass(mod))
2020+
mod = codegen_pass(mod)
2021+
2022+
out = build_and_run(Module, args, "cuda")
2023+
print(ref)
2024+
20122025

20132026
if __name__ == "__main__":
20142027
# tvm.testing.main()

0 commit comments

Comments
 (0)