1+ # Licensed to the Apache Software Foundation (ASF) under one
2+ # or more contributor license agreements. See the NOTICE file
3+ # distributed with this work for additional information
4+ # regarding copyright ownership. The ASF licenses this file
5+ # to you under the Apache License, Version 2.0 (the
6+ # "License"); you may not use this file except in compliance
7+ # with the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing,
12+ # software distributed under the License is distributed on an
13+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ # KIND, either express or implied. See the License for the
15+ # specific language governing permissions and limitations
16+ # under the License.
17+ # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
18+ import sys
19+
20+ import pytest
21+ from tvm .script import tir as T
22+
23+ # match buffer - use kwargs
24+ @T .prim_func
25+ def elementwise (
26+ a : T .Buffer (shape = (128 , 128 , 128 , 128 ), dtype = "float32" , elem_offset = 1 ),
27+ b : T .Buffer (shape = (128 , 128 , 128 , 128 ), dtype = "float32" , elem_offset = 2 ),
28+ ) -> None :
29+ # A = T.match_buffer(a, (128, 128, 128, 128))
30+ # B = T.match_buffer(b, (128, 128, 128, 128))
31+ for i , j , k , l in T .grid (128 , 128 , 128 , 128 ):
32+ with T .block ("B" ):
33+ vi , vj , vk , vl = T .axis .remap ("SSSS" , [i , j , k , l ])
34+ b [vi , vj , vk , vl ] = a [vi , vj , vk , vl ] * 2.0
35+
36+
37+ # match buffer - no kwargs
38+ @T .prim_func
39+ def elementwise (
40+ a : T .Buffer [(128 , 128 , 128 , 128 ), "float32" ],
41+ b : T .Buffer [(128 , 128 , 128 , 128 ), "float32" ],
42+ ) -> None :
43+ # A = T.match_buffer(a, (128, 128, 128, 128))
44+ # B = T.match_buffer(b, (128, 128, 128, 128))
45+ for i , j , k , l in T .grid (128 , 128 , 128 , 128 ):
46+ with T .block ("B" ):
47+ vi , vj , vk , vl = T .axis .remap ("SSSS" , [i , j , k , l ])
48+ b [vi , vj , vk , vl ] = a [vi , vj , vk , vl ] * 2.0
49+
50+
51+ if __name__ == "__main__" :
52+ sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments