Skip to content

Commit efe5cec

Browse files
committed
[TVMScript] Encourage using T.Buffer directly
Previously there are two equivalent ways of declaring a buffer in TVMScript: ```python buffer = T.buffer_decl(...) buffer = T.Buffer(...) ``` The two approaches are aliases to each other and are essentially the same in implementation. Therefore, this PR encourages to use `T.Buffer` as the recommended approach as it's a bit shorter. Meanwhile, `T.buffer_decl` will continue to be valid in TVMScript, but a deprecation warning will be emitted if its used.
1 parent dc626f3 commit efe5cec

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

python/tvm/script/ir_builder/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""Package tvm.script.ir_builder.tir"""
1818
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
1919
from .ir import boolean as bool # pylint: disable=redefined-builtin
20-
from .ir import buffer_decl as Buffer
20+
from .ir import buffer as Buffer

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
# pylint: enable=unused-import
8787

8888

89-
def buffer_decl(
89+
def buffer(
9090
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
9191
dtype: str = "float32",
9292
data: Var = None,
@@ -138,7 +138,7 @@ def buffer_decl(
138138
The declared buffer.
139139
"""
140140
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
141-
return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member
141+
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
142142
shape,
143143
dtype,
144144
"",
@@ -153,6 +153,11 @@ def buffer_decl(
153153
)
154154

155155

156+
@deprecated("T.buffer_decl(...)", "T.Buffer(...)")
157+
def buffer_decl(*args, **kwargs):
158+
return buffer(*args, **kwargs)
159+
160+
156161
def prim_func() -> frame.PrimFuncFrame:
157162
"""The primitive function statement.
158163
@@ -1431,7 +1436,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
14311436
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
14321437

14331438

1434-
@deprecated("T.buffer_var", "T.Ptr")
1439+
@deprecated("T.buffer_var", "T.handle")
14351440
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
14361441
"""The pointer declaration function.
14371442
@@ -1814,6 +1819,7 @@ def wrapped(*args, **kwargs):
18141819
"float16x64",
18151820
"float32x64",
18161821
"float64x64",
1822+
"buffer",
18171823
"buffer_decl",
18181824
"prim_func",
18191825
"arg",

python/tvm/script/parser/tir/entry.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

24-
from ...ir_builder.tir import buffer_decl, ptr
24+
from ...ir_builder.tir import buffer, ptr
2525
from .._core import parse, utils
2626

2727

@@ -49,9 +49,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
4949

5050

5151
class BufferProxy:
52-
"""Buffer proxy class for constructing tir buffer.
53-
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
54-
"""
52+
"""Buffer proxy class for constructing tir buffer."""
5553

5654
def __call__(
5755
self,
@@ -66,7 +64,7 @@ def __call__(
6664
buffer_type="",
6765
axis_separators=None,
6866
) -> Buffer:
69-
return buffer_decl(
67+
return buffer(
7068
shape,
7169
dtype=dtype,
7270
data=data,
@@ -89,9 +87,7 @@ def __getitem__(self, keys) -> Buffer:
8987

9088

9189
class PtrProxy:
92-
"""Ptr proxy class for constructing tir pointer.
93-
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
94-
"""
90+
"""Ptr proxy class for constructing tir pointer."""
9591

9692
@deprecated("T.Ptr(...)", "T.handle(...)")
9793
def __call__(self, dtype, storage_scope="global"):

src/script/ir_builder/tir/ir.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
593593
Namer::Name(var->var, name);
594594
});
595595

596-
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
597-
596+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl);
598597
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
599598
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg")
600599
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {

tests/python/unittest/test_auto_scheduler_feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def tir_matmul(
209209
) -> None:
210210
# function attr dict
211211
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
212-
A_flat = T.buffer_decl([16384], dtype="float32", data=A.data)
213-
B_flat = T.buffer_decl([16384], dtype="float32", data=B.data)
214-
C_flat = T.buffer_decl([16384], dtype="float32", data=C.data)
212+
A_flat = T.Buffer([16384], dtype="float32", data=A.data)
213+
B_flat = T.Buffer([16384], dtype="float32", data=B.data)
214+
C_flat = T.Buffer([16384], dtype="float32", data=C.data)
215215
# body
216216
for x, y in T.grid(128, 128):
217217
C_flat[x * 128 + y] = T.float32(0)

0 commit comments

Comments
 (0)