Skip to content

Commit 48adf3c

Browse files
committed
[TVMScript] Encourage using T.Buffer directly
Previously there are two equivalent ways of declaring a buffer in TVMScript: ```python buffer = T.buffer_decl(...) buffer = T.Buffer(...) ``` The two approaches are aliases to each other and are essentially the same in implementation. Therefore, this PR encourages to use `T.Buffer` as the recommended approach as it's a bit shorter. Meanwhile, `T.buffer_decl` will continue to be valid in TVMScript, but a deprecation warning will be emitted if its used.
1 parent 82cf9f7 commit 48adf3c

File tree

5 files changed

+27
-19
lines changed

5 files changed

+27
-19
lines changed

python/tvm/script/ir_builder/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""Package tvm.script.ir_builder.tir"""
1818
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
1919
from .ir import boolean as bool # pylint: disable=redefined-builtin
20-
from .ir import buffer_decl as Buffer
20+
from .ir import buffer as Buffer

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
# pylint: enable=unused-import
8787

8888

89-
def buffer_decl(
89+
def buffer(
9090
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
9191
dtype: str = "float32",
9292
data: Var = None,
@@ -138,7 +138,7 @@ def buffer_decl(
138138
The declared buffer.
139139
"""
140140
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
141-
return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member
141+
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
142142
shape,
143143
dtype,
144144
"",
@@ -153,6 +153,11 @@ def buffer_decl(
153153
)
154154

155155

156+
@deprecated("T.buffer_decl(...)", "T.Buffer(...)")
157+
def buffer_decl(*args, **kwargs):
158+
return buffer(*args, **kwargs)
159+
160+
156161
def prim_func() -> frame.PrimFuncFrame:
157162
"""The primitive function statement.
158163
@@ -1177,7 +1182,11 @@ def env_thread(thread_tag: str) -> IterVar:
11771182
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
11781183

11791184

1180-
def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None:
1185+
def buffer_store(
1186+
buffer: Buffer, # pylint: disable=redefined-outer-name
1187+
value: PrimExpr,
1188+
indices: List[Union[PrimExpr, slice]],
1189+
) -> None:
11811190
"""Buffer store node.
11821191
11831192
Parameters
@@ -1211,7 +1220,10 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr,
12111220
)
12121221

12131222

1214-
def prefetch(buffer: Buffer, bounds: List[Range]) -> None:
1223+
def prefetch(
1224+
buffer: Buffer, # pylint: disable=redefined-outer-name
1225+
bounds: List[Range],
1226+
) -> None:
12151227
"""The prefetch hint for a buffer.
12161228
12171229
Parameters
@@ -1432,7 +1444,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
14321444
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
14331445

14341446

1435-
@deprecated("T.buffer_var", "T.Ptr")
1447+
@deprecated("T.buffer_var", "T.handle")
14361448
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
14371449
"""The pointer declaration function.
14381450
@@ -1815,6 +1827,7 @@ def wrapped(*args, **kwargs):
18151827
"float16x64",
18161828
"float32x64",
18171829
"float64x64",
1830+
"buffer",
18181831
"buffer_decl",
18191832
"prim_func",
18201833
"arg",

python/tvm/script/parser/tir/entry.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

24-
from ...ir_builder.tir import buffer_decl, ptr
24+
from ...ir_builder.tir import buffer, ptr
2525
from .._core import parse, utils
2626

2727

@@ -49,9 +49,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
4949

5050

5151
class BufferProxy:
52-
"""Buffer proxy class for constructing tir buffer.
53-
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
54-
"""
52+
"""Buffer proxy class for constructing tir buffer."""
5553

5654
def __call__(
5755
self,
@@ -66,7 +64,7 @@ def __call__(
6664
buffer_type="",
6765
axis_separators=None,
6866
) -> Buffer:
69-
return buffer_decl(
67+
return buffer(
7068
shape,
7169
dtype=dtype,
7270
data=data,
@@ -89,9 +87,7 @@ def __getitem__(self, keys) -> Buffer:
8987

9088

9189
class PtrProxy:
92-
"""Ptr proxy class for constructing tir pointer.
93-
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
94-
"""
90+
"""Ptr proxy class for constructing tir pointer."""
9591

9692
@deprecated("T.Ptr(...)", "T.handle(...)")
9793
def __call__(self, dtype, storage_scope="global"):

src/script/ir_builder/tir/ir.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
593593
Namer::Name(var->var, name);
594594
});
595595

596-
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
597-
596+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl);
598597
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
599598
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg")
600599
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {

tests/python/unittest/test_auto_scheduler_feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def tir_matmul(
209209
) -> None:
210210
# function attr dict
211211
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
212-
A_flat = T.buffer_decl([16384], dtype="float32", data=A.data)
213-
B_flat = T.buffer_decl([16384], dtype="float32", data=B.data)
214-
C_flat = T.buffer_decl([16384], dtype="float32", data=C.data)
212+
A_flat = T.Buffer([16384], dtype="float32", data=A.data)
213+
B_flat = T.Buffer([16384], dtype="float32", data=B.data)
214+
C_flat = T.Buffer([16384], dtype="float32", data=C.data)
215215
# body
216216
for x, y in T.grid(128, 128):
217217
C_flat[x * 128 + y] = T.float32(0)

0 commit comments

Comments
 (0)