From 2f5d03603e149784d7fff476e6e0704c85f17fc1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Jul 2024 15:52:48 -0700 Subject: [PATCH 01/12] [Relax] Implement R.ensure_aligned and update memory planning for R.view --- python/tvm/relax/op/memory/__init__.py | 2 +- python/tvm/relax/op/memory/view.py | 17 +++ src/relax/backend/vm/vm_builtin_lower.cc | 20 ++++ src/relax/op/memory/view.cc | 34 +++++- src/relax/op/memory/view.h | 3 + .../transform/static_plan_block_memory.cc | 13 ++- src/runtime/relax_vm/builtin.cc | 13 +++ tests/python/relax/test_op_view.py | 105 +++++++----------- ...test_transform_static_plan_block_memory.py | 55 +++++++++ 9 files changed, 185 insertions(+), 77 deletions(-) diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 422c5d2e1f53..2ae1b676e035 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view +from .view import view, ensure_aligned diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 0c3d8a03b2dd..233d07f6c9b7 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls): relative_byte_offset = _normalize(relative_byte_offset, PrimValue) return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore + + +def ensure_aligned(data: Expr) -> Expr: + """ + Ensure the tensor has elem_offset == 0. A copy will be made if necessary. + + Parameters + ---------- + data : relax.Expr + The input tensor + + Results + ------- + result : relax.Expr + The aligned tensor + """ + return _ffi_api.ensure_aligned(data) # type: ignore diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index 887998d004c7..961aa9b600f8 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -47,6 +47,10 @@ class VMBuiltinLowerMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); + } else if (call->op == view_op_) { + return View(call); + } else if (call->op == ensure_aligned_op_) { + return EnsureAligned(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -124,6 +128,19 @@ class VMBuiltinLowerMutator : public ExprMutator { } } + Expr View(const Call& view_node) { + StructInfoDeriveFunc infer_sinfo_env_func; + infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo}); + } + + Expr EnsureAligned(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr ShapeOf(const Call& call_node) { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); @@ -188,6 +205,8 @@ class VMBuiltinLowerMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& view_op_ = Op::Get("relax.memory.view"); + const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -208,6 +227,7 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; + const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"}; }; Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index e7634c7edfce..d43cc01838ae 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -334,13 +334,12 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { relative_byte_offset = relax::PrimValue::Int64(0); } - StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); - auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); - - ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) && + relative_byte_offset.same_as(call->args[3])) { + return call; + } - return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); + return Call(call->op, {data, shape, dtype, relative_byte_offset}); } TVM_REGISTER_OP("relax.memory.view") @@ -355,5 +354,28 @@ TVM_REGISTER_OP("relax.memory.view") .set_attr("FLegalize", LegalizeView) .set_attr("FPurity", Bool(true)); +Expr ensure_aligned(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_aligned"); + return Call(op, {x}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned); + +StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " should receive 1 argument, " + << "but received " << call->args); + } + return GetStructInfo(call->args[0]); +} + +TVM_REGISTER_OP("relax.memory.ensure_aligned") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FInferStructInfo", InferStructInfoEnsureAligned) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index bc8002fa5b69..77ec7e9833cc 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -32,6 +32,9 @@ namespace relax { /*! \brief View a tensor with different properties. */ Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ +Expr ensure_aligned(const Expr& x); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2b16d8650906..2922de6dcc7e 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,8 +286,13 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is "relax.reshape". */ -bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } +/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */ +bool IsInplaceMemoryOp(const Expr& op) { + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& view_op = Op::Get("relax.memory.view"); + static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op); +} /*! \brief The base class for the storage allocation visitor. */ class StorageAllocatorBaseVisitor : public ExprVisitor { @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Create a storage token for builtin alloc_tensor. this->CreateToken(call); return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { // Reuse the input's token for builtin reshape. SetTokens(call, GetTokens(call->args[0])); return; @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { block_tokens.push_back(new_token.get()); } return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { Tokens tokens = GetTokens(call->args[0]); ICHECK(!tokens.IsNested()); if (tokens.IsLeaf()) { diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 2af31f1d4021..83b016446548 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -545,6 +545,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); +TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) { + if (data->byte_offset == 0) { + return data; + } + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + // For platforms that does not support pointer arithmetic, we need to copy the data to a new + // buffer. + return NDArray::FromDLPack(dl_tensor); +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 2433821c2abd..1e21612f9fff 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([64, 64]), - R.dtype("float32"), - R.prim_value(0), - ) + B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) return B After = tvm.relax.transform.LegalizeOps()(Before) @@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")): class Expected: @R.function def main(A: R.Tensor(dtype="float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([64, 64]), - R.dtype("float32"), - R.prim_value(0), - ) + B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) return B After = tvm.relax.transform.LegalizeOps()(Before) @@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([4096]), - R.dtype("int32"), - R.prim_value(0), + B = R.memory.view( + A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0 ) return B @@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([4096]), - R.dtype("float32"), - R.prim_value(0), + B = R.memory.view( + A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32" ) return B @@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")): class Expected: @R.function def main(A: R.Tensor([4096], "uint8")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( + B = R.memory.view( A, - R.shape([512]), - R.dtype("int32"), - R.prim_value(0), + shape=R.shape([512]), + dtype=R.dtype("int32"), + relative_byte_offset=R.prim_value(0), ) - C = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( + C = R.memory.view( A, - R.shape([16, 64]), - R.dtype("float16"), - R.prim_value(2048), + shape=R.shape([16, 64]), + dtype=R.dtype("float16"), + relative_byte_offset=R.prim_value(2048), ) return (B, C) @@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")): tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1]) +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.memory.view( + A, + shape=R.shape([16, 64]), + relative_byte_offset=32 * 64 * 4, + ) + C = R.memory.ensure_aligned(B) + return C + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.reshape(64, 64)[32:48, :] + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 63f422d4cfbe..3ab468844b01 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1449,5 +1449,60 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_view(): + @I.ir_module + class Before: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main(): + cls = Before + x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) + x1 = R.memory.view(x, [128], "float32", 0) + x2 = R.memory.ensure_aligned(x1) + y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(x2, y) + z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(y, z) + return z + + @I.ir_module + class Expected: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main() -> R.Tensor((128,), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32") + ) + x1: R.Tensor((128,), dtype="float32") = R.memory.view( + x, R.shape([128]), R.dtype("float32"), R.prim_value(0) + ) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1) + storage1: R.Object = R.memory.alloc_storage( + R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([128]), R.dtype("float32") + ) + cls.tir_exp(x2, y) + z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.tir_exp(y, z) + return z + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() From 59fada7cc1d8b836860b121e66f8b55a30b63d56 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 15 Jul 2024 15:29:21 -0700 Subject: [PATCH 02/12] Rename to ensure_zero_offset and LowerRuntimeBuiltin --- include/tvm/relax/backend.h | 2 +- include/tvm/relax/op_attr_types.h | 9 +++ include/tvm/runtime/device_api.h | 4 + python/tvm/relax/op/memory/__init__.py | 2 +- python/tvm/relax/op/memory/view.py | 6 +- python/tvm/relax/pipeline.py | 2 +- python/tvm/relax/transform/__init__.py | 25 +++--- python/tvm/relax/transform/transform.py | 4 +- ...ltin_lower.cc => lower_runtime_builtin.cc} | 45 ++++------- src/relax/op/memory/view.cc | 79 +++++++++++-------- .../transform/static_plan_block_memory.cc | 6 +- src/runtime/relax_vm/builtin.cc | 20 +++-- tests/python/relax/test_op_view.py | 2 +- ...test_transform_static_plan_block_memory.py | 4 +- 14 files changed, 111 insertions(+), 99 deletions(-) rename src/relax/backend/vm/{vm_builtin_lower.cc => lower_runtime_builtin.cc} (86%) diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 2fb11f5a6f83..e7d13c47b2bd 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -35,7 +35,7 @@ namespace transform { * * \return The Pass. */ -TVM_DLL Pass VMBuiltinLower(); +TVM_DLL Pass LowerRuntimeBuiltin(); /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index b44c4582d82d..c644e208f916 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc; +/*! \brief The function type of a function to lower the runtime builtin. + * + * A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`. + * +* \param bb The BlockBuilder context. +* \param call The call to be lowered. +*/ +using FLowerBuiltin = runtime::TypedPackedFunc; + /*! * \brief Gradient for a specific op. * diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 14b2b84b0d36..0072981be513 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -240,6 +240,10 @@ class TVM_DLL DeviceAPI { return device_type != kDLCPU && device_type != kDLMicroDev; } + static bool SupportsPointerArithmetics(int device_type) { + return device_type != kDLVulkan; + } + protected: /*! * \brief copy data from one place to another diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 2ae1b676e035..1191550085de 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view, ensure_aligned +from .view import view, ensure_zero_offset diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 233d07f6c9b7..95adc782092f 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -94,7 +94,7 @@ def _normalize(expr, relax_cls): return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore -def ensure_aligned(data: Expr) -> Expr: +def ensure_zero_offset(data: Expr) -> Expr: """ Ensure the tensor has elem_offset == 0. A copy will be made if necessary. @@ -106,6 +106,6 @@ def ensure_aligned(data: Expr) -> Expr: Results ------- result : relax.Expr - The aligned tensor + The tensor with elem_offset == 0 """ - return _ffi_api.ensure_aligned(data) # type: ignore + return _ffi_api.ensure_zero_offset(data) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index d068f800d0e9..38242ff4d2d3 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I transform.RewriteCUDAGraph(), transform.LowerAllocTensor(), transform.KillAfterLastUse(), - transform.VMBuiltinLower(), + transform.LowerRuntimeBuiltin(), transform.ComputePrimValue(), transform.VMShapeLower(), transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5789e2fcf235..eef6d331375c 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -16,6 +16,15 @@ # under the License. """Relax transformations. """ +# Import to register the legalization functions. +from . import legalize_ops, tuning_api +from .attach_external_modules import AttachExternModules +from .fast_math import FastMathTransform +from .ipc_allreduce_rewrite import IPCAllReduceRewrite +from .lazy_transform_params import LazyTransformParams +from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage +from .optimize_layout_transform import OptimizeLayoutTransform +from .remove_redundant_reshape import RemoveRedundantReshape from .transform import ( AdjustMatmulOrder, AllocateWorkspace, @@ -55,6 +64,7 @@ LegalizeOps, LiftTransformParams, LowerAllocTensor, + LowerRuntimeBuiltin, MergeCompositeFunctions, MetaScheduleApplyDatabase, MetaScheduleTuneIRMod, @@ -64,8 +74,8 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, - RemoveUnusedParameters, RemoveUnusedOutputs, + RemoveUnusedParameters, ReorderPermuteDimsAfterConcat, ReorderTakeAfterMatmul, RewriteCUDAGraph, @@ -78,20 +88,7 @@ TopologicalSort, UpdateParamStructInfo, UpdateVDevice, - VMBuiltinLower, VMShapeLower, dataflowblock_pass, function_pass, ) - -from .ipc_allreduce_rewrite import IPCAllReduceRewrite -from .lazy_transform_params import LazyTransformParams -from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage -from .optimize_layout_transform import OptimizeLayoutTransform -from .remove_redundant_reshape import RemoveRedundantReshape -from .fast_math import FastMathTransform -from .fuse_transpose_matmul import FuseTransposeMatmul -from .attach_external_modules import AttachExternModules - -# Import to register the legalization functions. -from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3528b4429e6f..e017bc113b2c 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -586,14 +586,14 @@ def ComputePrimValue() -> tvm.ir.transform.Pass: return _ffi_api.ComputePrimValue() # type: ignore -def VMBuiltinLower() -> tvm.ir.transform.Pass: +def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. Returns ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.VMBuiltinLower() # type: ignore + return _ffi_api.LowerRuntimeBuiltin() # type: ignore def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/lower_runtime_builtin.cc similarity index 86% rename from src/relax/backend/vm/vm_builtin_lower.cc rename to src/relax/backend/vm/lower_runtime_builtin.cc index 961aa9b600f8..7fff6c95329d 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -17,13 +17,14 @@ * under the License. */ /*! - * \file src/relax/backend/vm/vm_builtin_lower.cc + * \file src/relax/backend/vm/lower_runtime_builtin.cc * \brief Lowers most builtin functions and packed calls. */ #include #include #include #include +#include #include #include #include @@ -33,11 +34,12 @@ namespace relax { // This pass lowers most ops to VM specific builtins. // TODO(relax-team): revisit after PrimValue. -class VMBuiltinLowerMutator : public ExprMutator { +class LowerRuntimeBuiltinMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { + static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); // post-order mutation Call call = Downcast(VisitExprPostOrder_(call_node)); @@ -47,10 +49,6 @@ class VMBuiltinLowerMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); - } else if (call->op == view_op_) { - return View(call); - } else if (call->op == ensure_aligned_op_) { - return EnsureAligned(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -68,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator { return MakeMemAllocTensor(call); } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); - } else { - return call; + } else if (const auto* op_node = call->op.as()) { + Op op = GetRef(op_node); + if (lower_builtin_fmap.count(op)) { + return lower_builtin_fmap[op](builder_, call); + } } + return call; } Expr MakeMemAllocStorage(const Call& call) { @@ -128,19 +130,6 @@ class VMBuiltinLowerMutator : public ExprMutator { } } - Expr View(const Call& view_node) { - StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); - auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); - ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); - return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo}); - } - - Expr EnsureAligned(const Call& call_node) { - ICHECK(call_node->args.size() == 1); - return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)}); - } - Expr ShapeOf(const Call& call_node) { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); @@ -205,8 +194,6 @@ class VMBuiltinLowerMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); - const Op& view_op_ = Op::Get("relax.memory.view"); - const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -227,20 +214,20 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; - const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"}; + }; -Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } +Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); } namespace transform { -Pass VMBuiltinLower() { +Pass LowerRuntimeBuiltin() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; - return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); + [=](Function f, IRModule m, PassContext pc) { return Downcast(LowerRuntimeBuiltin(f)); }; + return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); +TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index d43cc01838ae..b582748e64b5 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -58,9 +58,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (auto opt = sinfo.as()) { return opt.value(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects first argument to be a tensor, " - << "but received " << arg_data << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects first argument to be a tensor, " << "but received " << arg_data + << " with type " << sinfo; } }(); auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { @@ -73,10 +73,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // The `R.view` operation returns a different shape. return ptr; } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects second argument to be a ShapeExpr, " - << "or a void-type (empty relax tuple), " - << "but received " << arg_shape << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects second argument to be a ShapeExpr, " + << "or a void-type (empty relax tuple), " << "but received " << arg_shape + << " with type " << sinfo; } }(); @@ -111,10 +111,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // being changed into. return DataType::Void(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op - << " expects the dtype argument to be a relax::DataTypeImm, " - << "but received " << arg_dtype << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects the dtype argument to be a relax::DataTypeImm, " << "but received " + << arg_dtype << " with type " << sinfo; } }(); @@ -126,8 +125,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return IntImm(DataType::Int(64), 0); } else if (auto prim_sinfo = sinfo.as()) { CHECK_EQ(prim_sinfo->dtype, DataType::Int(64)) - << "TypeError: " - << "Operator " << call->op + << "TypeError: " << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << sinfo; if (prim_sinfo->value.defined()) { @@ -139,9 +137,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return NullOpt; } } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects the relative_byte_offset argument " - << "to be a Relax PrimValue. " + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects the relative_byte_offset argument " << "to be a Relax PrimValue. " << "However, expression " << call << " provides relative_byte_offset of " << arg_relative_byte_offset << ", which has type " << sinfo; } @@ -246,8 +243,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // view to be larger than the original array. CHECK_GE(input_element_size.value()->value, output_element_size.value()->value) - << "ValueError: " - << "Operator " << call->op + << "ValueError: " << "Operator " << call->op << " may not produce a view that exceeds the bounds of the original array. " << "In expression " << call << " the data type is changed from " << data_sinfo->dtype << " to " << view_dtype.value() << ", increasing the size per element from " @@ -313,9 +309,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(data_shape.defined()) << "Legalization of " << call->op << " requires that either the output shape be explicitly specified, " - << "or the input shape is known. " - << "However, in expression " << call << ", no output shape is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown shape."; + << "or the input shape is known. " << "However, in expression " << call + << ", no output shape is specified, " << "and the input " << data << " of type " + << data->struct_info_ << " has unknown shape."; shape = ShapeExpr(data_shape.value()); } @@ -324,9 +320,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(!data_dtype.is_void()) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " - << "or the input dtype is known. " - << "However, in expression " << call << ", no output dtype is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown dtype."; + << "or the input dtype is known. " << "However, in expression " << call + << ", no output dtype is specified, " << "and the input " << data << " of type " + << data->struct_info_ << " has unknown dtype."; dtype = relax::DataTypeImm(data_dtype); } @@ -342,6 +338,14 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { return Call(call->op, {data, shape, dtype, relative_byte_offset}); } +Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { + StructInfoDeriveFunc infer_sinfo_env_func; + infer_sinfo_env_func= EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + return Call(runtime_view_func, call->args, call->attrs, {runtime_view_sinfo}); +} + TVM_REGISTER_OP("relax.memory.view") .set_num_inputs(4) .add_argument("x", "Tensor", "The input tensor.") @@ -352,30 +356,37 @@ TVM_REGISTER_OP("relax.memory.view") .set_attr("RequiresArgumentShapes", Bool(false)) .set_attr("FInferStructInfo", InferStructInfoView) .set_attr("FLegalize", LegalizeView) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinView); -Expr ensure_aligned(const Expr& x) { - static const Op& op = Op::Get("relax.memory.ensure_aligned"); +Expr ensure_zero_offset(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_zero_offset"); return Call(op, {x}); } -TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned); +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); -StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << call->op << " should receive 1 argument, " - << "but received " << call->args); + << "Operator " << call->op << " should receive 1 argument, " << "but received " + << call->args); } return GetStructInfo(call->args[0]); } -TVM_REGISTER_OP("relax.memory.ensure_aligned") +Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { + const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; + return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)}); +} + +TVM_REGISTER_OP("relax.memory.ensure_zero_offset") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") .set_attr("RequiresArgumentShapes", Bool(false)) - .set_attr("FInferStructInfo", InferStructInfoEnsureAligned) - .set_attr("FPurity", Bool(true)); + .set_attr("FInferStructInfo", InferStructInfoEnsureZeroOffset) + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinEnsureZeroOffset); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2922de6dcc7e..74200526b699 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,12 +286,12 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */ +/*! \brief Check if the input op is a memory op that may return the same buffer. */ bool IsInplaceMemoryOp(const Expr& op) { static const Op& reshape_op = Op::Get("relax.reshape"); static const Op& view_op = Op::Get("relax.memory.view"); - static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned"); - return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op); + static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op); } /*! \brief The base class for the storage allocation visitor. */ diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 83b016446548..1227c5163c31 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -545,17 +545,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); -TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) { +TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { if (data->byte_offset == 0) { return data; } - DLManagedTensor* dl_tensor = data.ToDLPack(); - dl_tensor->dl_tensor.data = - reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; - dl_tensor->dl_tensor.byte_offset = 0; - // For platforms that does not support pointer arithmetic, we need to copy the data to a new - // buffer. - return NDArray::FromDLPack(dl_tensor); + if (DeviceAPI::SupportsPointerArithmetics(data->device.device_type)) { + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + return NDArray::FromDLPack(dl_tensor); + } else { + auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + new_array.CopyFrom(data); + return new_array; + } }); } // namespace relax_vm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 1e21612f9fff..033aee9882a4 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -731,7 +731,7 @@ def main(A: R.Tensor([4096], "float32")): shape=R.shape([16, 64]), relative_byte_offset=32 * 64 * 4, ) - C = R.memory.ensure_aligned(B) + C = R.memory.ensure_zero_offset(B) return C built = tvm.relax.build(Module, target=target) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 3ab468844b01..911a19d43592 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1461,7 +1461,7 @@ def main(): cls = Before x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) x1 = R.memory.view(x, [128], "float32", 0) - x2 = R.memory.ensure_aligned(x1) + x2 = R.memory.ensure_zero_offset(x1) y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) cls.tir_exp(x2, y) z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) @@ -1486,7 +1486,7 @@ def main() -> R.Tensor((128,), dtype="float32"): x1: R.Tensor((128,), dtype="float32") = R.memory.view( x, R.shape([128]), R.dtype("float32"), R.prim_value(0) ) - x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_zero_offset(x1) storage1: R.Object = R.memory.alloc_storage( R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") ) From c742a07390989cbb798c6268533be8b06207771c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 15 Jul 2024 18:09:16 -0700 Subject: [PATCH 03/12] lint --- src/relax/backend/vm/lower_runtime_builtin.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 7fff6c95329d..a3867ae92448 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -39,7 +39,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { - static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); + static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); // post-order mutation Call call = Downcast(VisitExprPostOrder_(call_node)); @@ -214,7 +214,6 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; - }; Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); } @@ -223,7 +222,9 @@ namespace transform { Pass LowerRuntimeBuiltin() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(LowerRuntimeBuiltin(f)); }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LowerRuntimeBuiltin(f)); + }; return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } From c4e1f294b8b64b34cffd412649f0547f908b53f6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2024 14:04:06 -0700 Subject: [PATCH 04/12] Add warnings --- python/tvm/relax/transform/transform.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index e017bc113b2c..2546284625e9 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,6 +19,7 @@ import functools import inspect import types +import warnings from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np # type: ignore @@ -596,6 +597,20 @@ def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass: return _ffi_api.LowerRuntimeBuiltin() # type: ignore +def VMBuiltinLower() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + warnings.warn( + "tvm.relax.transform.VMBuiltinLower has been renamed to 'LowerRuntimeBuiltin'. " + "This wrapper is for backwards compatibility, and will be removed in a later update." + ) + return _ffi_api.LowerRuntimeBuiltin() # type: ignore + + def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: """Lower the symbolic shape and argument and match-cast structinfo matching. From d1a3243aa0e68ca1b96efe566a86a80214b1a3de Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2024 14:05:28 -0700 Subject: [PATCH 05/12] Apply format --- src/relax/op/memory/view.cc | 50 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index b582748e64b5..e221b8ec5864 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -58,9 +58,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (auto opt = sinfo.as()) { return opt.value(); } else { - LOG(FATAL) << "TypeError: " << "Operator " << call->op - << " expects first argument to be a tensor, " << "but received " << arg_data - << " with type " << sinfo; + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects first argument to be a tensor, " + << "but received " << arg_data << " with type " << sinfo; } }(); auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { @@ -73,10 +73,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // The `R.view` operation returns a different shape. return ptr; } else { - LOG(FATAL) << "TypeError: " << "Operator " << call->op - << " expects second argument to be a ShapeExpr, " - << "or a void-type (empty relax tuple), " << "but received " << arg_shape - << " with type " << sinfo; + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects second argument to be a ShapeExpr, " + << "or a void-type (empty relax tuple), " + << "but received " << arg_shape << " with type " << sinfo; } }(); @@ -111,9 +111,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // being changed into. return DataType::Void(); } else { - LOG(FATAL) << "TypeError: " << "Operator " << call->op - << " expects the dtype argument to be a relax::DataTypeImm, " << "but received " - << arg_dtype << " with type " << sinfo; + LOG(FATAL) << "TypeError: " + << "Operator " << call->op + << " expects the dtype argument to be a relax::DataTypeImm, " + << "but received " << arg_dtype << " with type " << sinfo; } }(); @@ -125,7 +126,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return IntImm(DataType::Int(64), 0); } else if (auto prim_sinfo = sinfo.as()) { CHECK_EQ(prim_sinfo->dtype, DataType::Int(64)) - << "TypeError: " << "Operator " << call->op + << "TypeError: " + << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << sinfo; if (prim_sinfo->value.defined()) { @@ -137,8 +139,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return NullOpt; } } else { - LOG(FATAL) << "TypeError: " << "Operator " << call->op - << " expects the relative_byte_offset argument " << "to be a Relax PrimValue. " + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects the relative_byte_offset argument " + << "to be a Relax PrimValue. " << "However, expression " << call << " provides relative_byte_offset of " << arg_relative_byte_offset << ", which has type " << sinfo; } @@ -243,7 +246,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // view to be larger than the original array. CHECK_GE(input_element_size.value()->value, output_element_size.value()->value) - << "ValueError: " << "Operator " << call->op + << "ValueError: " + << "Operator " << call->op << " may not produce a view that exceeds the bounds of the original array. " << "In expression " << call << " the data type is changed from " << data_sinfo->dtype << " to " << view_dtype.value() << ", increasing the size per element from " @@ -309,9 +313,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(data_shape.defined()) << "Legalization of " << call->op << " requires that either the output shape be explicitly specified, " - << "or the input shape is known. " << "However, in expression " << call - << ", no output shape is specified, " << "and the input " << data << " of type " - << data->struct_info_ << " has unknown shape."; + << "or the input shape is known. " + << "However, in expression " << call << ", no output shape is specified, " + << "and the input " << data << " of type " << data->struct_info_ << " has unknown shape."; shape = ShapeExpr(data_shape.value()); } @@ -320,9 +324,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(!data_dtype.is_void()) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " - << "or the input dtype is known. " << "However, in expression " << call - << ", no output dtype is specified, " << "and the input " << data << " of type " - << data->struct_info_ << " has unknown dtype."; + << "or the input dtype is known. " + << "However, in expression " << call << ", no output dtype is specified, " + << "and the input " << data << " of type " << data->struct_info_ << " has unknown dtype."; dtype = relax::DataTypeImm(data_dtype); } @@ -340,7 +344,7 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func= EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); return Call(runtime_view_func, call->args, call->attrs, {runtime_view_sinfo}); @@ -369,8 +373,8 @@ TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_ StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << call->op << " should receive 1 argument, " << "but received " - << call->args); + << "Operator " << call->op << " should receive 1 argument, " + << "but received " << call->args); } return GetStructInfo(call->args[0]); } From c931b4548955d22035b1734c4bcc198a632edb29 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2024 14:11:27 -0700 Subject: [PATCH 06/12] Check kAllocAlignment --- src/runtime/relax_vm/builtin.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 1227c5163c31..e77edd9184f1 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -549,7 +549,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray d if (data->byte_offset == 0) { return data; } - if (DeviceAPI::SupportsPointerArithmetics(data->device.device_type)) { + if (DeviceAPI::SupportsPointerArithmetics(data->device.device_type) && + data->byte_offset % tvm::runtime::kAllocAlignment == 0) { DLManagedTensor* dl_tensor = data.ToDLPack(); dl_tensor->dl_tensor.data = reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; From 25cc462ed3648e5c2319bd11d4e84b747512e11b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2024 14:29:49 -0700 Subject: [PATCH 07/12] update LegalizeView to no op --- src/relax/op/memory/view.cc | 18 ++-- tests/python/relax/test_op_view.py | 136 +++++++++++++++++------------ 2 files changed, 87 insertions(+), 67 deletions(-) diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index e221b8ec5864..94d0b5b01fe0 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -292,6 +292,11 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); Expr LegalizeView(const BlockBuilder& bb, const Call& call) { + // No-op. View is lowered during the LowerBuiltinView pass. + return call; +} + +Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; Expr shape = call->args[1]; Expr dtype = call->args[2]; @@ -334,20 +339,13 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { relative_byte_offset = relax::PrimValue::Int64(0); } - if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) && - relative_byte_offset.same_as(call->args[3])) { - return call; - } - - return Call(call->op, {data, shape, dtype, relative_byte_offset}); -} - -Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { StructInfoDeriveFunc infer_sinfo_env_func; infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); - return Call(runtime_view_func, call->args, call->attrs, {runtime_view_sinfo}); + + return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); } TVM_REGISTER_OP("relax.memory.view") diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 033aee9882a4..c97e8d1c8880 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -452,7 +452,7 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) -def test_legalize_without_any_changes_is_no_op(): +def test_legalize_is_no_op(): @I.ir_module class Before: @R.function @@ -460,18 +460,13 @@ def main(A: R.Tensor([4096], "float32")): B = R.memory.view(A) return B - @I.ir_module - class Expected: - @R.function - def main(A: R.Tensor([4096], "float32")): - B = A - return B + Expected = Before - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_shape_change(): +def test_lower_runtime_builtin_shape_change(): @I.ir_module class Before: @R.function @@ -483,14 +478,25 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([64, 64]), + R.dtype("float32"), + R.prim_value(0), + ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_shape_from_unknown(): +def test_lower_runtime_builtin_view_shape_from_unknown(): """R.memory.view does not require the input tensor to have a known shape""" @I.ir_module @@ -504,14 +510,25 @@ def main(A: R.Tensor(dtype="float32")): class Expected: @R.function def main(A: R.Tensor(dtype="float32")): - B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([64, 64]), + R.dtype("float32"), + R.prim_value(0), + ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_dtype_change(): +def test_lower_runtime_builtin_dtype_change(): @I.ir_module class Before: @R.function @@ -523,16 +540,25 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.memory.view( - A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0 + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([4096]), + R.dtype("int32"), + R.prim_value(0), ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_byte_offset(): +def test_lower_runtime_builtin_byte_offset(): @I.ir_module class Before: @R.function @@ -544,16 +570,25 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.memory.view( - A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32" + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([4096]), + R.dtype("float32"), + R.prim_value(0), ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_with_multiple_updated_fields(): +def test_lower_runtime_builtin_view_with_multiple_updated_fields(): """R.memory.view may update more than one field in the view In this test case, a 4-kilobyte buffer is provided. The first @@ -584,21 +619,33 @@ def main(A: R.Tensor([4096], "uint8")): class Expected: @R.function def main(A: R.Tensor([4096], "uint8")): - B = R.memory.view( + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( A, - shape=R.shape([512]), - dtype=R.dtype("int32"), - relative_byte_offset=R.prim_value(0), + R.shape([512]), + R.dtype("int32"), + R.prim_value(0), ) - C = R.memory.view( + C = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( A, - shape=R.shape([16, 64]), - dtype=R.dtype("float16"), - relative_byte_offset=R.prim_value(2048), + R.shape([16, 64]), + R.dtype("float16"), + R.prim_value(2048), ) return (B, C) - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) @@ -720,30 +767,5 @@ def main(A: R.Tensor([4096], "uint8")): tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1]) -@tvm.testing.parametrize_targets("llvm", "cuda") -def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev): - @I.ir_module - class Module: - @R.function - def main(A: R.Tensor([4096], "float32")): - B = R.memory.view( - A, - shape=R.shape([16, 64]), - relative_byte_offset=32 * 64 * 4, - ) - C = R.memory.ensure_zero_offset(B) - return C - - built = tvm.relax.build(Module, target=target) - vm = tvm.relax.VirtualMachine(built, device=dev) - - np_input = np.random.random([4096]).astype("float32") - tvm_input = tvm.nd.array(np_input, dev) - tvm_output = vm["main"](tvm_input) - np_expected = np_input.reshape(64, 64)[32:48, :] - - tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) - - if __name__ == "__main__": tvm.testing.main() From 7ed6b38fb495d4d1649699e0ff92f882a358d414 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2024 16:45:16 -0700 Subject: [PATCH 08/12] Make SupportsDevicePointerArithmetics virutal --- include/tvm/runtime/device_api.h | 7 ++++--- src/runtime/cpu_device_api.cc | 2 ++ src/runtime/cuda/cuda_device_api.cc | 2 ++ src/runtime/relax_vm/builtin.cc | 3 ++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 0072981be513..c33606d98ed3 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -240,9 +240,10 @@ class TVM_DLL DeviceAPI { return device_type != kDLCPU && device_type != kDLMicroDev; } - static bool SupportsPointerArithmetics(int device_type) { - return device_type != kDLVulkan; - } + /*! + * \brief Whether pointer arithmetics on a device owned pointer may be performed on the host. + */ + virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; } protected: /*! diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 774335f5660b..ccd726a6ece6 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI { void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CPUDeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 66357a191541..33908d750d6d 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI { CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data); } + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CUDADeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index e77edd9184f1..3908ad1112a0 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -549,7 +549,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray d if (data->byte_offset == 0) { return data; } - if (DeviceAPI::SupportsPointerArithmetics(data->device.device_type) && + auto* device_api = DeviceAPI::Get(data->device); + if (device_api->SupportsDevicePointerArithmeticsOnHost() && data->byte_offset % tvm::runtime::kAllocAlignment == 0) { DLManagedTensor* dl_tensor = data.ToDLPack(); dl_tensor->dl_tensor.data = From cd553b6a3b981a208f1252b515849ba3a8d8dc9b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 17 Jul 2024 16:18:46 -0700 Subject: [PATCH 09/12] fix --- python/tvm/relax/transform/__init__.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index eef6d331375c..05e1376fb697 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -16,15 +16,6 @@ # under the License. """Relax transformations. """ -# Import to register the legalization functions. -from . import legalize_ops, tuning_api -from .attach_external_modules import AttachExternModules -from .fast_math import FastMathTransform -from .ipc_allreduce_rewrite import IPCAllReduceRewrite -from .lazy_transform_params import LazyTransformParams -from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage -from .optimize_layout_transform import OptimizeLayoutTransform -from .remove_redundant_reshape import RemoveRedundantReshape from .transform import ( AdjustMatmulOrder, AllocateWorkspace, @@ -92,3 +83,14 @@ dataflowblock_pass, function_pass, ) + +from .attach_external_modules import AttachExternModules +from .fast_math import FastMathTransform +from .ipc_allreduce_rewrite import IPCAllReduceRewrite +from .lazy_transform_params import LazyTransformParams +from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage +from .optimize_layout_transform import OptimizeLayoutTransform +from .remove_redundant_reshape import RemoveRedundantReshape + +# Import to register the legalization functions. +from . import legalize_ops, tuning_api From 0abe71d55475c3115474c75c37208763963e518d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 17 Jul 2024 16:39:42 -0700 Subject: [PATCH 10/12] lint --- include/tvm/relax/op_attr_types.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index c644e208f916..291bee597c03 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -83,9 +83,9 @@ using FLegalize = runtime::TypedPackedFunc; /*! From a07aa4f7614adbbb98cb22866e4940c8beb1e514 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Aug 2024 16:17:15 -0700 Subject: [PATCH 11/12] apply eric's patch --- python/tvm/relax/transform/__init__.py | 1 + src/relax/op/memory/view.cc | 6 ------ tests/python/relax/test_op_view.py | 4 +++- .../python/relax/test_transform_static_plan_block_memory.py | 2 +- tests/python/relax/test_vm_builtin_lower.py | 4 ++-- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 05e1376fb697..72406adca9ff 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -79,6 +79,7 @@ TopologicalSort, UpdateParamStructInfo, UpdateVDevice, + VMBuiltinLower, VMShapeLower, dataflowblock_pass, function_pass, diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 94d0b5b01fe0..21a72f6200b0 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -291,11 +291,6 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); -Expr LegalizeView(const BlockBuilder& bb, const Call& call) { - // No-op. View is lowered during the LowerBuiltinView pass. - return call; -} - Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; Expr shape = call->args[1]; @@ -357,7 +352,6 @@ TVM_REGISTER_OP("relax.memory.view") "The view's byte offset, relative to the input tensor's byte offset.") .set_attr("RequiresArgumentShapes", Bool(false)) .set_attr("FInferStructInfo", InferStructInfoView) - .set_attr("FLegalize", LegalizeView) .set_attr("FPurity", Bool(true)) .set_attr("FLowerBuiltin", LowerBuiltinView); diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index c97e8d1c8880..0900e1be306b 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -453,6 +453,8 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): def test_legalize_is_no_op(): + """R.memory.view is not legalized until LowerRuntimeBuiltin""" + @I.ir_module class Before: @R.function @@ -462,7 +464,7 @@ def main(A: R.Tensor([4096], "float32")): Expected = Before - After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) + After = tvm.relax.transform.LegalizeOps()(Before) tvm.ir.assert_structural_equal(Expected, After) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 911a19d43592..f9e632d34897 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -185,7 +185,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) mod = relax.transform.LowerAllocTensor()(mod) mod = relax.transform.KillAfterLastUse()(mod) - mod = relax.transform.VMBuiltinLower()(mod) + mod = relax.transform.LowerRuntimeBuiltin()(mod) tvm.ir.assert_structural_equal(mod, ExpectedLowered) diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index df28db4d46d2..984f9f958ca2 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -57,7 +57,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: gv0 = alloc return gv0 - After = relax.transform.VMBuiltinLower()(Before) + After = relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) @@ -79,7 +79,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: return gv0 with pytest.raises(tvm.TVMError): - relax.transform.VMBuiltinLower()(Before) + relax.transform.LowerRuntimeBuiltin()(Before) if __name__ == "__main__": From d5b1588d785f706f7d03f4974055cb389c779b84 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 5 Aug 2024 13:58:58 -0700 Subject: [PATCH 12/12] fix --- python/tvm/relax/transform/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 72406adca9ff..1ce864651cd9 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -87,6 +87,7 @@ from .attach_external_modules import AttachExternModules from .fast_math import FastMathTransform +from .fuse_transpose_matmul import FuseTransposeMatmul from .ipc_allreduce_rewrite import IPCAllReduceRewrite from .lazy_transform_params import LazyTransformParams from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage