Skip to content

Commit 1138796

Browse files
committed
[TVMScript] Encourage using T.Buffer directly
Previously there are two equivalent ways of declaring a buffer in TVMScript: ```python buffer = T.buffer_decl(...) buffer = T.Buffer(...) ``` The two approaches are aliases to each other and are essentially the same in implementation. Therefore, this PR encourages to use `T.Buffer` as the recommended approach as it's a bit shorter. Meanwhile, `T.buffer_decl` will continue to be valid in TVMScript, but a deprecation warning will be emitted if its used.
1 parent dc626f3 commit 1138796

File tree

5 files changed

+19
-15
lines changed

5 files changed

+19
-15
lines changed

python/tvm/script/ir_builder/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""Package tvm.script.ir_builder.tir"""
1818
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
1919
from .ir import boolean as bool # pylint: disable=redefined-builtin
20-
from .ir import buffer_decl as Buffer
20+
from .ir import buffer as Buffer

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
# pylint: enable=unused-import
8787

8888

89-
def buffer_decl(
89+
def buffer(
9090
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
9191
dtype: str = "float32",
9292
data: Var = None,
@@ -138,7 +138,7 @@ def buffer_decl(
138138
The declared buffer.
139139
"""
140140
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
141-
return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member
141+
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
142142
shape,
143143
dtype,
144144
"",
@@ -153,6 +153,11 @@ def buffer_decl(
153153
)
154154

155155

156+
@deprecated("T.buffer_decl(...)", "T.Buffer(...)")
157+
def buffer_decl(*args, **kwargs):
158+
return buffer(*args, **kwargs)
159+
160+
156161
def prim_func() -> frame.PrimFuncFrame:
157162
"""The primitive function statement.
158163
@@ -1431,7 +1436,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
14311436
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
14321437

14331438

1434-
@deprecated("T.buffer_var", "T.Ptr")
1439+
@deprecated("T.buffer_var", "T.handle")
14351440
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
14361441
"""The pointer declaration function.
14371442
@@ -1814,6 +1819,7 @@ def wrapped(*args, **kwargs):
18141819
"float16x64",
18151820
"float32x64",
18161821
"float64x64",
1822+
"buffer",
18171823
"buffer_decl",
18181824
"prim_func",
18191825
"arg",

python/tvm/script/parser/tir/entry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

24-
from ...ir_builder.tir import buffer_decl, ptr
24+
from ...ir_builder.tir import buffer, ptr
2525
from .._core import parse, utils
2626

2727

@@ -66,7 +66,7 @@ def __call__(
6666
buffer_type="",
6767
axis_separators=None,
6868
) -> Buffer:
69-
return buffer_decl(
69+
return buffer(
7070
shape,
7171
dtype=dtype,
7272
data=data,

src/script/ir_builder/tir/ir.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)
593593
Namer::Name(var->var, name);
594594
});
595595

596-
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
597-
596+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl);
598597
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
599598
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg")
600599
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {

tests/python/unittest/test_auto_scheduler_feature.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import tempfile
2222

2323
import tvm
24-
from tvm import te, auto_scheduler, relay
24+
from tvm import auto_scheduler, relay, te
2525
from tvm.script import tir as T
26-
2726
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
2827

2928

@@ -78,8 +77,8 @@ def test_cpu_matmul():
7877
"""
7978

8079
# check touched memory in bytes, touched unique memory in bytes, reuse distance, etc.
81-
assert fequal(fea_dict[c_name + ".bytes"], math.log2(512**3 * 4 + 1))
82-
assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512**2 * 4 + 1))
80+
assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1))
81+
assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1))
8382
assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1))
8483
assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1))
8584
assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1))
@@ -209,9 +208,9 @@ def tir_matmul(
209208
) -> None:
210209
# function attr dict
211210
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
212-
A_flat = T.buffer_decl([16384], dtype="float32", data=A.data)
213-
B_flat = T.buffer_decl([16384], dtype="float32", data=B.data)
214-
C_flat = T.buffer_decl([16384], dtype="float32", data=C.data)
211+
A_flat = T.Buffer([16384], dtype="float32", data=A.data)
212+
B_flat = T.Buffer([16384], dtype="float32", data=B.data)
213+
C_flat = T.Buffer([16384], dtype="float32", data=C.data)
215214
# body
216215
for x, y in T.grid(128, 128):
217216
C_flat[x * 128 + y] = T.float32(0)

0 commit comments

Comments
 (0)