2626class Module :
2727 @T .prim_func
2828 def tvm_test_cpacked (
29- A : T .handle , B : T .handle , C : T .handle , device_context : T .handle
29+ A : T .Buffer [(1 ,), "float32" ],
30+ B : T .Buffer [(1 ,), "float32" ],
31+ C : T .Buffer [(1 ,), "float32" ],
32+ device_context : T .Buffer [(1 ,), "float32" ],
3033 ) -> T .handle :
31- A_0 = T .match_buffer (A , (1 ,), dtype = "float32" )
32- A_0pre = T .preflattened_buffer (A_0 , (1 ,), dtype = "float32" )
33- B_0 = T .match_buffer (B , (1 ,), dtype = "float32" )
34- B_0pre = T .preflattened_buffer (B_0 , (1 ,), dtype = "float32" )
35- C_0 = T .match_buffer (C , (1 ,), dtype = "float32" )
36- C_0pre = T .preflattened_buffer (C_0 , (1 ,), dtype = "float32" )
37- T .evaluate (C )
34+ T .evaluate (C .data )
3835
3936 @T .prim_func
4037 def tir_packed_call () -> None :
@@ -59,15 +56,12 @@ def tir_packed_call() -> None:
5956class Expected :
6057 @T .prim_func
6158 def tvm_test_cpacked (
62- A : T .handle , B : T .handle , C : T .handle , device_context : T .handle
59+ A : T .Buffer [(1 ,), "float32" ],
60+ B : T .Buffer [(1 ,), "float32" ],
61+ C : T .Buffer [(1 ,), "float32" ],
62+ device_context : T .handle ,
6363 ) -> T .handle :
64- A_0 = T .match_buffer (A , (1 ,), dtype = "float32" )
65- A_0pre = T .preflattened_buffer (A_0 , (1 ,), dtype = "float32" )
66- B_0 = T .match_buffer (B , (1 ,), dtype = "float32" )
67- B_0pre = T .preflattened_buffer (B_0 , (1 ,), dtype = "float32" )
68- C_0 = T .match_buffer (C , (1 ,), dtype = "float32" )
69- C_0pre = T .preflattened_buffer (C_0 , (1 ,), dtype = "float32" )
70- T .evaluate (C )
64+ T .evaluate (C .data )
7165
7266 @T .prim_func
7367 def tir_packed_call () -> None :
0 commit comments