@@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1):
104104 mod = partition_for_cutlass (mod )
105105
106106 if assert_all_bindings_fused :
107- assert len (mod ["main" ].body .blocks [0 ].bindings ) == num_final_bindings
107+ assert (
108+ len (mod ["main" ].body .blocks [0 ].bindings ) == num_final_bindings
109+ ), "Not all bindings are fused. " + str (mod ["main" ])
108110
109111 codegen_pass = relax .transform .RunCodegen ({"cutlass" : {"sm" : 80 , "find_first_valid" : True }})
110112 mod = codegen_pass (mod )
@@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype):
714716 v_shape = (b , s_kv , n , h_v )
715717
716718 mod = get_relax_attention_module (q_shape , k_shape , v_shape , dtype = attention_dtype )
717- out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 3 )
719+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 2 )
718720
719721 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
720722
@@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size):
751753 mod = get_relax_attention_module (
752754 q_shape , k_shape , v_shape , bias_shape = bias_shape , dtype = "float32"
753755 )
754- out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 3 )
756+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 2 )
755757
756758 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
757759
@@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size, attention_scale):
786788 q_shape , k_shape , v_shape , dtype = "float32" , bias_shape = bias_shape , qk_scale = attention_scale
787789 )
788790 if bias is None :
789- out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 3 )
791+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 2 )
790792 else :
791- out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 3 )
793+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 2 )
792794 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
793795
794796
@@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
829831 )
830832
831833 if bias is None :
832- out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 3 )
834+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 2 )
833835 else :
834- out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 3 )
836+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , bias , num_final_bindings = 2 )
835837 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
836838
837839
@@ -932,9 +934,9 @@ def test_stacked_attention_split_offload(stacked_attention_size):
932934 )
933935
934936 if bias is None :
935- out = get_result_with_relax_cutlass_offload (mod , qkv , num_final_bindings = 3 )
937+ out = get_result_with_relax_cutlass_offload (mod , qkv , num_final_bindings = 2 )
936938 else :
937- out = get_result_with_relax_cutlass_offload (mod , qkv , bias , num_final_bindings = 3 )
939+ out = get_result_with_relax_cutlass_offload (mod , qkv , bias , num_final_bindings = 2 )
938940 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
939941
940942
@@ -950,9 +952,9 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size):
950952 qkv , b , s , n , h , h_v , "strided_slice" , bias , scale , single_shape = single_shape
951953 )
952954 if bias is None :
953- out = get_result_with_relax_cutlass_offload (mod , qkv , num_final_bindings = 3 )
955+ out = get_result_with_relax_cutlass_offload (mod , qkv , num_final_bindings = 2 )
954956 else :
955- out = get_result_with_relax_cutlass_offload (mod , qkv , bias , num_final_bindings = 3 )
957+ out = get_result_with_relax_cutlass_offload (mod , qkv , bias , num_final_bindings = 2 )
956958 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
957959
958960
@@ -1311,9 +1313,8 @@ def main(
13111313 R .func_attr ({"num_input" : 4 })
13121314 cls = Expected
13131315 with R .dataflow ():
1314- lv = R .vm .alloc_storage (R .shape ([65536 ]), R .prim_value (0 ), R .dtype ("uint8" ))
1315- workspace_main = R .vm .alloc_tensor (
1316- lv , R .prim_value (0 ), R .shape ([65536 ]), R .dtype ("uint8" )
1316+ workspace_main = R .builtin .alloc_tensor (
1317+ R .shape ([65536 ]), R .dtype ("uint8" ), R .prim_value (0 )
13171318 )
13181319 lv_1 = R .reshape (bias , R .shape ([128 , 16 , 8 ]))
13191320 lv1 = R .reshape (lv_1 , R .shape ([4 , 32 , 16 , 8 ]))
@@ -2419,7 +2420,7 @@ def test_sliding_window():
24192420 1 , 64 , 64 , 16 , 8 , 8 , "none" , "none" , causal , "float16" , window_size = window_size
24202421 )
24212422
2422- out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 3 )
2423+ out = get_result_with_relax_cutlass_offload (mod , q , k , v , num_final_bindings = 2 )
24232424
24242425 tvm .testing .assert_allclose (out , ref , rtol = 1e-2 , atol = 1e-2 )
24252426
0 commit comments