Skip to content

Commit dc626f3

Browse files
authored
[TVMScript] Unify T.handle and T.Ptr (#13969)
While both represents a pointer type, `T.handle` was previously used to refer to tir variables whose `type_annotation` is `PrimType`, while `T.Ptr` instead specifically refers to `PointerType`. The divide is unnecessary if we extend `T.handle` slightly.
1 parent a5a6e7f commit dc626f3

File tree

14 files changed

+77
-64
lines changed

14 files changed

+77
-64
lines changed

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
415415
void Evaluate(PrimExpr value);
416416

417417
/*!
418-
* \brief The pointer declaration function.
418+
* \brief Create a TIR var that represents a pointer
419419
* \param dtype The data type of the pointer.
420420
* \param storage_scope The storage scope of the pointer.
421421
* \return The pointer.
422422
*/
423-
PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
423+
Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global");
424424

425425
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
426426
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
@@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
455455
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
456456
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
457457
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
458-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
459458
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
460459

461460
#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
13581358
return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member
13591359

13601360

1361-
def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
1362-
"""Construct a new tir.Var with type handle or cast expression to type handle.
1361+
def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
1362+
"""Create a TIR var that represents a pointer.
13631363
13641364
Parameters
13651365
----------
1366-
expr: PrimExpr
1367-
The expression to be cast.
1366+
dtype: str
1367+
The data type of the pointer.
1368+
1369+
storage_scope: str
1370+
The storage scope of the pointer.
13681371
13691372
Returns
13701373
-------
13711374
res : PrimExpr
13721375
The new tir.Var with type handle or casted expression with type handle.
13731376
"""
1374-
return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member
1377+
return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
13751378

13761379

13771380
def void(expr: Optional[PrimExpr] = None) -> PrimExpr:

python/tvm/script/parser/tir/entry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __call__(
7979
axis_separators=axis_separators,
8080
)
8181

82-
@deprecated("T.Buffer(...)", "T.Buffer(...)")
82+
@deprecated("T.Buffer[...]", "T.Buffer(...)")
8383
def __getitem__(self, keys) -> Buffer:
8484
if not isinstance(keys, tuple):
8585
return self(keys)
@@ -93,12 +93,13 @@ class PtrProxy:
9393
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
9494
"""
9595

96+
@deprecated("T.Ptr(...)", "T.handle(...)")
9697
def __call__(self, dtype, storage_scope="global"):
9798
if callable(dtype):
9899
dtype = dtype().dtype
99100
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
100101

101-
@deprecated("T.Ptr(...)", "T.Ptr(...)")
102+
@deprecated("T.Ptr[...]", "T.handle(...)")
102103
def __getitem__(self, keys):
103104
if not isinstance(keys, tuple):
104105
return self(keys)

src/script/ir_builder/tir/ir.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
545545
return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
546546
}
547547

548+
Var Handle(runtime::DataType dtype, String storage_scope) {
549+
Type type_annotation{nullptr};
550+
if (dtype.is_void() && storage_scope == "global") {
551+
type_annotation = PrimType(runtime::DataType::Handle());
552+
} else {
553+
type_annotation = PointerType(PrimType(dtype), storage_scope);
554+
}
555+
return tvm::tir::Var("", type_annotation);
556+
}
557+
548558
using tvm::script::ir_builder::details::Namer;
549559

550560
TVM_STATIC_IR_FUNCTOR(Namer, vtable)

src/script/printer/tir/ir.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
7373
element_type = d->AsDoc<ExprDoc>(ty->element_type, ty_p->Attr("element_type"));
7474
}
7575
if (ty->storage_scope == "") {
76-
return TIR(d, "Ptr")->Call({element_type});
76+
return TIR(d, "handle")->Call({element_type});
7777
} else {
78-
return TIR(d, "Ptr")->Call(
79-
{element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
78+
return TIR(d, "handle")
79+
->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
8080
}
8181
});
8282

tests/python/relay/aot/test_aot_create_executor_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func():
5353
class Module:
5454
@T.prim_func
5555
def __tvm_main__(
56-
a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), constants: T.Ptr(T.uint8)
56+
a: T.handle, output: T.handle, workspace: T.handle("uint8"), constants: T.handle("uint8")
5757
) -> None:
5858
# function attr dict
5959
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": ["test_device"]})

tests/python/relay/aot/test_pass_aot_lower_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
178178
def func(a: T.handle, output: T.handle) -> None:
179179
# function attr dict
180180
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []})
181-
tmp_read = T.Ptr("uint8", "")
181+
tmp_read = T.handle("uint8", "")
182182
# buffer definition
183183
tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read)
184184
a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
185185
output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16)
186186
# body
187-
tmp_write: T.Ptr(T.uint8) = output_buffer.data
187+
tmp_write: T.handle("uint8") = output_buffer.data
188188
tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
189189
for i in T.serial(140):
190190
tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])

tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def test_buffer_conditional_lowering():
424424
"""
425425

426426
@T.prim_func
427-
def before(A: T.Ptr("float32")):
427+
def before(A: T.handle("float32")):
428428
T.func_attr({"global_symbol": "main", "tir.noalias": True})
429429
for i in range(1):
430430
A_1 = T.Buffer((1,), data=A)

tests/python/unittest/test_tir_transform_storage_flatten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139
T.func_attr({"from_legacy_te_schedule": True})
140140

141141
# If a pointer defined using a LetStmt,
142-
A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", dtype="handle")
142+
A_data: T.handle("int32") = T.call_extern("dummy_extern_function", dtype="handle")
143143

144144
# and a buffer is backed by that pointer,
145145
A = T.decl_buffer([1], dtype="float32", data=A_data)

tests/python/unittest/test_tir_transform_storage_rewrite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare):
689689
"""
690690

691691
def before() -> None:
692-
A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle")
692+
A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
693693
A = T.Buffer([8], "int32", data=A_data)
694694
A[0:8] = T.broadcast(42, 8)
695695

696696
def expected() -> None:
697-
A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle")
697+
A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle")
698698
A = T.Buffer([1], "int32x8", data=A_data)
699699
A[0] = T.broadcast(42, 8)
700700

0 commit comments

Comments
 (0)