diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 8d3f651758d1..95c6d6f4ab47 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -63,6 +63,19 @@ class Executable : public ModuleNode { */ PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + /*! + * \brief Write the Executable to the binary stream in serialized form. + * \param stream The binary stream to save the executable to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + + /*! + * \brief Write the Executable to the provided path as a file contianing its serialized content. + * \param path The path to write the serialized data to. + * \param format The format of the serialized blob. + */ + void SaveToFile(const std::string& path, const std::string& format) final; + /*! * \brief Serialize the executable into global section, constant section, and * code section. @@ -125,12 +138,24 @@ class Executable : public ModuleNode { * \brief Get the `lib` module in an executable. Users have the flexibility to call * `export_library` from the frontend to save the library to disk. * - * \return The runtime module that contains the hardwre dependent code. + * \return The runtime module that contains the hardware dependent code. + */ + runtime::Module GetLib() const; + + /*! + * \brief Set the `lib` module in an executable. + * + * This allows us to do partial initialization in the case of (de|ser)ialization cases. + * This method also ensures correct initialization of library ensuring we only Import a + * single library. + * + * NB: This also provides some abstraction over how libraries are stored as there are plans + * to iterate on the way runtime::Module works in the backend of the compiler. */ - runtime::Module GetLib() const { return lib; } + void SetLib(const runtime::Module& lib); /*! - * \brief Get the arity of the VM Fucntion. + * \brief Get the arity of the VMFunction. * \param func Function name. * \return The number of parameters. */ @@ -148,9 +173,6 @@ class Executable : public ModuleNode { const char* type_key() const final { return "VMExecutable"; } - /*! \brief The runtime module/library that contains both the host and also the device - * code when executing on non-CPU devices. */ - runtime::Module lib; /*! \brief The global constant pool. */ std::vector constants; /*! \brief A map from globals (as strings) to their index in the function map. */ diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 5165ae0854fa..d36554bbd516 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -269,10 +269,14 @@ def _collect_dso_modules(self): return self._collect_from_import_tree(is_dso_exportable) def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs): - """Export the module and its imported device code one library. + """ + Export the module and all imported modules into a single device library. - This function only works on host llvm modules. - It will pack all the imported modules + This function only works on host LLVM modules, other runtime::Module + subclasses will work with this API but they must support implement + the save and load mechanisms of modules completely including saving + from streams and files. This will pack your non-shared library module + into a single shared library which can later be loaded by TVM. Parameters ---------- @@ -280,13 +284,20 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No The name of the shared library. fcompile : function(target, file_list, kwargs), optional - Compilation function to use create dynamic library. + The compilation function to use create the final library object during + export. + + For example, when fcompile=_cc.create_shared, or when it is not supplied but + module is "llvm," this is used to link all produced artifacts + into a final dynamic library. + + This behavior is controlled by the type of object exported. If fcompile has attribute object_format, will compile host library to that format. Otherwise, will use default format "o". workspace_dir : str, optional - the path to a directory used to create intermediary - artifacts for the process exporting of the library. + The path of the directory used to create the intermediate + artifacts when exporting the module. If this is not provided a temporary dir will be created. kwargs : dict, optional diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index a503da53c465..d0de0520a674 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -23,6 +23,7 @@ import numpy as np import tvm +from tvm.runtime import Module from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi import base as _base from .object import Object @@ -299,12 +300,44 @@ class VirtualMachine(object): POOLED_ALLOCATOR = 2 def __init__(self, exe, device, memory_cfg=None): - if not isinstance(exe, Executable): + """ + Construct a VirtualMachine wrapper class which provides a simple + interface over the raw C++ Module based API. + + Parameters + ---------- + exe: Union[Executable, Module] + The executable either with the wrapper Python type or the raw runtime.Module. + + In most cases this will be the Python wrapper class tvm.runtime.vm.Executable but + if you instead get the underlying runtime.Module subclass (i.e `exe.mod`) you + can directly pass it to this method. + + This case can occur when doing things such as RPC where TVM's module APIs + return the raw modules, not the wrapped modules. This constructor will + handle this internally. + + device: Union[Device, List[Device]] + The device, or devices on which to execute the VM code. + + memory_cfg: Optional[str] + The allocator behavior to use for the VM. + + Returns + ------- + vm: VirtualMachine + A VM wrapper object. + """ + if not isinstance(exe, Executable) and not isinstance(exe, Module): raise TypeError( "exe is expected to be the type of Executable, " + "but received {}".format(type(exe)) ) - self.module = _ffi_api._VirtualMachine(exe.module) + + if not isinstance(exe, Executable): + exe = Executable(exe) + + self.module = exe.mod["vm_load_executable"]() self._exec = exe self._init = self.module["init"] self._invoke = self.module["invoke"] diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index dafaed111c03..906250c1bb0d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() { auto compile_engine = CompileEngine::Global(); auto ext_mods = compile_engine->LowerExternalFunctions(); + runtime::Module lib; if (funcs.size() > 0) { Map build_funcs; for (const auto& i : funcs) { build_funcs.Set(i.first, i.second); } - exec_->lib = tvm::build(build_funcs, target_host_); + lib = tvm::build(build_funcs, target_host_); } else { // There is no function handled by TVM. We create a virtual main module // to make sure a DSO module will be also available. - exec_->lib = codegen::CSourceModuleCreate(";", "", Array{}); + lib = codegen::CSourceModuleCreate(";", "", Array{}); } - exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_); + lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_); + exec_->SetLib(lib); } ExprDeviceMap VMCompiler::AnalyzeContext() const { diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 30ef2141c508..370dc838839f 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -99,6 +99,28 @@ void InitContextFunctions(std::function fgetsymbol) { #undef TVM_INIT_CONTEXT_FUNC } +Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { + std::string loadkey = "runtime.module.loadbinary_"; + std::string fkey = loadkey + type_key; + const PackedFunc* f = Registry::Get(fkey); + if (f == nullptr) { + std::string loaders = ""; + for (auto name : Registry::ListNames()) { + if (name.find(loadkey, 0) == 0) { + if (loaders.size() > 0) { + loaders += ", "; + } + loaders += name.substr(loadkey.size()); + } + } + LOG(FATAL) << "Binary was created using " << type_key + << " but a loader of that name is not registered. Available loaders are " << loaders + << ". Perhaps you need to recompile with this runtime enabled."; + } + + return (*f)(static_cast(stream)); +} + /*! * \brief Load and append module blob to module list * \param mblob The module blob. @@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { ICHECK(stream->Read(&import_tree_row_ptr)); ICHECK(stream->Read(&import_tree_child_indices)); } else { - std::string loadkey = "runtime.module.loadbinary_"; - std::string fkey = loadkey + tkey; - const PackedFunc* f = Registry::Get(fkey); - if (f == nullptr) { - std::string loaders = ""; - for (auto name : Registry::ListNames()) { - if (name.rfind(loadkey, 0) == 0) { - if (loaders.size() > 0) { - loaders += ", "; - } - loaders += name.substr(loadkey.size()); - } - } - ICHECK(f != nullptr) - << "Binary was created using " << tkey - << " but a loader of that name is not registered. Available loaders are " << loaders - << ". Perhaps you need to recompile with this runtime enabled."; - } - Module m = (*f)(static_cast(stream)); + auto m = LoadModuleFromBinary(tkey, stream); modules.emplace_back(m); } } diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 91918c1ccaa3..00c79e8248f4 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -29,9 +29,21 @@ #include #include +#include namespace tvm { namespace runtime { + +/*! \brief Load a module with the given type key directly from the stream. + * This function wraps the registry mechanism used to store type based deserializers + * for each runtime::Module sub-class. + * + * \param type_key The type key of the serialized module. + * \param stream A pointer to the stream containing the serialized module. + * \return module The deserialized module. + */ +Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream); + /*! * \brief Library is the common interface * for storing data in the form of shared libaries. diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 6992097e8d69..e8b948d3d2ae 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -37,6 +37,8 @@ #include #include +#include "../file_utils.h" +#include "../library_module.h" #include "serialize_utils.h" namespace tvm { @@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetFunctionParameterName(func_name, index); }); + } else if (name == "vm_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + auto vm = make_object(); + vm->LoadExecutable(this); + *rv = Module(vm); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); @@ -475,9 +483,37 @@ void LoadHeader(dmlc::Stream* strm) { STREAM_CHECK(version == TVM_VERSION, "version"); } +runtime::Module Executable::GetLib() const { + ICHECK_LE(this->imports_.size(), 1) + << "The kernel library must be imported as the only module in an Executable"; + + if (this->imports().size() == 0) { + return Module(nullptr); + } else { + return this->imports_[0]; + } +} + +void Executable::SetLib(const runtime::Module& lib) { + ICHECK(lib.defined()) << "the provided library can not be null"; + + ICHECK_EQ(this->imports_.size(), 0) + << "A VMExecutable should never have more than one import inside an the executable, \n" + << "the first import should *always* be the library containing" + << "the platform specific kernel code"; + + this->Import(lib); +} + runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { auto exec = make_object(); - exec->lib = lib; + + // Support null-initialization of lib, to enable initialization during + // deserialization before we have we have deserialized the imports. + if (lib.defined()) { + exec->SetLib(lib); + } + exec->code_ = code; dmlc::MemoryStringStream strm(&exec->code_); @@ -765,6 +801,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } } +void Executable::SaveToBinary(dmlc::Stream* stream) { + auto code_bytes = this->Save(); + std::string code(code_bytes.data, code_bytes.size); + stream->Write(code); + + ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization"; +} + +Module ExecutableLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string code; + stream->Read(&code); + auto exec = Executable::Load(code, Module()); + return exec; +} + +void Executable::SaveToFile(const std::string& path, const std::string& format) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + SaveToBinary(strm); + SaveBinaryToFile(path, data); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary); + +// Load module from module. +Module ExecutableLoadFile(const std::string& file_name, const std::string& format) { + std::string data; + LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + auto exec = ExecutableLoadBinary(reinterpret_cast(strm)); + return exec; +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile); + TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index ee06da83bd92..76ca009bc741 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { ICHECK(exec) << "The executable is not created yet."; exec_ = exec; - runtime::Module lib = exec_->lib; - // Get the list of packed functions. + runtime::Module lib = exec_->GetLib(); + ICHECK(exec->primitive_map.empty() || lib.operator->()) - << "runtime module should have been built for primitive functions" - << "\n"; + << "If the executable has declared primitive functions, the" + << "generated kernel library must non-be null."; + for (const auto& it : exec_->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast(it.second); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 4ecd0d9189ea..c1bdc3ff9fd0 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -19,11 +19,14 @@ import tvm from tvm import runtime -from tvm import relay +from tvm import relay, IRModule +from tvm.relay.backend import vm from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.relay.loops import while_loop from tvm.relay import testing +from tvm.contrib import utils +from tvm import rpc import tvm.testing @@ -799,5 +802,46 @@ def test_constant_shape_with_external_codegen(): assert "shape_func" in opt_mod.astext(False) +def test_vm_rpc(): + """ + This test checks to make sure you can export a VMExecutable, + upload it to a remote machine using RPC and then execute it + on the other machine. + """ + target = "llvm" + target_host = "llvm" + + # Build a IRModule. + x = relay.var("x", shape=(10, 1)) + f = relay.Function([x], x + x) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target, target_host=target_host) + + # Export to Disk + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + vm_exec.mod.export_library(path) + + # Use LocalRPC for testing. + remote = rpc.LocalSession() + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + ctx = remote.cpu() + # Build a VM out of the executable and context. + vm_factory = runtime.vm.VirtualMachine(rexec, ctx) + np_input = np.random.uniform(size=(10, 1)).astype("float32") + input_tensor = tvm.nd.array(np_input, ctx) + # Invoke its "main" function. + out = vm_factory.invoke("main", [input_tensor]) + # Check the result. + np.testing.assert_allclose(out.asnumpy(), np_input + np_input) + + if __name__ == "__main__": pytest.main([__file__])