diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5cba87920580..d5cc1de5c675 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array bounds); void Evaluate(PrimExpr value); /*! - * \brief The pointer declaration function. + * \brief Create a TIR var that represents a pointer * \param dtype The data type of the pointer. * \param storage_scope The storage scope of the pointer. * \return The pointer. */ -PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global"); +Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global"); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(Optional expr = NullOpt) { \ @@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index fdb27df2a9d1..25d16b56dc62 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type handle or cast expression to type handle. +def handle(dtype: str = "void", storage_scope: str = "global") -> Var: + """Create a TIR var that represents a pointer. Parameters ---------- - expr: PrimExpr - The expression to be cast. + dtype: str + The data type of the pointer. + + storage_scope: str + The storage scope of the pointer. Returns ------- res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member def void(expr: Optional[PrimExpr] = None) -> PrimExpr: diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index bacf92c14287..51743e6b507b 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -79,7 +79,7 @@ def __call__( axis_separators=axis_separators, ) - @deprecated("T.Buffer(...)", "T.Buffer(...)") + @deprecated("T.Buffer[...]", "T.Buffer(...)") def __getitem__(self, keys) -> Buffer: if not isinstance(keys, tuple): return self(keys) @@ -93,12 +93,13 @@ class PtrProxy: Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr(). """ + @deprecated("T.Ptr(...)", "T.handle(...)") def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore - @deprecated("T.Ptr(...)", "T.Ptr(...)") + @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): if not isinstance(keys, tuple): return self(keys) diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index c586e81f1b9c..9ab19b2e28a5 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope)); } +Var Handle(runtime::DataType dtype, String storage_scope) { + Type type_annotation{nullptr}; + if (dtype.is_void() && storage_scope == "global") { + type_annotation = PrimType(runtime::DataType::Handle()); + } else { + type_annotation = PointerType(PrimType(dtype), storage_scope); + } + return tvm::tir::Var("", type_annotation); +} + using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index ce10ff6816d7..78e50a5eb5da 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); } if (ty->storage_scope == "") { - return TIR(d, "Ptr")->Call({element_type}); + return TIR(d, "handle")->Call({element_type}); } else { - return TIR(d, "Ptr")->Call( - {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); + return TIR(d, "handle") + ->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); } }); diff --git a/tests/python/relay/aot/test_aot_create_executor_metadata.py b/tests/python/relay/aot/test_aot_create_executor_metadata.py index 1bc79fe2a607..804738a7866a 100644 --- a/tests/python/relay/aot/test_aot_create_executor_metadata.py +++ b/tests/python/relay/aot/test_aot_create_executor_metadata.py @@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func(): class Module: @T.prim_func def __tvm_main__( - a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), constants: T.Ptr(T.uint8) + a: T.handle, output: T.handle, workspace: T.handle("uint8"), constants: T.handle("uint8") ) -> None: # function attr dict T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": ["test_device"]}) diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py index f2455e97a051..bc58812cd67c 100644 --- a/tests/python/relay/aot/test_pass_aot_lower_main.py +++ b/tests/python/relay/aot/test_pass_aot_lower_main.py @@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { def func(a: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []}) - tmp_read = T.Ptr("uint8", "") + tmp_read = T.handle("uint8", "") # buffer definition tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read) a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) # body - tmp_write: T.Ptr(T.uint8) = output_buffer.data + tmp_write: T.handle("uint8") = output_buffer.data tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write) for i in T.serial(140): tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i]) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 05d71de5bca6..758a395da6d7 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -424,7 +424,7 @@ def test_buffer_conditional_lowering(): """ @T.prim_func - def before(A: T.Ptr("float32")): + def before(A: T.handle("float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in range(1): A_1 = T.Buffer((1,), data=A) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 29623b498f43..39009164e708 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -139,7 +139,7 @@ def main(): T.func_attr({"from_legacy_te_schedule": True}) # If a pointer defined using a LetStmt, - A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", dtype="handle") + A_data: T.handle("int32") = T.call_extern("dummy_extern_function", dtype="handle") # and a buffer is backed by that pointer, A = T.decl_buffer([1], dtype="float32", data=A_data) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 4766022121df..c46754fb1742 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare): """ def before() -> None: - A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle") + A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle") A = T.Buffer([8], "int32", data=A_data) A[0:8] = T.broadcast(42, 8) def expected() -> None: - A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle") + A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle") A = T.Buffer([1], "int32x8", data=A_data) A[0] = T.broadcast(42, 8) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 5bbedd349259..58f37f04967d 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -144,20 +144,20 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr("uint8"), slow_memory_1_var: T.Ptr("uint8"), output: T.handle) -> None: + def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle("uint8"), slow_memory_1_var: T.handle("uint8"), output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") - sid_8_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") + sid_9_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_8_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @T.prim_func - def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr("uint8"), slow_memory_7_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle("uint8"), slow_memory_7_var: T.handle("uint8")) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) @@ -174,7 +174,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16") @T.prim_func - def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr("uint8"), slow_memory_3_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle("uint8"), slow_memory_3_var: T.handle("uint8")) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") @@ -185,7 +185,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr("uint8"), slow_memory_5_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle("uint8"), slow_memory_5_var: T.handle("uint8")) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") @@ -380,7 +380,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @tvm.script.ir_module class ResnetStructurePlanned: @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle("uint8")) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") @@ -390,7 +390,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle("uint8")) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") @@ -414,7 +414,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle("uint8")) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") @@ -437,7 +437,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle("uint8")) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") @@ -459,7 +459,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle("uint8")) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") @@ -481,15 +481,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr("uint8"), output: T.handle) -> None: + def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), output: T.handle) -> None: global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") - sid_6_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") - sid_7_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") - sid_8_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_2_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") + sid_6_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") + sid_7_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_8_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) @@ -557,7 +557,7 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class TensorIntrinStructurePlanned: @T.prim_func - def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None: + def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) @@ -576,7 +576,7 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None: @T.prim_func def __tvm_main__( - input: T.handle, global_workspace_1_var: T.Ptr("uint8"), output: T.handle + input: T.handle, global_workspace_1_var: T.handle("uint8"), output: T.handle ) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index e96ae4da8c2e..20be6d149808 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -40,7 +40,7 @@ def test_tir_buffer_proxy(): def test_tir_ptr_proxy(): - ptr_0 = T.Ptr("int32", "global") + ptr_0 = T.handle("int32", "global") assert ( isinstance(ptr_0, tir.Var) and ptr_0.dtype == "handle" @@ -49,7 +49,7 @@ def test_tir_ptr_proxy(): and ptr_0.type_annotation.storage_scope == "global" ) - ptr_1 = T.Ptr("float32", "shared") + ptr_1 = T.handle("float32", "shared") assert ( isinstance(ptr_1, tir.Var) and ptr_1.dtype == "handle" diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 6f96b3a3dd31..a04544152ec9 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -674,7 +674,7 @@ def test_prim_type(): def test_pointer_type(): obj = ir.PointerType(ir.PrimType("int32"), "global") - _assert_print(obj, 'T.Ptr("int32", "global")') + _assert_print(obj, 'T.handle("int32", "global")') def test_tuple_type(): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 1ec8f49b4bad..db2122336642 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -204,30 +204,30 @@ def mmult( arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") arg2_code: T.int32 = buf_type_ids[2] - A_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A_data, "storage_alignment", 128) A = T.Buffer([1024 * 1024], dtype="int32", data=A_data) - buf0_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") buf0_shape = T.Buffer([2], dtype="int32", data=buf0_shape_data) - buf0_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") buf0_strides = T.Buffer([2], dtype="int32", data=buf0_strides_data) dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - B_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") T.attr(B_data, "storage_alignment", 128) B = T.Buffer([1024 * 1024], dtype="int32", data=B_data) - buf1_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") buf1_shape = T.Buffer([2], dtype="int32", data=buf1_shape_data) - buf1_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") buf1_strides = T.Buffer([2], dtype="int32", data=buf1_strides_data) - C_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") T.attr(C_data, "storage_alignment", 128) C = T.Buffer([1024 * 1024], dtype="int32", data=C_data) - buf2_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") buf2_shape = T.Buffer([2], dtype="int32", data=buf2_shape_data) - buf2_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") buf2_strides = T.Buffer([2], dtype="int32", data=buf2_strides_data) assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( @@ -2238,7 +2238,7 @@ def opt_conv_tensorcore_mod_host( } ) # body - stack_tcode_data: T.Ptr("int32") = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_tcode_data: T.handle("int32") = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") stack_tcode = T.Buffer([9], "int32", data=stack_tcode_data) stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") assert num_args == 3, "default_function: num_args should be 3" @@ -2251,25 +2251,25 @@ def opt_conv_tensorcore_mod_host( A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A, "storage_alignment", 128) - arg0_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_shape_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") arg0_shape = T.Buffer([6], "int64", data=arg0_shape_data) - arg0_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + arg0_strides_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") arg0_strides = T.Buffer([6], "int64", data=arg0_strides_data) dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") T.attr(W, "storage_alignment", 128) - arg1_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_shape_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") arg1_shape = T.Buffer([6], "int64", data=arg1_shape_data) - arg1_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + arg1_strides_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") arg1_strides = T.Buffer([6], "int64", data=arg1_strides_data) Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") T.attr(Conv, "storage_alignment", 128) - arg2_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_shape_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") arg2_shape = T.Buffer([6], "int64", data=arg2_shape_data) - arg2_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg2_strides_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") arg2_strides = T.Buffer([6], "int64", data=arg2_strides_data) assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( @@ -3145,7 +3145,7 @@ def func(A: T.Buffer(1, "int32")): def func_T_ptr_let_statement(): @T.prim_func def func_T_ptr_let_statement( - args: T.handle, arg_type_ids_handle: T.Ptr("int32"), num_args: T.int32 + args: T.handle, arg_type_ids_handle: T.handle("int32"), num_args: T.int32 ) -> None: # The T.Ptr declaration in the parameter list should parse # correctly, and should be usable as the data pointer in a buffer. @@ -3157,14 +3157,14 @@ def func_T_ptr_let_statement( # Functions that return a "handle" can be assigned to a T.Ptr # variable. A variable annotated with T.Ptr still has dtype of # T.handle, but has type annotation as a pointer type. - A_data: T.Ptr("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + A_data: T.handle("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") # The buffer declaration has a data pointer defined earlier in # this function. It should only be defined after the data pointer # has been defined, and should not be hoisted into the header of # the function as other buffer_decl statements can be. A = T.Buffer([1024], dtype="float32", data=A_data) - B_data: T.Ptr("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B_data: T.handle("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") B = T.Buffer([1024], dtype="float32", data=B_data) B[0] = A[0] @@ -3266,13 +3266,13 @@ def string_annotation_of_special_chars(): def pointer_type(): @T.prim_func - def func_with_ptr_type_annotations(x: T.Ptr("int32"), y: T.Ptr("int32", "shared")): + def func_with_ptr_type_annotations(x: T.handle("int32"), y: T.handle("int32", "shared")): xx_data = T.allocate([16], "int32", "global") xx = T.Buffer(shape=[16], dtype="int32", scope="global", data=xx_data) yy_data = T.allocate([16], "int32", "shared") yy = T.Buffer(shape=[16], dtype="int32", scope="shared", data=yy_data) - a: T.Ptr("int32") = T.address_of(xx[0], dtype="handle") - b: T.Ptr("int32", "shared") = T.address_of(yy[0], dtype="handle") + a: T.handle("int32") = T.address_of(xx[0], dtype="handle") + b: T.handle("int32", "shared") = T.address_of(yy[0], dtype="handle") T.evaluate(T.call_extern("copy", a, b, dtype="")) return func_with_ptr_type_annotations @@ -3324,7 +3324,7 @@ def func(): def void_ptr(): @T.prim_func - def func(out_ret_value: T.Ptr("void")): + def func(out_ret_value: T.handle("void")): T.evaluate(out_ret_value) return func