1818import sys
1919
2020import pytest
21+ import tvm
2122from tvm .script import tir as T
2223
23- # match buffer - use kwargs
24+ # match buffer - no syntax sugar
2425@T .prim_func
25- def elementwise (
26+ def elementwise_handle (
27+ a : T .handle ,
28+ b : T .handle ,
29+ ) -> None :
30+ A = T .match_buffer (a , (128 , 128 , 128 , 128 ))
31+ B = T .match_buffer (b , (128 , 128 , 128 , 128 ))
32+ for i , j , k , l in T .grid (128 , 128 , 128 , 128 ):
33+ with T .block ("B" ):
34+ vi , vj , vk , vl = T .axis .remap ("SSSS" , [i , j , k , l ])
35+ B [vi , vj , vk , vl ] = A [vi , vj , vk , vl ] * 2.0
36+
37+
38+ # match buffer - use buffer with kwargs
39+ @T .prim_func
40+ def elementwise_buffer_kwargs (
2641 a : T .Buffer (shape = (128 , 128 , 128 , 128 ), dtype = "float32" , elem_offset = 1 ),
2742 b : T .Buffer (shape = (128 , 128 , 128 , 128 ), dtype = "float32" , elem_offset = 2 ),
2843) -> None :
29- # A = T.match_buffer(a, (128, 128, 128, 128))
30- # B = T.match_buffer(b, (128, 128, 128, 128))
3144 for i , j , k , l in T .grid (128 , 128 , 128 , 128 ):
3245 with T .block ("B" ):
3346 vi , vj , vk , vl = T .axis .remap ("SSSS" , [i , j , k , l ])
3447 b [vi , vj , vk , vl ] = a [vi , vj , vk , vl ] * 2.0
3548
3649
37- # match buffer - no kwargs
50+ # match buffer - use buffer without kwargs
3851@T .prim_func
39- def elementwise (
52+ def elementwise_buffer_no_kwargs (
4053 a : T .Buffer [(128 , 128 , 128 , 128 ), "float32" ],
4154 b : T .Buffer [(128 , 128 , 128 , 128 ), "float32" ],
4255) -> None :
43- # A = T.match_buffer(a, (128, 128, 128, 128))
44- # B = T.match_buffer(b, (128, 128, 128, 128))
4556 for i , j , k , l in T .grid (128 , 128 , 128 , 128 ):
4657 with T .block ("B" ):
4758 vi , vj , vk , vl = T .axis .remap ("SSSS" , [i , j , k , l ])
4859 b [vi , vj , vk , vl ] = a [vi , vj , vk , vl ] * 2.0
4960
5061
62+ def test_match_buffer_syntax_sugar ():
63+ # with kwargs
64+ tvm .ir .assert_structural_equal (elementwise_handle , elementwise_buffer_kwargs )
65+ # without kwargs
66+ tvm .ir .assert_structural_equal (elementwise_handle , elementwise_buffer_no_kwargs )
67+
68+
5169if __name__ == "__main__" :
5270 sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments