diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py index a182c73dc0..fa69524a80 100644 --- a/apps/relax_examples/mlp.py +++ b/apps/relax_examples/mlp.py @@ -48,8 +48,8 @@ def build_mlp(data, weight): # build and create vm executor target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) # run the mlp model on relax vm data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py index db79ce22eb..45405ae398 100644 --- a/apps/relax_examples/nn_module.py +++ b/apps/relax_examples/nn_module.py @@ -48,7 +48,7 @@ data = nn.Placeholder((n, input_size), name="data") output = model(data) params = [data] + model.parameters() - builder.emit_func_output(output, params=params) + builder.emit_func_output(output, params=params) # get and print the IRmodule being built mod = builder.get() @@ -56,8 +56,8 @@ # build the IRModule and create relax vm target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) # init parameters params = nn.init_params(mod) diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py index 988948d33c..6b6434a69a 100644 --- a/apps/relax_examples/resnet.py +++ b/apps/relax_examples/resnet.py @@ -36,8 +36,8 @@ # build the IRModule and create relax vm target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(relax_mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(relax_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) # init weights and run the model on relax vm shape = (1, 3, 224, 224) diff --git a/include/tvm/relax/vm/exec_builder.h b/include/tvm/relax/vm/exec_builder.h index e9a6ebdbea..1b222bf671 100644 --- a/include/tvm/relax/vm/exec_builder.h +++ b/include/tvm/relax/vm/exec_builder.h @@ -19,7 +19,6 @@ /*! * \file tvm/relax/vm/exec_builder.h - * \brief */ #ifndef TVM_RELAX_VM_EXEC_BUILDER_H_ #define TVM_RELAX_VM_EXEC_BUILDER_H_ @@ -48,8 +47,8 @@ class ExecBuilder; */ class ExecBuilderNode : public Object { public: - /*! \brief The mutable internal executable node. */ - ObjectPtr exec; // mutable + /*! \brief The mutable internal executable. */ + ObjectPtr exec; // mutable /*! * \brief To annotate the start of a vm function. * \param func The function name. @@ -81,6 +80,7 @@ class ExecBuilderNode : public Object { void EmitIf(vm::RegName cond, vm::Index false_offset); /*! * \brief Emit a constant value to the constant pool. + * \param obj The constant value to be emitted * \return The index that represents the constant. */ vm::Index EmitConstant(TVMRetValue obj); @@ -88,9 +88,9 @@ class ExecBuilderNode : public Object { * \brief Get the built executable. * \return The built executable. */ - vm::Executable Get(); + ObjectPtr Get(); /*! - * \brief Create a ExecBuilder. + * \brief Create an ExecBuilder. * \return The ExecBuilder. */ TVM_DLL static ExecBuilder Create(); @@ -102,6 +102,11 @@ class ExecBuilderNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); private: + /*! + * \brief A helper function to check if an executable is legal by checking if registers are used + * properly + */ + void CheckExecutable(); /*! * \brief Formalize the executable. */ diff --git a/include/tvm/relax/vm/executable.h b/include/tvm/relax/vm/executable.h index ff3512ab32..c3b4476b86 100644 --- a/include/tvm/relax/vm/executable.h +++ b/include/tvm/relax/vm/executable.h @@ -38,8 +38,6 @@ namespace tvm { namespace runtime { namespace relax_vm { -class Executable; - /*! * \brief A representation of a Relax function in the VM. * @@ -63,50 +61,68 @@ struct VMFunction { * The executable contains information (e.g. data in different memory regions) * to run in a virtual machine. */ -class ExecutableNode : public Object { +class Executable : public runtime::ModuleNode { public: + /*! + * \brief Get a PackedFunc from the executable module. + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief Print the detailed statistics of the given code, i.e. number of - * globls and constants, etc. + * globals and constants, etc. + * \return The statistics represented by a string. */ std::string Stats() const; /*! * \brief Get the i-th instruction from the executable. + * \param i The index of the instruction to be fetched. * \return The instruction. */ Instruction GetInstruction(Index i) const; /*! * \brief Set j-th byte data of i-th instruction to val. + * \param i The index of the instruction to be updated. + * \param j The index of the byte data of the instruction to be updated. + * \param val The value to be set */ void SetInstructionData(Index i, Index j, ExecWord val); /*! * \brief Print the instructions as text format. + * \return The text format of the instructions. */ String AsText() const; /*! * \brief Print the instructions as python program. + * \return The python program of the instructions, represented by a string. */ String AsPython() const; /*! * \brief Write the Executable to the binary stream in serialized form. * \param stream The binary stream to save the executable to. */ - void SaveToBinary(dmlc::Stream* stream); + void SaveToBinary(dmlc::Stream* stream) final; /*! * \brief Load Executable from the binary stream in serialized form. * \param stream The binary stream that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. */ - static Executable LoadFromBinary(void* stream); + static Module LoadFromBinary(void* stream); /*! - * \brief Write the Executable to the provided path as a file contianing its serialized content. - * \param path The path to write the serialized data to. + * \brief Write the Executable to the provided path as a file containing its serialized content. + * \param file_name The name of the file to write the serialized data to. + * \param format The target format of the saved file. */ - void SaveToFile(const std::string& path); + void SaveToFile(const std::string& file_name, const std::string& format) final; /*! * \brief Load Executable from the file. - * \param file_name The file that load the executable from. + * \param file_name The path of the file that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. */ - static Executable LoadFromFile(const std::string& file_name); + static Module LoadFromFile(const std::string& file_name); + /*! \brief The virtual machine's function table. */ std::vector global_funcs; /*! \brief A map from globals (as strings) to their index in the function map. */ @@ -115,7 +131,8 @@ class ExecutableNode : public Object { std::vector constants; /*! \brief The name of packed functions. */ std::vector func_names; - /*! \brief A mapping from the packed function (as string) to the index that + /*! + * \brief A mapping from the packed function (as string) to the index that * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. */ std::unordered_map func2idx; @@ -124,9 +141,9 @@ class ExecutableNode : public Object { /*! \brief The byte data of instruction. */ std::vector instr_data; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.Executable"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object); + virtual ~Executable() {} + + const char* type_key() const final { return "relax.Executable"; } private: /*! @@ -171,12 +188,6 @@ class ExecutableNode : public Object { void LoadPackedFuncNames(dmlc::Stream* strm); }; -/*! \brief Reference to Executable. */ -class Executable : public ObjectRef { - public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Executable, ObjectRef, ExecutableNode); -}; - } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/include/tvm/relax/vm/vm.h b/include/tvm/relax/vm/vm.h index dafaab722f..48b56bfd65 100644 --- a/include/tvm/relax/vm/vm.h +++ b/include/tvm/relax/vm/vm.h @@ -66,8 +66,8 @@ struct VMFrame { struct VMState { /*! \brief The memory allocators. */ std::vector allocators; - /*! \brief The loaded module. */ - runtime::Module mod_; + /*! \brief The kernel library. */ + Optional lib; }; /*! @@ -90,11 +90,10 @@ class VirtualMachine : public runtime::ModuleNode { */ void Init(const std::vector& devices, const std::vector& alloc_types); /*! - * \brief load the executable and module for the virtual machine. + * \brief Load the executable for the virtual machine. * \param exec The executable. - * \param mod The library module. */ - void Load(Executable exec, runtime::Module mod); + void LoadExecutable(ObjectPtr exec); /*! * \brief Get a PackedFunc from module. * @@ -117,17 +116,19 @@ class VirtualMachine : public runtime::ModuleNode { ~VirtualMachine() final {} const char* type_key() const final { return "relax.VirtualMachine"; } - /*! \brief The state of the virtual machine, which can be referred by - * instructions. - */ + + /*! \brief The state of the virtual machine, which can be referred by instructions. */ VMState state; protected: - /*! \brief Push a call frame on to the call stack. */ + /*! + * \brief Push a call frame on to the call stack. + * \param ret_pc The program counter to return to. + * \param vm_func The function to be pushed to the call stack. + */ void PushFrame(Index ret_pc, const VMFunction& vm_func); /*! * \brief Pop a frame off the call stack. - * \return The number of frames left. */ void PopFrame(); /*! @@ -139,7 +140,7 @@ class VirtualMachine : public runtime::ModuleNode { /*! * \brief Read a VM register. * \param reg The register to read from. - * \return The read object. + * \return The value of the register. */ inline RegType ReadRegister(RegName reg) const; /*! @@ -154,7 +155,7 @@ class VirtualMachine : public runtime::ModuleNode { private: /*! \brief The loaded executable. */ - Executable exec_; + ObjectPtr exec_; /*! \brief The current stack of call frames. */ std::vector frames_; /*! \brief The virtual machine PC. */ diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index f4e7ed71c7..ec759560bb 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -124,4 +124,4 @@ def emit_if(self, cond, false_offset): def get(self) -> Executable: """return the executable""" - return _ffi_api.ExecBuilderGet(self) + return Executable(_ffi_api.ExecBuilderGet(self)) diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 2a9b9323ec..cef7530264 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -17,41 +17,45 @@ # pylint: disable=invalid-name, redefined-builtin """The Relax virtual machine""" from typing import List, Optional, Union, Dict, Tuple + import tvm from tvm import relax from tvm.ir.module import IRModule -from tvm.runtime import Object, Device, Module, PackedFunc +from tvm.runtime import Device, Module, PackedFunc from tvm.tir.function import PrimFunc from . import _ffi_api from ..rpc.base import RPC_SESS_MASK -@tvm._ffi.register_object("relax.Executable") -class Executable(Object): +class Executable(object): """The executable object emitted by the VM compiler or the ExecBuilder.""" - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.Executable) + def __init__(self, mod: Module): + self.mod = mod + self._stats = self.mod["stats"] + self._save_to_file = self.mod["save_to_file"] + self._as_text = self.mod["as_text"] + self._as_python = self.mod["as_python"] def stats(self) -> str: """print the detailed statistics of the executable.""" - return _ffi_api.ExecutableStats(self) + return self._stats() - def save_to_file(self, file_name: str) -> None: + def save_to_file(self, path: str) -> None: """serialize and write the executable to a file.""" - _ffi_api.ExecutableSaveToFile(self, file_name) + self._save_to_file(path) - def astext(self) -> str: + def as_text(self) -> str: """print the instructions as text format.""" - return _ffi_api.ExecutableAsText(self) + return self._as_text() - def aspython(self) -> str: + def as_python(self) -> str: """print the instructions as python program.""" - return _ffi_api.ExecutableAsPython(self) + return self._as_python() -def load_exec_from_file(file_name: str) -> Executable: - return _ffi_api.ExecutableLoadFromFile(file_name) +def load_exec_from_file(path: str) -> Executable: + return Executable(_ffi_api.ExecutableLoadFromFile(path)) class VirtualMachine(object): @@ -65,7 +69,6 @@ def __init__( exec: Executable, device: Union[Device, List[Device]], memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, - mod: Optional[Module] = None, ) -> None: """ Construct a VirtualMachine wrapper object. @@ -75,26 +78,18 @@ def __init__( exec: Executable The VM executable. - device : tvm.runtime.Device or List[tvm.runtime.Device] + device : Union[Device, List[Device]] The device to deploy the module. - memory_cfg : str or Dict[tvm.runtime.Device, str], optional + memory_cfg : Optional[Union[str, Dict[Device, str]]] Config the type of memory allocator. The allocator type can be ["naive", "pooled"]. If memory_cfg is None, all devices will use pooled allocator by default. If memory_cfg is string, all devices will use the specified allocator type. If memory_cfg is a dict, each device uses the allocator type specified in the dict, or pooled allocator if not specified in the dict. - - mod : tvm.runtime.Module, optional - Optional runtime module to load to the VM. - - Returns - ------- - vm: VirtualMachine - A VM wrapper object. """ - self.module = _ffi_api.VirtualMachine(exec, mod) + self.module = exec.mod["vm_load_executable"]() self._setup_device(device, memory_cfg) def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: @@ -137,7 +132,7 @@ def __getitem__(self, key: str) -> PackedFunc: return self.module[key] -def build(mod: tvm.IRModule, target: tvm.target.Target) -> Tuple[Executable, Module]: +def build(mod: tvm.IRModule, target: tvm.target.Target) -> Executable: """ Build an IRModule to VM executable. @@ -154,14 +149,12 @@ def build(mod: tvm.IRModule, target: tvm.target.Target) -> Tuple[Executable, Mod to setup the dimensions and parameters correctly. host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. Returns ------- - ex: tvm.relax.vm.Exectuable + ex: tvm.relax.vm.Executable An executable that can be loaded by virtual machine. - lib: tvm.runtime.Module - A runtime module that contains generated code. Example ------- @@ -175,7 +168,7 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): mod = InputModule target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) """ passes = [relax.transform.ToNonDataflow()] passes.append(relax.transform.CallTIRRewrite()) @@ -187,8 +180,7 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): # split primfunc and relax function rx_mod, tir_mod = _split_tir_relax(new_mod) lib = tvm.build(tir_mod, target=target) - ex = _ffi_api.VMCodeGen(rx_mod) - return ex, lib + return Executable(_ffi_api.VMCodeGen(rx_mod, lib)) def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]: diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index ccb402df57..081c0f8450 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -126,7 +126,7 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const IfNode* op) { const If& ife = GetRef(op); // Get the executable from exec_builder - Executable exec_ = builder_->Get(); + ObjectPtr exec_ = builder_->Get(); // Visit the condition expression Instruction::Arg cond_reg = this->VisitExpr(ife->cond); @@ -393,13 +393,23 @@ void VMCodeGen::CodeGen(IRModule rx_mod) { } } -Executable VMCodeGen::GetExec() { return builder_->Get(); } +ObjectPtr VMCodeGen::GetExec() { return builder_->Get(); } -Executable CodeGen(IRModule mod) { - auto codegen = make_object(); - codegen->CodeGen(mod); - Executable exec = codegen->GetExec(); - return exec; +/*! + * \brief Create the Relax VM executable from an IRModule of Relax function(s) and, possibly, a + * kernel library. + * \param mod The IRModule containing Relax function(s). + * \param lib The kernel library. + * \return The constructed Relax VM executable. + */ +Module CodeGen(IRModule mod, Optional lib) { + VMCodeGen codegen; + codegen.CodeGen(mod); + ObjectPtr executable = codegen.GetExec(); + if (lib.defined()) { + executable->Import(lib.value()); + } + return Module(executable); } TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(CodeGen); diff --git a/src/relax/backend/vm/codegen_vm.h b/src/relax/backend/vm/codegen_vm.h index 1019f70d94..e2a7a046d3 100644 --- a/src/relax/backend/vm/codegen_vm.h +++ b/src/relax/backend/vm/codegen_vm.h @@ -51,7 +51,7 @@ class VMCodeGen : public Object { * \brief Get the compiled executable. * \return The compiled executable. */ - Executable GetExec(); + ObjectPtr GetExec(); static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.VMCodeGen"; diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index 2d47edf5c8..68f0e6d1b4 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -18,7 +18,6 @@ */ /*! * \file src/relax/vm/builtin.cc - * \brief */ #include #include @@ -117,13 +116,16 @@ TVM_REGISTER_GLOBAL("vm.binary_broadcast_shape_infer") TVM_REGISTER_GLOBAL("vm.call_tir_dyn").set_body([](TVMArgs args, TVMRetValue* rv) { void* vm_state_ptr = args[0]; VMState* vm_state = static_cast(vm_state_ptr); - runtime::Module mod_ = vm_state->mod_; - runtime::String func_name = args[1]; - PackedFunc func = mod_->GetFunction(func_name, true); - if (func == nullptr) { - func = *(mod_->GetFuncFromEnv(func_name)); + PackedFunc func{nullptr}; + if (vm_state->lib.defined()) { + func = vm_state->lib.value()->GetFunction(func_name, true); + } + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(func_name); + CHECK(p_func != nullptr); + func = *(p_func); } ShapeTuple to_unpack = args[args.size() - 1]; @@ -144,12 +146,12 @@ TVM_REGISTER_GLOBAL("vm.call_tir_dyn").set_body([](TVMArgs args, TVMRetValue* rv }); TVM_REGISTER_GLOBAL("vm.runtime.TupleGetItem") -.set_body_typed([](runtime::ADT adt, ShapeTuple index) { - ICHECK_EQ(index.size(), 1); - int idx = index[0]; - ICHECK_LT(idx, adt.size()); - return adt[idx]; -}); + .set_body_typed([](runtime::ADT adt, ShapeTuple index) { + ICHECK_EQ(index.size(), 1); + int idx = index[0]; + ICHECK_LT(idx, adt.size()); + return adt[idx]; + }); } // namespace relax_vm } // namespace runtime diff --git a/src/relax/vm/exec_builder.cc b/src/relax/vm/exec_builder.cc index 728884cf00..e35eb1a8aa 100644 --- a/src/relax/vm/exec_builder.cc +++ b/src/relax/vm/exec_builder.cc @@ -33,7 +33,7 @@ TVM_REGISTER_NODE_TYPE(ExecBuilderNode); ExecBuilder ExecBuilderNode::Create() { ExecBuilder ret(make_object()); - ret->exec = make_object(); + ret->exec = make_object(); return ret; } @@ -91,8 +91,7 @@ void ExecBuilderNode::EmitIf(vm::RegName cond, vm::Index false_offset) { exec->instr_data.push_back(false_offset); } -// helper function to check if an executable is legal by checking if registers are used properly -bool CheckExecutable(Executable exec) { +void ExecBuilderNode::CheckExecutable() { for (auto it = exec->global_funcs.cbegin(); it != exec->global_funcs.cend(); ++it) { Index num_inputs = it->num_args; std::unordered_set dst_registers; @@ -111,10 +110,9 @@ bool CheckExecutable(Executable exec) { if (instr.args[i].kind() == Instruction::kRegister && instr.args[i].value() >= num_inputs && dst_registers.find(instr.args[i].value()) == dst_registers.end()) { - LOG(ERROR) << "register r(" << instr.args[i].value() << ") in VM function \"" + LOG(FATAL) << "register r(" << instr.args[i].value() << ") in VM function \"" << it->name << "\" is used as input while the number of inputs is only " << num_inputs << ".\n"; - return false; } arg_registers.emplace(instr.args[i].value()); } @@ -148,13 +146,12 @@ bool CheckExecutable(Executable exec) { } } } - return true; } -Executable ExecBuilderNode::Get() { - CheckExecutable(Executable(this->exec)); +ObjectPtr ExecBuilderNode::Get() { + this->CheckExecutable(); this->Formalize(); - return Executable(this->exec); + return this->exec; } void ExecBuilderNode::Formalize() { @@ -249,7 +246,10 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, return Instruction::Arg(Instruction::kConstIdx, value).data; }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_method(&ExecBuilderNode::Get); +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { + ObjectPtr p_exec = builder->Get(); + return runtime::Module(p_exec); +}); } // namespace relax } // namespace tvm diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc index 8adbe521eb..7aeb605d2f 100644 --- a/src/relax/vm/executable.cc +++ b/src/relax/vm/executable.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -50,9 +51,33 @@ enum ConstantType : int { ICHECK(val) << "Invalid VM file format in the " << section << " section." \ << "\n"; -TVM_REGISTER_OBJECT_TYPE(ExecutableNode); +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); + } else if (name == "save_to_file") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + std::string path = args[0]; + this->SaveToFile(path, ""); + }); + } else if (name == "as_text") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); }); + } else if (name == "as_python") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); }); + } else if (name == "vm_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = make_object(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); + } + return nullptr; +} -std::string ExecutableNode::Stats() const { +std::string Executable::Stats() const { std::ostringstream oss; oss << "Relax VM executable statistics:" << std::endl; @@ -123,12 +148,12 @@ std::string ExecutableNode::Stats() const { return oss.str(); } -void ExecutableNode::SetInstructionData(Index i, Index j, ExecWord val) { +void Executable::SetInstructionData(Index i, Index j, ExecWord val) { Index instr_idx = instr_offset[i]; instr_data[instr_idx + j] = val; } -Instruction ExecutableNode::GetInstruction(Index i) const { +Instruction Executable::GetInstruction(Index i) const { size_t offset = instr_offset[i]; Opcode op = static_cast(instr_data[offset]); switch (op) { @@ -178,7 +203,7 @@ void LoadHeader(dmlc::Stream* strm) { STREAM_CHECK(version == TVM_VERSION, "version"); } -void ExecutableNode::SaveToBinary(dmlc::Stream* stream) { +void Executable::SaveToBinary(dmlc::Stream* stream) { std::string code; // Initialize the stream object. dmlc::MemoryStringStream strm(&code); @@ -201,20 +226,20 @@ void ExecutableNode::SaveToBinary(dmlc::Stream* stream) { stream->Write(code); } -void ExecutableNode::SaveToFile(const std::string& path) { +void Executable::SaveToFile(const std::string& file_name, const std::string& format) { std::string data; dmlc::MemoryStringStream writer(&data); dmlc::SeekStream* strm = &writer; - ExecutableNode::SaveToBinary(strm); - runtime::SaveBinaryToFile(path, data); + Executable::SaveToBinary(strm); + runtime::SaveBinaryToFile(file_name, data); } -Executable ExecutableNode::LoadFromBinary(void* stream) { +Module Executable::LoadFromBinary(void* stream) { std::string code; static_cast(stream)->Read(&code); dmlc::MemoryStringStream strm(&code); - auto exec = make_object(); + ObjectPtr exec = make_object(); // Load header. LoadHeader(&strm); @@ -231,18 +256,23 @@ Executable ExecutableNode::LoadFromBinary(void* stream) { // Code section. exec->LoadCodeSection(&strm); - return Executable(exec); + return Module(exec); } -Executable ExecutableNode::LoadFromFile(const std::string& file_name) { +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable") + .set_body_typed(Executable::LoadFromBinary); + +Module Executable::LoadFromFile(const std::string& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); dmlc::Stream* strm = &reader; - auto exec = ExecutableNode::LoadFromBinary(reinterpret_cast(strm)); - return exec; + return Executable::LoadFromBinary(reinterpret_cast(strm)); } +TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable") + .set_body_typed(Executable::LoadFromFile); + void SerializeVMFunc(const VMFunction& func, dmlc::Stream* strm) { strm->Write(func.name); strm->Write(func.start_instr); @@ -259,14 +289,14 @@ VMFunction DeserializeVMFunc(dmlc::Stream* strm) { return func; } -void ExecutableNode::SaveGlobalSection(dmlc::Stream* strm) { +void Executable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(static_cast(this->global_funcs.size())); for (const auto& func : this->global_funcs) { SerializeVMFunc(func, strm); } } -void ExecutableNode::SaveConstantSection(dmlc::Stream* strm) { +void Executable::SaveConstantSection(dmlc::Stream* strm) { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { if (it.IsObjectRef()) { @@ -298,14 +328,14 @@ void ExecutableNode::SaveConstantSection(dmlc::Stream* strm) { } } -void ExecutableNode::SavePackedFuncNames(dmlc::Stream* strm) { strm->Write(func_names); } +void Executable::SavePackedFuncNames(dmlc::Stream* strm) { strm->Write(func_names); } -void ExecutableNode::SaveCodeSection(dmlc::Stream* strm) { +void Executable::SaveCodeSection(dmlc::Stream* strm) { strm->Write(instr_offset); strm->Write(instr_data); } -void ExecutableNode::LoadGlobalSection(dmlc::Stream* strm) { +void Executable::LoadGlobalSection(dmlc::Stream* strm) { uint64_t sz; STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); size_t size = static_cast(sz); @@ -318,7 +348,7 @@ void ExecutableNode::LoadGlobalSection(dmlc::Stream* strm) { } } -void ExecutableNode::LoadConstantSection(dmlc::Stream* strm) { +void Executable::LoadConstantSection(dmlc::Stream* strm) { uint64_t sz; // Load the number of constants. STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); @@ -367,14 +397,14 @@ void ExecutableNode::LoadConstantSection(dmlc::Stream* strm) { } } -void ExecutableNode::LoadPackedFuncNames(dmlc::Stream* strm) { +void Executable::LoadPackedFuncNames(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&(this->func_names)), "packed func names"); for (size_t i = 0; i < func_names.size(); ++i) { this->func2idx[func_names[i]] = i; } } -void ExecutableNode::LoadCodeSection(dmlc::Stream* strm) { +void Executable::LoadCodeSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); } @@ -435,7 +465,7 @@ std::string InstrArgToPyStr(Instruction::Arg arg) { } } -String ExecutableNode::AsText() const { +String Executable::AsText() const { // print the text format std::ostringstream os; for (size_t fidx = 0; fidx < this->global_funcs.size(); ++fidx) { @@ -482,7 +512,7 @@ String ExecutableNode::AsText() const { return String(os.str()); } -String ExecutableNode::AsPython() const { +String Executable::AsPython() const { // print the python format std::ostringstream os; os << "ib = rx.Builder()\n"; @@ -526,19 +556,7 @@ String ExecutableNode::AsPython() const { return String(os.str()); } -TVM_REGISTER_GLOBAL("relax.Executable").set_body_typed([]() { return Executable(); }); - -TVM_REGISTER_GLOBAL("relax.ExecutableStats").set_body_method(&ExecutableNode::Stats); - -TVM_REGISTER_GLOBAL("relax.ExecutableAsText").set_body_method(&ExecutableNode::AsText); - -TVM_REGISTER_GLOBAL("relax.ExecutableAsPython") - .set_body_method(&ExecutableNode::AsPython); - -TVM_REGISTER_GLOBAL("relax.ExecutableSaveToFile") - .set_body_method(&ExecutableNode::SaveToFile); - -TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(ExecutableNode::LoadFromFile); +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile); } // namespace relax_vm } // namespace runtime diff --git a/src/relax/vm/vm.cc b/src/relax/vm/vm.cc index 5c59274db2..925a9a4754 100644 --- a/src/relax/vm/vm.cc +++ b/src/relax/vm/vm.cc @@ -28,15 +28,6 @@ namespace tvm { namespace runtime { namespace relax_vm { -class DummyModule : public runtime::ModuleNode { - public: - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - return nullptr; - } - - const char* type_key() const final { return "relax.DummyModule"; } -}; - PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { const auto& m = exec_->global_map; @@ -55,9 +46,10 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } } -void VirtualMachine::Load(Executable exec, runtime::Module mod) { +void VirtualMachine::LoadExecutable(ObjectPtr exec) { this->exec_ = exec; - this->state.mod_ = mod; + CHECK_LE(exec_->imports().size(), 1); + this->state.lib = exec_->imports().empty() ? Optional(NullOpt) : exec_->imports()[0]; } RegType VirtualMachine::Invoke(Index gf_idx, const std::vector& args) { @@ -100,9 +92,15 @@ void VirtualMachine::RunLoop() { case Opcode::Call: { std::string func_name = exec_->func_names[instr.func_idx]; DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name; - PackedFunc func = state.mod_->GetFunction(func_name, true); - if (func == nullptr) { - func = *(state.mod_->GetFuncFromEnv(func_name)); + + PackedFunc func{nullptr}; + if (state.lib.defined()) { + func = state.lib.value()->GetFunction(func_name, true); + } + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(func_name); + CHECK(p_func != nullptr); + func = *(p_func); } std::vector values(instr.num_args); @@ -194,23 +192,6 @@ inline RegType VirtualMachine::ReadRegister(Index r) const { return frames_.back().register_file[r]; } -runtime::Module CreateVirtualMachine(Executable exec, Optional mod) { - runtime::Module mod_; - if (!mod) { - mod_ = runtime::Module(make_object()); - } else { - mod_ = mod.value(); - } - auto vm = make_object(); - vm->Load(exec, mod_); - return runtime::Module(vm); -} - -TVM_REGISTER_GLOBAL("relax.VirtualMachine") - .set_body_typed([](Executable exec, Optional mod) { - return CreateVirtualMachine(exec, mod); - }); - // initialize the VirtualMachine, takes variable-length arguments // first argument is a runtime::Module, followed by one or more device_type, device_id, // and the AllocatorType associated with the device. diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index 49a1ae1615..c75f126d66 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -184,14 +184,14 @@ def main(x: Tensor[(32, 32), "float32"], w: Tensor[(32, 32), "float32"]) -> Tens ) with transform.PassContext(opt_level=3): - ex0, lib0 = relax.vm.build(mod, target) + ex0 = relax.vm.build(mod, target) with transform.PassContext(opt_level=3): mod = relax.transform.MetaScheduleApplyHistoryBest(database, target)(mod) - ex1, lib1 = relax.vm.build(mod, target) + ex1 = relax.vm.build(mod, target) - vm0 = relax.VirtualMachine(ex0, dev, mod=lib0) - vm1 = relax.VirtualMachine(ex1, dev, mod=lib1) + vm0 = relax.VirtualMachine(ex0, dev) + vm1 = relax.VirtualMachine(ex1, dev) data = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) weight = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 8f1e3178da..e1f3adbacb 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -15,18 +15,16 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations # must import to defer parsing of annotations -import pytest import os + import numpy as np +import pytest import tvm -from tvm import relax, tir, te -from tvm.runtime import container -import numpy as np - -from tvm.ir.base import assert_structural_equal import tvm.script -from tvm.script import tir as T, relax as R +import tvm.testing +from tvm import relax, te, tir, TVMError from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T @tvm.register_func("test.vm.move") @@ -74,7 +72,7 @@ def test_vm_execute(): ) ) add_res = vm["func0"](a, b) - np.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy()) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) def test_vm_multiple_func(): @@ -99,8 +97,8 @@ def test_vm_multiple_func(): ) mul_res = vm["func1"](a, b) add_res = vm["func0"](a, b) - np.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy()) - np.testing.assert_allclose(mul_res.numpy(), a.numpy() * b.numpy()) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(mul_res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) def test_vm_serialize(): @@ -119,10 +117,32 @@ def test_vm_serialize(): exec0 = ib.get() exec0.save_to_file("exec.tmp") exec1 = relax.load_exec_from_file("exec.tmp") - assert exec0.astext() == exec1.astext() + assert exec0.as_text() == exec1.as_text() os.remove("exec.tmp") +def test_vm_exec_serialize_export_library(): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: Tensor[(3, 4), "float32"]): + z = R.call_packed("vm.builtin.copy", x) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + from tvm.contrib import utils + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + ex.mod.export_library(path_exec) + + loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec)) + assert ex.as_text() == loaded_exec.as_text() + + def test_vm_constant_serialize(): dtype = tvm.DataType("float32") shape = (4, 6) @@ -140,22 +160,20 @@ def test_vm_constant_serialize(): exec0 = ib.get() exec0.save_to_file("exec.tmp") exec1 = relax.load_exec_from_file("exec.tmp") - assert exec0.astext() == exec1.astext() + assert exec0.as_text() == exec1.as_text() vm = relax.VirtualMachine(exec0, tvm.cpu()) res = vm["main"](inp) - np.testing.assert_allclose(inp.numpy(), res.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) os.remove("exec.tmp") def test_vm_checker(): ib = relax.ExecBuilder() - try: + with pytest.raises(TVMError): with ib.function("func0", num_inputs=2): ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2)) ib.emit_ret(ib.r(2)) ib.get() - except ValueError as ex: - assert True def test_vm_formalize(): @@ -171,7 +189,7 @@ def test_vm_formalize(): ib1.emit_ret(ib1.r(3)) exec0 = ib0.get() exec1 = ib1.get() - assert exec0.astext() == exec1.astext() + assert exec0.as_text() == exec1.as_text() @tvm.register_func("test.vm.add_scalar") @@ -247,11 +265,11 @@ def foo(x: Tensor[(3, 4), "float32"]): mod = TestVMMove target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) res = vm["foo"](inp) - np.testing.assert_allclose(res.numpy(), inp.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) def test_vm_goto(): @@ -274,7 +292,7 @@ def test_vm_goto(): ) ) res = vm["main"](a, b) - np.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy()) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) def test_vm_if(): @@ -298,9 +316,9 @@ def test_vm_if(): ) ) res = vm["main"](False, a, b) - np.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy()) + tvm.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) res = vm["main"](1, a, b) - np.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy()) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) def test_vm_compile_if(): @@ -316,13 +334,13 @@ def ife(cond: Tensor[(), "bool"], x: Tensor[(3, 4), "float32"]): mod = TestVMCompileIf target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array(np.random.rand(3, 4)) res = vm["ife"](True, inp) - np.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) res = vm["ife"](0, inp) - np.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) def test_vm_compile_stage0(): @@ -335,12 +353,12 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): mod = TestVMCompileStage0 target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) vm["foo"](inp1, inp2) - np.testing.assert_allclose(inp2.numpy(), inp1.numpy()) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) def test_vm_compile_stage1(): @@ -375,8 +393,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape: mod = TestVMCompileStage1 target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) arr = tvm.nd.array(np.random.rand(*shape)) @@ -395,8 +413,8 @@ def foo(x: Tensor[_, "float32"]) -> Shape: mod = TestVMCompileStage2 target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) arr = tvm.nd.array(np.random.rand(*shape)) @@ -417,13 +435,13 @@ def foo(x: Tensor[(32, 16), "float32"]) -> Tensor: mod = TestVMCompileStage3 target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) res = vm["foo"](inp) - np.testing.assert_allclose(inp.numpy(), res.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) def test_vm_compile_e2e(): @@ -440,13 +458,13 @@ def foo(x: Tensor[_, "float32"]) -> Tensor: mod = TestVMCompileE2E target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) res = vm["foo"](inp) - np.testing.assert_allclose(np.tile(inp.numpy(), (1, 2)), res.numpy()) + tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) def test_vm_compile_e2e_func_param_with_shape(): @@ -477,14 +495,14 @@ def func(x: Tensor[(m, n), "float32"], w: Tensor[(n, k), "float32"]) -> Tensor: mod = TestVMCompileE2E2 target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) res = vm["func"](data, weight) expected = np.dot(data.numpy(), weight.numpy()) - np.testing.assert_allclose(expected, res.numpy(), rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) def test_vm_emit_te_extern(): @@ -504,14 +522,14 @@ def test_vm_emit_te_extern(): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) res = vm["rx_cblas_matmul"](data, weight) expected = np.dot(data.numpy(), weight.numpy()) - np.testing.assert_allclose(expected, res.numpy(), rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) def test_vm_emit_te_concat(): @@ -533,9 +551,9 @@ def te_func(A, B): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array( np.random.rand( 1, @@ -547,8 +565,9 @@ def te_func(A, B): ).astype(np.float32) ) res = vm["rx_func"](inp, inp2) - - np.testing.assert_allclose(res.numpy(), np.append(inp.numpy(), inp2.numpy())) + tvm.testing.assert_allclose( + res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 + ) def test_vm_emit_te_dtype_change(): @@ -572,9 +591,9 @@ def te_func(A): assert new_mod["rx_func"].body.blocks[0].bindings[0].value.attrs.dtype == "int16" target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array( np.random.rand( 1, @@ -601,9 +620,9 @@ def te_func(A): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (9,) inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) res = vm["rx_func"](inp) @@ -612,7 +631,7 @@ def expected_output(): output_shape = (shape[0] // 2,) return inp.numpy()[: output_shape[0]] + 1 - np.testing.assert_allclose(res.numpy(), expected_output()) + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) def test_vm_relax_symbolic_shape(): @@ -633,9 +652,9 @@ def te_func(A, B): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape1 = (5,) shape2 = (3,) inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) @@ -645,7 +664,7 @@ def te_func(A, B): def expected_output(): return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] - np.testing.assert_allclose(res.numpy(), expected_output()) + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) def test_vm_relax_dyn_tir_shape(): @@ -667,19 +686,19 @@ def te_func(A): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) ex.save_to_file("exec.tmp") exec1 = relax.load_exec_from_file("exec.tmp") - assert ex.astext() == exec1.astext() + assert ex.as_text() == exec1.as_text() - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) res = vm["rx_func"](inp, inp2) - np.testing.assert_allclose(res.numpy(), inp2.numpy()) + tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) os.remove("exec.tmp") @@ -697,17 +716,17 @@ def test_vm_tuple(): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) + ex = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (5, 5) inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) (res1, res2), res3 = vm["rx_func"](inp, inp2) - np.testing.assert_allclose(res1.numpy(), inp.numpy()) - np.testing.assert_allclose(res2.numpy(), inp2.numpy()) - np.testing.assert_allclose(res3.numpy(), inp.numpy()) + tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) def test_vm_tuplegetitem(): @@ -723,12 +742,12 @@ def tuple_get_item(x: Tensor[(_, _), "float32"], y: Tensor[(_, _), "float32"]): mod = TestVMTupleGetItem target = tvm.target.Target("llvm", host="llvm") - ex, lib = relax.vm.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) x_inp = tvm.nd.array(np.random.rand(2, 3)) y_inp = tvm.nd.array(np.random.rand(2, 3)) res = vm["tuple_get_item"](x_inp, y_inp) - np.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) if __name__ == "__main__":