Skip to content

Commit ed9ba73

Browse files
author
shingjan
committed
add test case and concrete type for buffer
1 parent e9b6845 commit ed9ba73

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

python/tvm/script/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818

1919
# Type system
2020
from .ty import int8, int16, int32, int64, float16, float32, float64
21-
from .ty import boolean, handle, Ptr, Tuple
21+
from .ty import boolean, handle, Ptr, Tuple, Buffer
2222

2323
from .prim_func import prim_func

python/tvm/script/tir/ty.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ def evaluate(self):
4646
return tvm.ir.PrimType(self.type)
4747

4848

49+
class ConcreteBufferType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method
50+
"""TVM script typing class for uniform Type objects"""
51+
52+
def __init__(self, vtype):
53+
self.type = vtype
54+
55+
def evaluate(self):
56+
return tvm.ir.PrimType(self.type)
57+
58+
def __call__(self, shape, dtype, elem_offset):
59+
pass
60+
61+
4962
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
5063
"""TVM script typing class generator for PtrType
5164
@@ -78,3 +91,4 @@ def __getitem__(self, vtypes):
7891
handle = ConcreteType("handle")
7992
Ptr = GenericPtrType()
8093
Tuple = GenericTupleType()
94+
Buffer = ConcreteBufferType("Buffer")
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
18+
import sys
19+
20+
import pytest
21+
from tvm.script import tir as T
22+
23+
# match buffer - use kwargs
24+
@T.prim_func
25+
def elementwise(
26+
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=1),
27+
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=2),
28+
) -> None:
29+
# A = T.match_buffer(a, (128, 128, 128, 128))
30+
# B = T.match_buffer(b, (128, 128, 128, 128))
31+
for i, j, k, l in T.grid(128, 128, 128, 128):
32+
with T.block("B"):
33+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
34+
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
35+
36+
37+
# match buffer - no kwargs
38+
@T.prim_func
39+
def elementwise(
40+
a: T.Buffer[(128, 128, 128, 128), "float32"],
41+
b: T.Buffer[(128, 128, 128, 128), "float32"],
42+
) -> None:
43+
# A = T.match_buffer(a, (128, 128, 128, 128))
44+
# B = T.match_buffer(b, (128, 128, 128, 128))
45+
for i, j, k, l in T.grid(128, 128, 128, 128):
46+
with T.block("B"):
47+
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
48+
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
49+
50+
51+
if __name__ == "__main__":
52+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)