From 4be15afce94c09d55f88918d4c7ad646f88ca7b9 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 9 Oct 2019 23:58:16 +0000 Subject: [PATCH 1/6] [relay][vm] Separate VM runtime with executable --- include/tvm/runtime/vm.h | 130 +++++++--- python/tvm/relay/backend/deserializer.py | 16 +- python/tvm/relay/backend/profiler_vm.py | 11 +- python/tvm/relay/backend/serializer.py | 107 ++------ python/tvm/relay/backend/vm.py | 164 +++++++++--- src/relay/backend/vm/compiler.cc | 20 +- src/relay/backend/vm/compiler.h | 12 +- src/relay/backend/vm/profiler/compiler.cc | 1 - .../backend => runtime}/vm/deserializer.cc | 47 ++-- .../backend => runtime}/vm/deserializer.h | 37 +-- src/runtime/vm/executable.cc | 237 ++++++++++++++++++ src/runtime/vm/profiler/vm.cc | 38 +-- src/runtime/vm/profiler/vm.h | 4 +- .../backend => runtime}/vm/serialize_util.h | 12 +- .../backend => runtime}/vm/serializer.cc | 197 ++++----------- .../backend => runtime}/vm/serializer.h | 96 ++----- src/runtime/vm/vm.cc | 117 ++++----- tests/python/relay/test_vm.py | 30 +-- tests/python/relay/test_vm_serialization.py | 91 +++---- .../unittest/test_runtime_vm_profiler.py | 6 +- 20 files changed, 762 insertions(+), 611 deletions(-) rename src/{relay/backend => runtime}/vm/deserializer.cc (90%) rename src/{relay/backend => runtime}/vm/deserializer.h (72%) create mode 100644 src/runtime/vm/executable.cc rename src/{relay/backend => runtime}/vm/serialize_util.h (95%) rename src/{relay/backend => runtime}/vm/serializer.cc (65%) rename src/{relay/backend => runtime}/vm/serializer.h (63%) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index aa8543d569af..14816e84c6ab 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -430,15 +431,101 @@ struct VMFrame { caller_return_register(0) {} }; +/*! \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to create a virtual machine. + */ +class Executable : public ModuleNode { + public: + /*! + * \brief Get a PackedFunc from an executable module. + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \brief Get the serialized form of the `functions` in `vm_`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + +/*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! \brief Get the `lib` module in an executable. Users have the flexibility to call + * `export_library` from the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + runtime::Module GetLib() const { return lib; } + + /*! + * \brief Set the execution context for the executable. + * + * \param ctxs The list of TVMContext. + */ + void SetContext(const std::vector& ctxs); + + /*! \brief Get device context for params. + */ + TVMContext GetParamsContext() const; + + virtual ~Executable() {} + + const char* type_key() const final { + return "VMExecutable"; + } + + /*! \brief The runtime module/library that contains hardware dependent code. */ + runtime::Module lib; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map primitive_map; + /*! \brief The virtual machine's function table. */ + std::vector functions; + + /*! \brief The set of TVM contexts the VM is currently executing on. */ + std::vector ctxs; +}; + /*! \brief The virtual machine. * * The virtual machine contains all the current execution state, - * as well as the global view of functions, the global constant - * table, the compiled operators. + * as well as the executable. * * The goal is to have a single self-contained object, * enabling one to easily pass around VMs, execute them on - * multiple threads, or serialized them to disk or over the + * multiple threads, or serialize them to disk or over the * wire. */ class VirtualMachine : public runtime::ModuleNode { @@ -486,16 +573,10 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } - /*! \brief The runtime module/library that contains generated code. */ - runtime::Module lib; /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; - /*! \brief The virtual machine's function table. */ - std::vector functions; /*! \brief The current stack of call frames. */ std::vector frames; - /*! \brief The global constant pool. */ - std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -506,8 +587,8 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register; - /*! \brief The set of TVM contexts the VM is currently executing on. */ - std::vector ctxs; + /*! \brief The executable the VM will operate on. */ + const Executable* exec; /*! \brief Push a call frame on to the call stack. */ void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func); @@ -550,36 +631,17 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} - /*! \brief Initialize the virtual machine for a set of contexts. - * \param contexts The set of TVM contexts. + /*! \brief Initialize the virtual machine using an executable. + * \param exec The executable. */ - void Init(const std::vector& contexts); + void Init(const Executable* exec); /*! \brief Run VM dispatch loop. */ void RunLoop(); - /*! \brief Get device context for params. - */ - TVMContext GetParamsContext() const; - - /*! - * \brief Load parameters from the parameter bytearray. - * \param params The binary file that contains parameters. - */ - void LoadParams(const std::string& params); - - /*! \brief A map from globals (as strings) to their index in the function map. - */ - std::unordered_map global_map; - - /*! \brief A mapping from the packed function (as string) to the index that - * corresponds to the position of the `packed_funcs` list. - */ - std::unordered_map primitive_map; - private: /*! \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py index fde702b1cd04..b5e2353be501 100644 --- a/python/tvm/relay/backend/deserializer.py +++ b/python/tvm/relay/backend/deserializer.py @@ -31,7 +31,7 @@ def _create_deserializer(code, lib): Parameters ---------- code : bytearray - The serialized virtual machine code. + The serialized virtual machine bytecode. lib : :py:class:`~tvm.module.Module` The serialized runtime module/library that contains the hardware @@ -40,7 +40,7 @@ def _create_deserializer(code, lib): Returns ------- ret : Deserializer - The created virtual machine deserializer. + The created virtual machine executable deserializer. """ if isinstance(code, (bytes, str)): code = bytearray(code) @@ -55,12 +55,12 @@ def _create_deserializer(code, lib): class Deserializer: - """Relay VM deserializer. + """Relay VM executable deserializer. Parameters ---------- code : bytearray - The serialized virtual machine code. + The serialized virtual machine bytecode. lib : :py:class:`~tvm.module.Module` The serialized runtime module/library that contains the hardware @@ -71,11 +71,11 @@ def __init__(self, code, lib): self._deserialize = self.mod["deserialize"] def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM. + """Deserialize the serialized bytecode into a Relay VM executable. Returns ------- - ret : VirtualMachine - The deserialized Relay VM. + ret : Executable + The deserialized Relay VM executable. """ - return rly_vm.VirtualMachine(self._deserialize()) + return rly_vm.Executable(self._deserialize()) diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 8ae3161e0b83..30c99611b7c9 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachineProfiler - The profile VM runtime. + exec : Executable + The executable with profiling code. """ compiler = VMCompilerProfiler() target = compiler.update_target(target) @@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachineProfiler(compiler._get_vm()) + return vm.Executable(compiler._get_exec()) class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" @@ -68,13 +68,16 @@ def __init__(self): super().__init__() self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] class VirtualMachineProfiler(vm.VirtualMachine): """Relay profile VM runtime.""" def __init__(self, mod): super().__init__(mod) + m = mod.module if isinstance(mod, vm.Executable) else mod + self.mod = _vm._VirtualMachineDebug(m) + self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] def get_stat(self): diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py index b45ba9116a15..4680ada2eaac 100644 --- a/python/tvm/relay/backend/serializer.py +++ b/python/tvm/relay/backend/serializer.py @@ -24,100 +24,35 @@ from . import _vm from . import vm as rly_vm -def _create_serializer(vm): +def _create_serializer(executable): """Create a VM serializer. Parameters ---------- - vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] - The virtual machine to be serialized. + executable : Union[Executable, :py:class:`~tvm.module.Module`] + The virtual machine executable to be serialized. Returns ------- ret : Serializer - The created virtual machine serializer. + The created virtual machine executable serializer. """ - if isinstance(vm, rly_vm.VirtualMachine): - vm = vm.module - elif not isinstance(vm, tvm.module.Module): - raise TypeError("vm is expected to be the type of VirtualMachine or " + - "tvm.Module, but received {}".format(type(vm))) + if isinstance(executable, rly_vm.Executable): + executable = executable.module + elif not isinstance(executable, tvm.module.Module): + raise TypeError("executable is expected to be an Executable or " + + "tvm.Module, but received {}".format(type(executable))) - return _vm._Serializer(vm) + return _vm._Serializer(executable) class Serializer: """Relay VM serializer.""" - def __init__(self, vm): - self.mod = _create_serializer(vm) + def __init__(self, executable): + self.mod = _create_serializer(executable) self._get_lib = self.mod["get_lib"] - self._get_bytecode = self.mod["get_bytecode"] - self._get_globals = self.mod["get_globals"] - self._get_stats = self.mod["get_stats"] - self._get_primitive_ops = self.mod["get_primitive_ops"] self._serialize = self.mod["serialize"] - @property - def stats(self): - """Get the statistics of the Relay VM. - - Returns - ------- - ret : String - The serialized statistic information. - """ - return self._get_stats() - - @property - def primitive_ops(self): - """Get the name of the primitive ops that are executed in the VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The list of primitive ops. - """ - return [prim_op.value for prim_op in self._get_primitive_ops()] - - @property - def bytecode(self): - """Get the bytecode of the Relay VM. - - Returns - ------- - ret : String - The serialized bytecode. - - Notes - ----- - The bytecode is in the following format: - func_name reg_file_size num_instructions - param1 param2 ... paramM - instruction1 - instruction2 - ... - instructionN - - Each instruction is printed in the following format: - hash opcode field1 ... fieldX # The text format. - - The part starting from # is only used for visualization and debugging. - The real serialized code doesn't contain it, therefore the deserializer - doesn't need to deal with it as well. - """ - return self._get_bytecode() - - @property - def globals(self): - """Get the globals used by the Relay VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The serialized globals. - """ - return [glb.value for glb in self._get_globals()] - def serialize(self): """Serialize the Relay VM. @@ -160,31 +95,31 @@ def serialize(self): # create a Relay VM. ctx = tvm.cpu() target = "llvm" - compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) - vm.init(ctx) + executable = relay.vm..compile(mod, target) + executable.set_context(ctx) # serialize. - ser = relay.serializer.Serializer(vm) + ser = relay.serializer.Serializer(executable) code, lib = ser.serialize() # save and load the code and lib file. tmp = tvm.contrib.util.tempdir() path_lib = tmp.relpath("lib.so") lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: + with open(tmp.relpath("code.ro"), "wb") as fo: fo.write(code) loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() - # execute the deserialized vm. - des_vm.init(ctx) + # execute the deserialized executable. + des_exec.set_context(ctx) x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) res = des_vm.run(x_data) print(res.asnumpy()) """ diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index c24b16ca6437..142e2b7d32e6 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,7 +24,7 @@ import tvm from tvm import autotvm -from tvm._ffi.runtime_ctypes import TVMByteArray +from tvm import TVMContext from tvm.relay import expr as _expr from . import _vm from . import vmobj as _obj @@ -44,6 +44,7 @@ def _convert(arg, cargs): else: raise "unsupported type" + def convert(args): cargs = [] for arg in args: @@ -52,41 +53,136 @@ def convert(args): return cargs -class VirtualMachine(object): - """Relay VM runtime.""" +class Executable(object): + """Relay VM executable""" def __init__(self, mod): self.mod = mod - self._init = self.mod["init"] - self._load_params = self.mod["load_params"] - self._invoke = self.mod["invoke"] + self._set_context = self.mod["set_context"] + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_stats = self.mod["get_stats"] - def init(self, ctx): - """Initialize the context in the VM. + def set_context(self, ctx): + """Initialize the context of the VM executable. Parameters ---------- - ctx : :py:class:`TVMContext` + ctx : Union[:py:class:`tvm.TVMContext`, List[py:class:`tvm.TVMContext`]] The runtime context to run the code on. """ - args = [ctx.device_type, ctx.device_id] - self._init(*args) - def load_params(self, params): - """Load parameters for the VM. + if isinstance(ctx, TVMContext): + ctx = [ctx] + elif not isinstance(ctx, (list, tuple)): + raise ValueError("ctx has to be the type of TVMContext or a list of " + "TVMContext") + # args[0], args[1] are used as the primary/fallback context type and id + # for heterogeneous execution. + args = [] + for cur_ctx in ctx: + if not isinstance(cur_ctx, TVMContext): + raise ValueError("ctx has to be the type of TVMContext or a list " + "of TVMContext") + args.append(cur_ctx.device_type) + args.append(cur_ctx.device_id) + + self._set_context(*args) - Parameters - ---------- - params : Union[bytearray, Dict] - The dictionary that contains serialized parameters. + @property + def lib(self): + """Get the library that contains hardware dependent code. + + Returns + ------- + ret : :py:class:`~tvm.Module` + The runtime module that contains hardware dependent code. + """ + return self._get_lib() + + @property + def stats(self): + """Get the statistics of the Relay VM executable. + + Returns + ------- + ret : String + The statistic information of the VM executable. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops contained in the executable. + + Returns + ------- + ret : List[String] + The list of primitive ops. """ - if isinstance(params, dict): - params = tvm.relay.save_param_dict(params) - elif isinstance(params, (bytes, str)): - params = bytearray(params) - if not isinstance(params, (bytearray, TVMByteArray)): - raise TypeError("params must be a bytearray") + ret = [] + num_primitives = _vm.GetNumOfPrimitives(self.module) + for i in range(num_primitives): + ret.append(_vm.GetPrimitiveFields(self.module, i)) + return ret - self._load_params(bytearray(params)) + @property + def bytecode(self): + """Get the bytecode of the Relay VM executable. + + Returns + ------- + ret : String + The bytecode of the executable. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM executable. + + Returns + ------- + ret : List[String] + The globals contained in the executable. + """ + ret = [] + num_globals = _vm.GetNumOfGlobals(self.module) + for i in range(num_globals): + ret.append(_vm.GetGlobalFields(self.module, i)) + return ret + + @property + def module(self): + """Return the runtime module contained in a virtual machine executable.""" + return self.mod + + +class VirtualMachine(object): + """Relay VM runtime.""" + def __init__(self, mod): + if not isinstance(mod, (Executable, tvm.module.Module)): + raise TypeError("mod is expected to be the type of Executable or " + + "tvm.Module, but received {}".format(type(mod))) + m = mod.module if isinstance(mod, Executable) else mod + self.mod = _vm._VirtualMachine(m) + self._invoke = self.mod["invoke"] def invoke(self, func_name, *args): """Invoke a function. @@ -122,11 +218,6 @@ def run(self, *args): """ return self.invoke("main", *args) - @property - def module(self): - """Return the runtime module contained in a virtual machine.""" - return self.mod - def compile(mod, target=None, target_host=None, params=None): """ @@ -155,8 +246,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachine - The VM runtime. + exec : Executable + The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() @@ -167,14 +258,14 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachine(compiler._get_vm()) + return Executable(compiler._get_exec()) class VMCompiler(object): """Build Relay module to run on VM runtime.""" def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): @@ -240,7 +331,7 @@ class VMExecutor(Executor): mod : :py:class:`~tvm.relay.module.Module` The module to support the execution. - ctx : :py:class:`TVMContext` + ctx : :py:class:`~tvm.TVMContext` The runtime context to run the code on. target : :py:class:`Target` @@ -252,8 +343,9 @@ def __init__(self, mod, ctx, target): self.mod = mod self.ctx = ctx self.target = target - self.vm = compile(mod, target) - self.vm.init(ctx) + self.executable = compile(mod, target) + self.executable.set_context(ctx) + self.vm = VirtualMachine(self.executable) def _make_executor(self, expr=None): main = self.mod["main"] diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0cfae374ab2c..f295ccd7a555 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, Module mod = args[0]; this->Compile(mod, args[1], args[2]); }); - } else if (name == "get_vm") { + } else if (name == "get_executable") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(vm_); + *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, // Next we get ready by allocating space for // the global state. - vm_->functions.resize(context_.module->functions.size()); + exec_->functions.resize(context_.module->functions.size()); for (auto named_func : context_.module->functions) { auto gvar = named_func.first; @@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); - CHECK(func_index < vm_->functions.size()); - vm_->functions[func_index] = vm_func; + CHECK(func_index < exec_->functions.size()); + exec_->functions[func_index] = vm_func; } #if USE_RELAY_DEBUG - for (auto vm_func : vm_->functions) { + for (auto vm_func : exec_->functions) { DLOG(INFO) << vm_func << "-------------"; } #endif // USE_RELAY_DEBUG // populate constants for (auto data : context_.constants) { - vm_->constants.push_back(runtime::vm::Tensor(data)); + exec_->constants.push_back(runtime::vm::Tensor(data)); } LibraryCodegen(); for (auto gv : context_.global_map) { - vm_->global_map.insert({gv.first->name_hint, gv.second}); + exec_->global_map.insert({gv.first->name_hint, gv.second}); } } @@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { // therefore target won't be used in the build function runtime::Module mod = (*f)(funcs, Target(), target_host_); CHECK(mod.operator->()); - vm_->lib = mod; + exec_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } size_t primitive_index = 0; for (auto cfunc : cached_funcs) { - vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index dff1ef7f4569..215cc12c4cdb 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { return "VMCompiler"; } - std::shared_ptr GetVirtualMachine() const { - return vm_; - } - - virtual void InitVM() { - vm_ = std::make_shared(); + void InitVM() { + exec_ = std::make_shared(); } /*! @@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { tvm::Target target_host_; /*! \brief Global shared meta data */ VMCompilerContext context_; - /*! \brief Compiled virtual machine. */ - std::shared_ptr vm_; + /*! \brief Compiled executable. */ + std::shared_ptr exec_; /*! \brief parameters */ std::unordered_map params_; }; diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc index 9fd28e8c7f46..60c441a60cf0 100644 --- a/src/relay/backend/vm/profiler/compiler.cc +++ b/src/relay/backend/vm/profiler/compiler.cc @@ -33,7 +33,6 @@ namespace vm { class VMCompilerDebug : public VMCompiler { public: VMCompilerDebug() {} - void InitVM() override { vm_ = std::make_shared(); } virtual ~VMCompilerDebug() {} }; diff --git a/src/relay/backend/vm/deserializer.cc b/src/runtime/vm/deserializer.cc similarity index 90% rename from src/relay/backend/vm/deserializer.cc rename to src/runtime/vm/deserializer.cc index 777282782e99..eb191545bcc6 100644 --- a/src/relay/backend/vm/deserializer.cc +++ b/src/runtime/vm/deserializer.cc @@ -19,8 +19,8 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.cc - * \brief Implementation of APIs to deserialize the serialized VM bytecode. + * \file src/runtime/vm/deserializer.cc + * \brief Implementation of APIs to deserialize the serialized VM executable. */ #include "deserializer.h" @@ -32,17 +32,17 @@ #include "serialize_util.h" namespace tvm { -namespace relay { +namespace runtime { namespace vm { #define STREAM_CHECK(val, section) \ CHECK(val) << "Invalid VM file format in the " << section << " section." \ << "\n"; -void Deserializer::Init(const std::string& code, const runtime::Module& lib) { +inline void Deserializer::Init(const std::string& code, const runtime::Module& lib) { code_ = code; - vm_ = std::make_shared(); - vm_->lib = lib; + exec_ = std::make_shared(); + exec_->lib = lib; strm_ = new dmlc::MemoryStringStream(&code_); } @@ -52,7 +52,7 @@ runtime::PackedFunc Deserializer::GetFunction( if (name == "deserialize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Deserialize(); - *rv = runtime::Module(vm_); + *rv = runtime::Module(exec_); }); } else { LOG(FATAL) << "Unknown packed function: " << name; @@ -82,13 +82,16 @@ void Deserializer::Deserialize() { // Code section. DeserializeCodeSection(); + + // Context section. + DeserializeContextSection(); } void Deserializer::DeserializeGlobalSection() { std::vector globals; STREAM_CHECK(strm_->Read(&globals), "global"); for (size_t i = 0; i < globals.size(); i++) { - vm_->global_map.insert({globals[i], i}); + exec_->global_map.insert({globals[i], i}); } } @@ -103,7 +106,7 @@ void Deserializer::DeserializeConstantSection() { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm_), "constant"); runtime::ObjectRef obj = runtime::vm::Tensor(constant); - vm_->constants.push_back(obj); + exec_->constants.push_back(obj); } } @@ -111,7 +114,7 @@ void Deserializer::DeserializePrimitiveOpNames() { std::vector primitive_names; STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); for (size_t i = 0; i < primitive_names.size(); i++) { - vm_->primitive_map.insert({primitive_names[i], i}); + exec_->primitive_map.insert({primitive_names[i], i}); } } @@ -283,7 +286,7 @@ void Deserializer::DeserializeCodeSection() { STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); size_t num_funcs = static_cast(sz); - vm_->functions.resize(num_funcs); + exec_->functions.resize(num_funcs); for (size_t i = 0; i < num_funcs; i++) { // Load the function info. VMFunctionSerializer loaded_func; @@ -303,10 +306,22 @@ void Deserializer::DeserializeCodeSection() { loaded_func.params, instructions, loaded_func.register_file_size); - auto it = vm_->global_map.find(loaded_func.name); - CHECK(it != vm_->global_map.end()); - CHECK_LE(it->second, vm_->global_map.size()); - vm_->functions[it->second] = vm_func; + auto it = exec_->global_map.find(loaded_func.name); + CHECK(it != exec_->global_map.end()); + CHECK_LE(it->second, exec_->global_map.size()); + exec_->functions[it->second] = vm_func; + } +} + +void Deserializer::DeserializeContextSection() { + std::vector ctxs; + STREAM_CHECK(strm_->Read(&ctxs), "context"); + CHECK_EQ(ctxs.size() % 2, 0U); + for (size_t i = 0; i < ctxs.size(); i += 2) { + TVMContext ctx; + ctx.device_type = DLDeviceType(ctxs[i]); + ctx.device_id = static_cast(ctxs[i + 1]); + exec_->ctxs.push_back(ctx); } } @@ -320,5 +335,5 @@ TVM_REGISTER_GLOBAL("relay._vm._Deserializer") .set_body_typed(CreateDeserializer); } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm diff --git a/src/relay/backend/vm/deserializer.h b/src/runtime/vm/deserializer.h similarity index 72% rename from src/relay/backend/vm/deserializer.h rename to src/runtime/vm/deserializer.h index 0caf72bee92c..54eb02075e8d 100644 --- a/src/relay/backend/vm/deserializer.h +++ b/src/runtime/vm/deserializer.h @@ -19,15 +19,14 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.h - * \brief Define a deserializer for the serialized Relay VM. + * \file src/runtime/vm/deserializer.h + * \brief Define a deserializer for the serialized Relay VM executable. */ -#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ +#ifndef TVM_RUNTIME_VM_DESERIALIZER_H_ +#define TVM_RUNTIME_VM_DESERIALIZER_H_ #include -#include #include #include @@ -37,7 +36,7 @@ #include namespace tvm { -namespace relay { +namespace runtime { namespace vm { using namespace tvm::runtime::vm; @@ -46,13 +45,14 @@ namespace runtime = tvm::runtime; class Deserializer : public runtime::ModuleNode { public: /*! - * \brief Initialize the deserializer for creating a virtual machine object. + * \brief Initialize the deserializer for creating a virtual machine + * executable object. * * \param code The serialized code. * \param lib The serialized runtime module/library that contains the * hardware dependent code. */ - inline void Init(const std::string& code, const runtime::Module& lib); + void Init(const std::string& code, const runtime::Module& lib); /*! * \brief Return the member function to the frontend. @@ -67,36 +67,39 @@ class Deserializer : public runtime::ModuleNode { const char* type_key() const final { return "Deserializer"; } - /*! \brief Deserialize the serialized VM. */ + /*! \brief Deserialize the serialized VM executable. */ void Deserialize(); virtual ~Deserializer() { delete strm_; } private: - /*! \brief Deserialize the globals in `vm_`. */ + /*! \brief Deserialize the globals in `exec_`. */ void DeserializeGlobalSection(); - /*! \brief Deserialize the constant pool in `vm_`. */ + /*! \brief Deserialize the constant pool in `exec_`. */ void DeserializeConstantSection(); - /*! \brief Deserialize primitive op names in `vm_`. */ + /*! \brief Deserialize primitive op names in `exec_`. */ void DeserializePrimitiveOpNames(); - /*! \brief Deserialize the vm functions in `vm_`. */ + /*! \brief Deserialize the vm functions in `exec_`. */ void DeserializeCodeSection(); + /*! \brief Deserialize the context in `exec_`. */ + void DeserializeContextSection(); + /*! \brief The code to be serialized. */ std::string code_; /*! \brief The stream used for serialization. */ dmlc::Stream* strm_; - /*! \brief The VM to be created. */ - std::shared_ptr vm_; + /*! \brief The VM executable to be created. */ + std::shared_ptr exec_; }; } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ +#endif // TVM_RUNTIME_VM_DESERIALIZER_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc new file mode 100644 index 000000000000..9c85922926c4 --- /dev/null +++ b/src/runtime/vm/executable.cc @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/vm/executable.cc + * \brief The implementation of a virtual machine executable APIs. + */ + +#include +#include + +#include +#include +#include +#include + +#include "serializer.h" + +namespace tvm { +namespace runtime { +namespace vm { + +PackedFunc Executable::GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "set_context") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0); + std::vector contexts; + for (int i = 0; i < args.size() / 2; ++i) { + TVMContext ctx; + int device_type = args[i * 2]; + ctx.device_type = DLDeviceType(device_type); + ctx.device_id = args[i * 2 + 1]; + contexts.push_back(ctx); + } + this->SetContext(contexts); + }); + } else if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } +} + +inline void Executable::SetContext(const std::vector& ctxs) { + this->ctxs = ctxs; +} + +std::string Executable::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters:"<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relay VM executable statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << constants.size() << "): ["; + for (const auto& it : constants) { + const auto* cell = it.as(); + CHECK(cell); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_map.size() << "): ["; + for (const auto& it : global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << primitive_map.size() << "): ["; + std::vector prim_ops; + for (const auto& it : primitive_map) { + auto packed_index = static_cast(it.second); + if (prim_ops.size() <= packed_index) { + prim_ops.resize(packed_index + 1); + } + prim_ops[packed_index] = it.first; + } + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +TVMContext Executable::GetParamsContext() const { + CHECK(!ctxs.empty()) << "context has not been set yet."; + + // Use the fallback device if no device index is available. + int fallback_device_type = static_cast(ctxs[0].device_type); + // TODO(wweic): For heterogeneous execution, get device information from byte + + const auto& cit = + std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { + return fallback_device_type == static_cast(c.device_type); + }); + return (cit == ctxs.end() ? ctxs[0] : *cit); +} + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->global_map.size()); +}); + + +TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + std::vector > globals(exec->global_map.begin(), + exec->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + CHECK_LT(idx, globals.size()); + *rv = globals[idx].first; +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->primitive_map.size()); +}); + + +TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + CHECK_GE(idx, 0); + CHECK_LT(idx, exec->primitive_map.size()); + + for (const auto& it : exec->primitive_map) { + if (idx == static_cast(it.second)) { + *rv = it.first; + break; + } + } +}); + +} // namespace vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 80e0ce57a8ae..dc5ec9943b3b 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -67,27 +67,15 @@ PackedFunc VirtualMachineDebug::GetFunction( os << "Total Duration " << total_duration << " us" << std::endl; *rv = os.str(); }); - } else if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size() % 2, 0); - std::vector contexts; - for (int i = 0; i < args.size() / 2; ++i) { - TVMContext ctx; - int device_type = args[i * 2]; - ctx.device_type = DLDeviceType(device_type); - ctx.device_id = args[i * 2 + 1]; - contexts.push_back(ctx); - } - this->Init(contexts); - }); } else { return VirtualMachine::GetFunction(name, sptr_to_self); } } -void VirtualMachineDebug::Init(const std::vector& ctxs) { - VirtualMachine::Init(ctxs); - for (auto kv : primitive_map) { +void VirtualMachineDebug::Init(const Executable* exec) { + VirtualMachine::Init(exec); + CHECK(this->exec); + for (auto kv : this->exec->primitive_map) { packed_index_map[kv.second] = kv.first; op_invokes[kv.second] = 0; } @@ -97,7 +85,8 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { - auto ctx = VirtualMachine::GetParamsContext(); + CHECK(this->exec); + auto ctx = this->exec->GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); @@ -117,6 +106,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, op_invokes[packed_index] += 1; } +runtime::Module CreateVirtualMachineDebug(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->Init(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateVirtualMachineDebug(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 447967cafeb0..d2e71597e80d 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -47,11 +47,11 @@ class VirtualMachineDebug : public VirtualMachine { void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void Init(const Executable* exec); + ~VirtualMachineDebug() {} private: - void Init(const std::vector& ctxs); - std::unordered_map packed_index_map; std::unordered_map> op_durations; std::unordered_map op_invokes; diff --git a/src/relay/backend/vm/serialize_util.h b/src/runtime/vm/serialize_util.h similarity index 95% rename from src/relay/backend/vm/serialize_util.h rename to src/runtime/vm/serialize_util.h index 3e7508ebee9b..3931f2f0e023 100644 --- a/src/relay/backend/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -19,11 +19,11 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serialize_util.h + * \file src/runtime/vm/serialize_util.h * \brief Definitions of helpers for serializing and deserializing a Relay VM. */ -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ +#define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ #include #include @@ -34,7 +34,7 @@ #include namespace tvm { -namespace relay { +namespace runtime { namespace vm { /*! \brief The magic number for the serialized VM bytecode file */ @@ -158,7 +158,7 @@ struct VMInstructionSerializer { }; } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ diff --git a/src/relay/backend/vm/serializer.cc b/src/runtime/vm/serializer.cc similarity index 65% rename from src/relay/backend/vm/serializer.cc rename to src/runtime/vm/serializer.cc index 0040ef9db470..3d52abdc965b 100644 --- a/src/relay/backend/vm/serializer.cc +++ b/src/runtime/vm/serializer.cc @@ -19,8 +19,8 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM. + * \file src/runtime/vm/serializer.cc + * \brief Implementation of serializing APIs for the Relay VM executable. */ #include "serializer.h" @@ -36,11 +36,12 @@ #include "serialize_util.h" namespace tvm { -namespace relay { +namespace runtime { namespace vm { -void Serializer::Init(const VirtualMachine* vm) { - vm_ = vm; +inline void Serializer::Init(const Executable* exec) { + CHECK(exec); + exec_ = exec; // Initialize the stream object. strm_ = new dmlc::MemoryStringStream(&code_); } @@ -50,23 +51,7 @@ runtime::PackedFunc Serializer::GetFunction( const std::shared_ptr& sptr_to_self) { if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); - } else if (name == "get_primitive_ops") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPrimitiveOps(); - }); - } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); - } else if (name == "get_globals") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGlobals(); - }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); + *rv = this->exec_->GetLib(); }); } else if (name == "serialize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -78,67 +63,6 @@ runtime::PackedFunc Serializer::GetFunction( } } -tvm::Array Serializer::GetPrimitiveOps() const { - std::vector ret; - for (const auto& it : vm_->primitive_map) { - auto packed_name = tvm::ir::StringImm::make(it.first); - auto packed_index = static_cast(it.second); - if (ret.size() <= packed_index) { - ret.resize(packed_index + 1); - } - ret[packed_index] = packed_name; - } - return ret; -} - -std::string Serializer::Stats() const { - std::ostringstream oss; - oss << "Relay VM statistics:" << std::endl; - - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << vm_->constants.size() << "): ["; - for (const auto& it : vm_->constants) { - auto* cell = it.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); - - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - - oss << "["; - for (auto s : shape) { - oss << s << ", "; - } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; - } - if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of globals and the name of each of them. - oss << " Globals (#" << vm_->global_map.size() << "): ["; - for (const auto& it : vm_->global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; - } - if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of primitive ops and the name of each of them. - oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; - const auto& prim_ops = GetPrimitiveOps(); - for (const auto& it : prim_ops) { - oss << it << ", "; - } - if (!prim_ops.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - return oss.str(); -} - TVMByteArray Serializer::Serialize() { uint64_t header = kTVMVMBytecodeMagic; strm_->Write(header); @@ -157,6 +81,9 @@ TVMByteArray Serializer::Serialize() { // Code section. SerializeCodeSection(); + // Context section. + SerializeContextSection(); + TVMByteArray arr; arr.data = code_.c_str(); arr.size = code_.length(); @@ -164,33 +91,43 @@ TVMByteArray Serializer::Serialize() { } void Serializer::SerializeGlobalSection() { - auto globals = GetGlobals(); + std::vector > globals(exec_->global_map.begin(), + exec_->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + std::vector glbs; for (const auto& it : globals) { - glbs.push_back(it.as()->value); + glbs.push_back(it.first); } strm_->Write(glbs); } void Serializer::SerializeConstantSection() { std::vector arrays; - for (const auto& obj : vm_->constants) { + for (const auto& obj : exec_->constants) { const auto* cell = obj.as(); CHECK(cell != nullptr); runtime::NDArray data = cell->data; arrays.push_back(const_cast(data.operator->())); } - strm_->Write(static_cast(vm_->constants.size())); + strm_->Write(static_cast(exec_->constants.size())); for (const auto& it : arrays) { runtime::SaveDLTensor(strm_, it); } } void Serializer::SerializePrimitiveOpNames() { - auto names = GetPrimitiveOps(); std::vector primitive_names; - for (const auto& it : names) { - primitive_names.push_back(it.as()->value); + for (const auto& it : exec_->primitive_map) { + auto packed_index = static_cast(it.second); + if (primitive_names.size() <= packed_index) { + primitive_names.resize(packed_index + 1); + } + primitive_names[packed_index] = it.first; } strm_->Write(primitive_names); } @@ -346,8 +283,8 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { void Serializer::SerializeCodeSection() { // Save the number of functions. - strm_->Write(static_cast(vm_->functions.size())); - for (const auto& func : vm_->functions) { + strm_->Write(static_cast(exec_->functions.size())); + for (const auto& func : exec_->functions) { // Serialize the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, @@ -363,77 +300,31 @@ void Serializer::SerializeCodeSection() { } } -tvm::Array Serializer::GetGlobals() const { - tvm::Array ret; - std::vector > globals(vm_->global_map.begin(), - vm_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - for (const auto& it : globals) { - ret.push_back(tvm::ir::StringImm::make(it.first)); - } - return ret; -} - -std::string Serializer::GetBytecode() const { - std::ostringstream oss; - - for (const auto& func : vm_->functions) { - // Print the header of the function format. - oss << "# func name, reg file size, param count, inst count:" - << std::endl; - oss << func.name << " " - << func.register_file_size << " " - << func.params.size() << " " - << func.instructions.size() << std::endl; - - // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; - for (const auto& param : func.params) { - oss << param << " "; - } - oss << std::endl; - - // Print the instructions of a `VMFunction`. - // The part after ";" is the instruction in text format. - oss << "hash, opcode, fields # inst(text):"<< std::endl; - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - oss << std::hex << "0x" << serialized_instr.Hash() << " " - << std::dec << serialized_instr.opcode << " "; - for (auto it : serialized_instr.fields) { - oss << it << " "; - } - oss << " # " << instr; - if (oss.str().back() != '\n') oss << std::endl; - } +void Serializer::SerializeContextSection() { + CHECK(!exec_->ctxs.empty()); + std::vector serialized_ctx; + for (const auto& ctx : exec_->ctxs) { + serialized_ctx.push_back(static_cast(ctx.device_type)); + serialized_ctx.push_back(static_cast(ctx.device_id)); } - - return oss.str(); -} - -runtime::Module Serializer::GetLib() const { - return vm_->lib; + strm_->Write(serialized_ctx); } -runtime::Module CreateSerializer(const VirtualMachine* vm) { - std::shared_ptr exec = std::make_shared(); - exec->Init(vm); - return runtime::Module(exec); +runtime::Module CreateSerializer(const Executable* exec) { + std::shared_ptr serializer = std::make_shared(); + serializer->Init(exec); + return runtime::Module(serializer); } TVM_REGISTER_GLOBAL("relay._vm._Serializer") .set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; - const auto* vm = dynamic_cast(mod.operator->()); - CHECK(vm) << "Virtual machine has not been defined yet." + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "Virtual machine has not been defined yet." << "\n"; - *rv = CreateSerializer(vm); + *rv = CreateSerializer(exec); }); } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm diff --git a/src/relay/backend/vm/serializer.h b/src/runtime/vm/serializer.h similarity index 63% rename from src/relay/backend/vm/serializer.h rename to src/runtime/vm/serializer.h index 2371bb4c94f5..33e64de5c959 100644 --- a/src/relay/backend/vm/serializer.h +++ b/src/runtime/vm/serializer.h @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.h + * \file src/runtime/vm/serializer.h * \brief Define a serializer for the Relay VM. * * The following components of a Relay VM will be serialized: @@ -32,6 +32,8 @@ * - The `primitive_map` that contains the name of individual primitive operators. * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of * a list of instructions/bytecode. + * - The `ctxs` that contains the device context used to execute the hardware + * dependent code. * * Note that only the library is returned as a separate module. All othere parts * are stored in a single serialized code that is organized with the following @@ -41,6 +43,7 @@ * - Primitive name section, containing the function name of the primitive ops * used by the virtual machine. * - Code section, handling the VM functions and bytecode. + * - Context section, saving the context information. * * The code section is again organized as follows for each VM function: * func_name, register_file_size, num_instructions (N) @@ -63,14 +66,11 @@ * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. */ -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_ +#ifndef TVM_RUNTIME_VM_SERIALIZER_H_ +#define TVM_RUNTIME_VM_SERIALIZER_H_ #include #include -#include -#include -#include #include #include @@ -79,8 +79,10 @@ #include #include +#include "serialize_util.h" + namespace tvm { -namespace relay { +namespace runtime { namespace vm { using namespace tvm::runtime; @@ -92,11 +94,11 @@ using namespace tvm::runtime::vm; class Serializer : public runtime::ModuleNode { public: /*! - * \brief Initialize the serializer for a virtual machine. + * \brief Initialize the serializer for an executable. * - * \param vm The Relay virtual machine. + * \param vm The Relay virtual machine executable. */ - inline void Init(const VirtualMachine* vm); + void Init(const Executable* exec); /*! * \brief Return the member function to the frontend. @@ -112,81 +114,33 @@ class Serializer : public runtime::ModuleNode { const char* type_key() const final { return "Serializer"; } /*! - * \brief Print the detailed statistics of the given code, i.e. number of - * globls and constants, etc. - */ - std::string Stats() const; - - /*! - * \brief Serialize the `vm_` into global section, constant section, and code + * \brief Serialize the `exec_` into global section, constant section, and code * section. * * \return The binary representation of the VM. */ TVMByteArray Serialize(); - /*! - * \brief Get a list of the globals used by the `_vm`. - * - * \return The global map in the form a list. - */ - tvm::Array GetGlobals() const; - - /*! - * \brief Get the primitive operators that are contained in the Relay VM. - * - * \return The list of primitve operators. - */ - tvm::Array GetPrimitiveOps() const; - - /*! - * \brief Get the serialized form of the `functions` in `vm_`. This is - * essentially bytecode serialization. - * - * \return The serialized vm bytecode. - * - * \note The bytecode is in the following format: - * func_name reg_file_size num_instructions - * param1 param2 ... paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Each instruction is printed in the following format: - * opcode num_fields field1 ... fieldX # The text format. - * - * The field starting from # is only used for debugging. The serialized code - * doesn't contain it, therefore the deserializer doens't need to handle it. - */ - std::string GetBytecode() const; - - /*! \brief Get the `lib` module in vm_. Serialization of `runtime::module` - * has already been supported by TVM. Therefore, we only return the runtime - * module and let users have the flexibility to call `export_library` from - * the frontend to save the library to disk. - * - * \return The runtime module that contains the hardwre dependent code. - */ - inline runtime::Module GetLib() const; - virtual ~Serializer() { delete strm_; } private: - /*! \brief Serialize the globals in vm_. */ + /*! \brief Serialize the globals in exec_. */ void SerializeGlobalSection(); - /*! \brief Serialize the constant pool in vm_. */ + /*! \brief Serialize the constant pool in exec_. */ void SerializeConstantSection(); - /*! \brief Serialize primitive op names in vm_. */ + /*! \brief Serialize primitive op names in exec_. */ void SerializePrimitiveOpNames(); - /*! \brief Serialize the vm functions in vm_. */ + /*! \brief Serialize the vm functions in exec_. */ void SerializeCodeSection(); - /*! \brief The Relay virtual machine for to be serialized. */ - const VirtualMachine* vm_; + /*! \brief Serialize the context in exec_. */ + void SerializeContextSection(); + + /*! \brief The Relay virtual machine executable to be serialized. */ + const Executable* exec_; /*! \brief The stream used for serialization. */ dmlc::Stream* strm_; @@ -195,8 +149,10 @@ class Serializer : public runtime::ModuleNode { std::string code_; }; +VMInstructionSerializer SerializeInstruction(const Instruction& instr); + } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_ +#endif // TVM_RUNTIME_VM_SERIALIZER_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 7dea9bdb95ea..d4dd8340fbf2 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -575,13 +575,14 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(exec) << "The executable is not created yet."; std::string func_name = args[0]; - auto gvit = this->global_map.find(func_name); - CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; + auto gvit = exec->global_map.find(func_name); + CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; auto func_index = gvit->second; - const auto& vm_func = this->functions[func_index]; + const auto& vm_func = exec->functions[func_index]; const auto& param_names = vm_func.params; - auto ctx = this->GetParamsContext(); + auto ctx = exec->GetParamsContext(); // Prepare the func args std::vector func_args(param_names.size()); @@ -604,67 +605,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, *rv = this->Invoke(vm_func, func_args); }); - } else if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size() % 2, 0); - std::vector contexts; - for (int i = 0; i < args.size() / 2; ++i) { - TVMContext ctx; - int device_type = args[i * 2]; - ctx.device_type = DLDeviceType(device_type); - ctx.device_id = args[i * 2 + 1]; - contexts.push_back(ctx); - } - this->Init(contexts); - }); - } else if (name == "load_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } -TVMContext VirtualMachine::GetParamsContext() const { - // Use the fallback device if no device index is available. - int fallback_device_type = static_cast(ctxs[0].device_type); - // TODO(wweic): For heterogeneous execution, get device information from byte - - const auto& cit = - std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); - return (cit == ctxs.end() ? ctxs[0] : *cit); -} - -void VirtualMachine::LoadParams(const std::string& params) { - dmlc::MemoryStringStream mss(const_cast(¶ms)); - dmlc::Stream* strm = &mss; - uint64_t header, reserved; - CHECK(strm->Read(&header)) << "Invalid parameter file"; - CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file"; - CHECK(strm->Read(&reserved)) << "Invalid parameter file"; - - std::vector names; - CHECK(strm->Read(&names)) << "Invalid parameter file"; - - uint64_t sz; - strm->Read(&sz); - size_t size = static_cast(sz); - CHECK(size == names.size()) << "Invalid parameter file"; - - auto ctx = GetParamsContext(); - for (size_t i = 0; i < size; i++) { - NDArray arr; - CHECK(arr.Load(strm)) << "Invalid parameter file"; - ObjectRef obj = Tensor(arr); - auto copy = CopyTo(obj, ctx); - params_.emplace(std::make_pair(names[i], copy)); - } -} - void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -699,15 +645,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vectorGetAllocator(ctxs[0]); + // TODO(wweic) ctx could be obtained from the ctxs list. + auto alloc = MemoryManager::Global()->GetAllocator(exec->ctxs[0]); DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector& args) { - auto func_index = this->global_map[name]; + CHECK(exec) << "The executable has not been created yet."; + auto func_index = exec->global_map.at(name); DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; - return Invoke(this->functions[func_index], args); + return Invoke(exec->functions[func_index], args); } void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, @@ -744,14 +692,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const std::vector& ctxs) { - this->ctxs = ctxs; +void VirtualMachine::Init(const Executable* exec) { + CHECK(exec) << "The executable is not created yet."; + this->exec = exec; + runtime::Module lib = this->exec->lib; // Get the list of packed functions. - CHECK(primitive_map.empty() || lib.operator->()) + CHECK(exec->primitive_map.empty() || lib.operator->()) << "runtime module should have been built for primitive functions" << "\n"; - for (const auto& it : primitive_map) { + for (const auto& it : this->exec->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast(it.second); if (packed_funcs.size() <= packed_index) { @@ -788,6 +738,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { void VirtualMachine::RunLoop() { CHECK(this->code); + CHECK(this->exec); this->pc = 0; Index frame_start = frames.size(); while (true) { @@ -810,8 +761,9 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { - auto constant_obj = this->constants[instr.const_index]; - auto device_obj = CopyTo(constant_obj, ctxs[0]); + auto constant_obj = exec->constants[instr.const_index]; + // TODO(wweic) ctx could be obtained from the ctxs list. + auto device_obj = CopyTo(constant_obj, exec->ctxs[0]); WriteRegister(instr.dst, device_obj); pc++; goto main_loop; @@ -828,7 +780,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_args; ++i) { args.push_back(ReadRegister(instr.invoke_args_registers[i])); } - InvokeGlobal(this->functions[instr.func_index], args); + InvokeGlobal(exec->functions[instr.func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -858,7 +810,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_closure_args; ++i) { args.push_back(ReadRegister(instr.closure_args[i])); } - InvokeGlobal(this->functions[closure->func_index], args); + InvokeGlobal(exec->functions[closure->func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -910,8 +862,9 @@ void VirtualMachine::RunLoop() { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { shape[i] = instr.alloc_tensor.shape[i]; } - auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); - auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); + // TODO(wweic) ctx could be obtained from the ctxs list. + auto allocator = MemoryManager::Global()->GetAllocator(exec->ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, exec->ctxs[0]); auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; @@ -931,8 +884,9 @@ void VirtualMachine::RunLoop() { auto num_dims = shape_tensor->shape[0]; auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); - auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); - auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); + // TODO(wweic) ctx could be obtained from the ctxs list. + auto allocator = MemoryManager::Global()->GetAllocator(exec->ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, exec->ctxs[0]); auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; @@ -976,6 +930,21 @@ void VirtualMachine::RunLoop() { } } +runtime::Module CreateVirtualMachine(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->Init(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "The virtual machine executable has not been defined yet." + << "\n"; + *rv = CreateVirtualMachine(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cedbc4f71859..63450d60a18e 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + exe.set_context(ctx) + vm = relay.vm.VirtualMachine(exe) return vm.invoke("main", *args) else: assert isinstance(f, relay.Module), "expected expression or module" mod = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + exe.set_context(ctx) + vm = relay.vm.VirtualMachine(exe) ret = vm.invoke("main", *args) return ret @@ -573,25 +575,6 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) -def test_set_params(): - mod = relay.Module() - x = relay.var('x', shape=(10, 5)) - w = relay.var('w', shape=(6, 5)) - b = relay.var('b', shape=(6,)) - y = relay.nn.bias_add(relay.nn.dense(x, w), b) - mod["main"] = relay.Function([x, w, b], y) - vm = relay.vm.compile(mod, 'llvm') - vm.init(tvm.cpu()) - - x_np = np.random.uniform(size=(10, 5)).astype('float32') - w_np = np.random.uniform(size=(6, 5)).astype('float32') - b_np = np.random.uniform(size=(6,)).astype('float32') - ref_np = np.dot(x_np, w_np.T) + b_np - params = {'w': w_np} - vm.load_params(params) - out = vm.run(x_np, b_np) - tvm.testing.assert_allclose(out.asnumpy(), ref_np) - if __name__ == "__main__": test_id() @@ -626,4 +609,3 @@ def test_set_params(): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() - test_set_params() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 3a317fc2d111..72b82db21ea1 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -28,23 +28,22 @@ from tvm.contrib import util from tvm.relay import testing -def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): +def create_exec(f, ctx=tvm.cpu(), target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = _vm.compile(mod, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(mod, target=target, params=params) + executable.set_context(ctx) + return executable else: assert isinstance(f, relay.Module), "expected mod as relay.Module" - vm = _vm.compile(f, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(f, target=target, params=params) + executable.set_context(ctx) + return executable def veval(vm, *args, ctx=tvm.cpu()): assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" - vm.init(ctx) ret = vm.run(*args) return ret @@ -59,13 +58,13 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_vm(mod, ctx, target, params=params) + vm = create_exec(mod, ctx, target, params=params) ser = serializer.Serializer(vm) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() - des_vm.init(ctx) - des_vm.load_params(params) + des_exec = deser.deserialize() + des_exec.set_context(ctx) + des_vm = _vm.VirtualMachine(des_exec) result = des_vm.run(data) return result.asnumpy().astype(dtype) @@ -99,25 +98,25 @@ def test_serializer(): main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) mod["main"] = main - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) - glbs = ser.globals + glbs = exe.globals assert len(glbs) == 3 assert "f1" in glbs assert "f2" in glbs assert "main" in glbs - prim_ops = ser.primitive_ops + prim_ops = exe.primitive_ops assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops) - code = ser.bytecode + code = exe.bytecode assert "main 5 2 5" in code assert "f1 2 1 3" in code assert "f2 2 1 3" in code + ser = serializer.Serializer(exe) code, lib = ser.serialize() assert isinstance(code, bytearray) assert isinstance(lib, tvm.module.Module) @@ -129,7 +128,7 @@ def test_save_load(): x_data = np.random.rand(10, 10).astype('float32') # serialize. - vm = create_vm(f) + vm = create_exec(f) ser = serializer.Serializer(vm) code, lib = ser.serialize() assert isinstance(code, bytearray) @@ -138,15 +137,16 @@ def test_save_load(): tmp = util.tempdir() path_lib = tmp.relpath("lib.so") lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: + with open(tmp.relpath("code.ro"), "wb") as fo: fo.write(code) loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. deser = deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -156,12 +156,13 @@ def test_const(): c = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) - vm = create_vm(f) - ser = serializer.Serializer(vm) + exe = create_exec(f) + ser = serializer.Serializer(exe) code, lib = ser.serialize() assert isinstance(code, bytearray) deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) x_data = np.random.rand(10, 10).astype('float32') res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -177,11 +178,12 @@ def test_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) + exe = create_exec(f) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) # same res = veval(des_vm, x_data, x_data) @@ -213,11 +215,12 @@ def test_loop(): aarg = relay.var('accum', shape=[], dtype='int32') mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) result = veval(des_vm, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -230,11 +233,12 @@ def test_tuple(): i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) + exe = create_exec(f) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) result = veval(des_vm, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -251,11 +255,12 @@ def test_adt_list(): f = relay.Function([], l321) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) result = veval(des_vm) assert len(result) == 2 @@ -297,11 +302,12 @@ def test_adt_compose(): f = relay.Function([y], add_two_body) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) x_data = np.array(np.random.rand()).astype('float32') result = veval(des_vm, x_data) @@ -317,11 +323,12 @@ def test_closure(): clo = ff(relay.const(1.0)) main = clo(relay.const(2.0)) - vm = create_vm(main) - ser = serializer.Serializer(vm) + exe = create_exec(main) + ser = serializer.Serializer(exe) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = deser.deserialize() + des_vm = _vm.VirtualMachine(des_exec) res = veval(des_vm) tvm.testing.assert_allclose(res.asnumpy(), 3.0) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index b5ce0ec70e51..531dd28dc41b 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -26,9 +26,9 @@ def test_basic(): mod, params = resnet.get_workload() target = 'llvm' ctx = tvm.cpu() - vm = relay.profiler_vm.compile(mod, target) - vm.init(ctx) - vm.load_params(params) + exe = relay.profiler_vm.compile(mod, target, params=params) + exe.set_context(ctx) + vm = relay.profiler_vm.VirtualMachineProfiler(exe) data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data]) From e6ecc0749c6aa4f2d20e6695ed23842687acf89c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 14 Oct 2019 00:43:34 +0000 Subject: [PATCH 2/6] Address comments --- include/tvm/runtime/vm.h | 7 ++++--- src/runtime/vm/executable.cc | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 14816e84c6ab..05dc4691e442 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -434,7 +434,7 @@ struct VMFrame { /*! \brief The executable emitted by the VM compiler. * * The executable contains information (e.g. data in different memory regions) - * to create a virtual machine. + * to run in a virtual machine. */ class Executable : public ModuleNode { public: @@ -450,7 +450,7 @@ class Executable : public ModuleNode { const std::shared_ptr& sptr_to_self) final; /*! - * \brief Get the serialized form of the `functions` in `vm_`. This is + * \brief Get the serialized form of the `functions`. This is * essentially bytecode serialization. * * \return The serialized vm bytecode. @@ -501,7 +501,8 @@ class Executable : public ModuleNode { return "VMExecutable"; } - /*! \brief The runtime module/library that contains hardware dependent code. */ + /*! \brief The runtime module/library that contains both the host and also the device + * code when executing on non-CPU devices. */ runtime::Module lib; /*! \brief The global constant pool. */ std::vector constants; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 9c85922926c4..5cccbcef4988 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -87,7 +87,7 @@ std::string Executable::GetBytecode() const { << func.instructions.size() << std::endl; // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; + oss << "# Parameters: "<< std::endl; for (const auto& param : func.params) { oss << param << " "; } From e9c22386e66e00b6d2861ade8036bd3d680fddfc Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 16 Oct 2019 18:41:41 +0000 Subject: [PATCH 3/6] move ctx back to vm --- include/tvm/runtime/vm.h | 30 +++++------ python/tvm/relay/backend/profiler_vm.py | 1 + python/tvm/relay/backend/serializer.py | 3 +- python/tvm/relay/backend/vm.py | 42 +++++---------- src/runtime/vm/deserializer.cc | 15 ------ src/runtime/vm/executable.cc | 34 +------------ src/runtime/vm/profiler/vm.cc | 25 +++++++-- src/runtime/vm/profiler/vm.h | 4 +- src/runtime/vm/serializer.cc | 13 ----- src/runtime/vm/serializer.h | 6 --- src/runtime/vm/vm.cc | 51 +++++++++++++++---- tests/python/relay/test_vm.py | 4 +- tests/python/relay/test_vm_serialization.py | 16 ++++-- .../unittest/test_runtime_vm_profiler.py | 2 +- 14 files changed, 110 insertions(+), 136 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 05dc4691e442..39ac86ae5c0e 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -484,17 +484,6 @@ class Executable : public ModuleNode { */ runtime::Module GetLib() const { return lib; } - /*! - * \brief Set the execution context for the executable. - * - * \param ctxs The list of TVMContext. - */ - void SetContext(const std::vector& ctxs); - - /*! \brief Get device context for params. - */ - TVMContext GetParamsContext() const; - virtual ~Executable() {} const char* type_key() const final { @@ -514,9 +503,6 @@ class Executable : public ModuleNode { std::unordered_map primitive_map; /*! \brief The virtual machine's function table. */ std::vector functions; - - /*! \brief The set of TVM contexts the VM is currently executing on. */ - std::vector ctxs; }; /*! \brief The virtual machine. @@ -591,6 +577,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The executable the VM will operate on. */ const Executable* exec; + /*! \brief The set of TVM contexts the VM is currently executing on. */ + std::vector ctxs; + /*! \brief Push a call frame on to the call stack. */ void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func); /*! \brief Pop a frame off the call stack. @@ -634,15 +623,24 @@ class VirtualMachine : public runtime::ModuleNode { VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} - /*! \brief Initialize the virtual machine using an executable. + /*! \brief load the executable for the virtual machine. * \param exec The executable. */ - void Init(const Executable* exec); + void LoadExecutable(const Executable* exec); + + /*! \brief Initialize the virtual machine for a set of contexts. + * \param contexts The set of TVM contexts. + */ + void Init(const std::vector& contexts); /*! \brief Run VM dispatch loop. */ void RunLoop(); + /*! \brief Get device context for params. + */ + TVMContext GetParamsContext() const; + private: /*! \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 30c99611b7c9..b36715249f0a 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -77,6 +77,7 @@ def __init__(self, mod): super().__init__(mod) m = mod.module if isinstance(mod, vm.Executable) else mod self.mod = _vm._VirtualMachineDebug(m) + self._init = self.mod["init"] self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py index 4680ada2eaac..7bee54c2d34d 100644 --- a/python/tvm/relay/backend/serializer.py +++ b/python/tvm/relay/backend/serializer.py @@ -96,7 +96,6 @@ def serialize(self): ctx = tvm.cpu() target = "llvm" executable = relay.vm..compile(mod, target) - executable.set_context(ctx) # serialize. ser = relay.serializer.Serializer(executable) @@ -117,7 +116,7 @@ def serialize(self): des_exec = deser.deserialize() # execute the deserialized executable. - des_exec.set_context(ctx) + des_vm.init(ctx) x_data = np.random.rand(10, 10).astype('float32') des_vm = relay.vm.VirtualMachine(des_exec) res = des_vm.run(x_data) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 142e2b7d32e6..7c14965d37d9 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,7 +24,6 @@ import tvm from tvm import autotvm -from tvm import TVMContext from tvm.relay import expr as _expr from . import _vm from . import vmobj as _obj @@ -57,37 +56,10 @@ class Executable(object): """Relay VM executable""" def __init__(self, mod): self.mod = mod - self._set_context = self.mod["set_context"] self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] self._get_stats = self.mod["get_stats"] - def set_context(self, ctx): - """Initialize the context of the VM executable. - - Parameters - ---------- - ctx : Union[:py:class:`tvm.TVMContext`, List[py:class:`tvm.TVMContext`]] - The runtime context to run the code on. - """ - - if isinstance(ctx, TVMContext): - ctx = [ctx] - elif not isinstance(ctx, (list, tuple)): - raise ValueError("ctx has to be the type of TVMContext or a list of " - "TVMContext") - # args[0], args[1] are used as the primary/fallback context type and id - # for heterogeneous execution. - args = [] - for cur_ctx in ctx: - if not isinstance(cur_ctx, TVMContext): - raise ValueError("ctx has to be the type of TVMContext or a list " - "of TVMContext") - args.append(cur_ctx.device_type) - args.append(cur_ctx.device_id) - - self._set_context(*args) - @property def lib(self): """Get the library that contains hardware dependent code. @@ -182,8 +154,20 @@ def __init__(self, mod): "tvm.Module, but received {}".format(type(mod))) m = mod.module if isinstance(mod, Executable) else mod self.mod = _vm._VirtualMachine(m) + self._init = self.mod["init"] self._invoke = self.mod["invoke"] + def init(self, ctx): + """Initialize the context in the VM. + + Parameters + ---------- + ctx : :py:class:`TVMContext` + The runtime context to run the code on. + """ + args = [ctx.device_type, ctx.device_id] + self._init(*args) + def invoke(self, func_name, *args): """Invoke a function. @@ -344,8 +328,8 @@ def __init__(self, mod, ctx, target): self.ctx = ctx self.target = target self.executable = compile(mod, target) - self.executable.set_context(ctx) self.vm = VirtualMachine(self.executable) + self.vm.init(ctx) def _make_executor(self, expr=None): main = self.mod["main"] diff --git a/src/runtime/vm/deserializer.cc b/src/runtime/vm/deserializer.cc index eb191545bcc6..1a748d2523e7 100644 --- a/src/runtime/vm/deserializer.cc +++ b/src/runtime/vm/deserializer.cc @@ -82,9 +82,6 @@ void Deserializer::Deserialize() { // Code section. DeserializeCodeSection(); - - // Context section. - DeserializeContextSection(); } void Deserializer::DeserializeGlobalSection() { @@ -313,18 +310,6 @@ void Deserializer::DeserializeCodeSection() { } } -void Deserializer::DeserializeContextSection() { - std::vector ctxs; - STREAM_CHECK(strm_->Read(&ctxs), "context"); - CHECK_EQ(ctxs.size() % 2, 0U); - for (size_t i = 0; i < ctxs.size(); i += 2) { - TVMContext ctx; - ctx.device_type = DLDeviceType(ctxs[i]); - ctx.device_id = static_cast(ctxs[i + 1]); - exec_->ctxs.push_back(ctx); - } -} - runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { std::shared_ptr exec = std::make_shared(); exec->Init(code, lib); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 5cccbcef4988..f413908d3007 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -39,20 +39,7 @@ namespace vm { PackedFunc Executable::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { - if (name == "set_context") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size() % 2, 0); - std::vector contexts; - for (int i = 0; i < args.size() / 2; ++i) { - TVMContext ctx; - int device_type = args[i * 2]; - ctx.device_type = DLDeviceType(device_type); - ctx.device_id = args[i * 2 + 1]; - contexts.push_back(ctx); - } - this->SetContext(contexts); - }); - } else if (name == "get_lib") { + if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); @@ -70,10 +57,6 @@ PackedFunc Executable::GetFunction(const std::string& name, } } -inline void Executable::SetContext(const std::vector& ctxs) { - this->ctxs = ctxs; -} - std::string Executable::GetBytecode() const { std::ostringstream oss; @@ -166,20 +149,6 @@ std::string Executable::Stats() const { return oss.str(); } -TVMContext Executable::GetParamsContext() const { - CHECK(!ctxs.empty()) << "context has not been set yet."; - - // Use the fallback device if no device index is available. - int fallback_device_type = static_cast(ctxs[0].device_type); - // TODO(wweic): For heterogeneous execution, get device information from byte - - const auto& cit = - std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); - return (cit == ctxs.end() ? ctxs[0] : *cit); -} - TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") .set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; @@ -188,7 +157,6 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") *rv = static_cast(exec->global_map.size()); }); - TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") .set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index dc5ec9943b3b..821de0bda245 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -67,13 +67,26 @@ PackedFunc VirtualMachineDebug::GetFunction( os << "Total Duration " << total_duration << " us" << std::endl; *rv = os.str(); }); + } else if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0); + std::vector contexts; + for (int i = 0; i < args.size() / 2; ++i) { + TVMContext ctx; + int device_type = args[i * 2]; + ctx.device_type = DLDeviceType(device_type); + ctx.device_id = args[i * 2 + 1]; + contexts.push_back(ctx); + } + this->Init(contexts); + }); } else { return VirtualMachine::GetFunction(name, sptr_to_self); } } -void VirtualMachineDebug::Init(const Executable* exec) { - VirtualMachine::Init(exec); +void VirtualMachineDebug::LoadExecutable(const Executable* exec) { + VirtualMachine::LoadExecutable(exec); CHECK(this->exec); for (auto kv : this->exec->primitive_map) { packed_index_map[kv.second] = kv.first; @@ -81,12 +94,16 @@ void VirtualMachineDebug::Init(const Executable* exec) { } } +void VirtualMachineDebug::Init(const std::vector& ctxs) { + VirtualMachine::Init(ctxs); +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { CHECK(this->exec); - auto ctx = this->exec->GetParamsContext(); + auto ctx = this->GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); @@ -108,7 +125,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, runtime::Module CreateVirtualMachineDebug(const Executable* exec) { std::shared_ptr vm = std::make_shared(); - vm->Init(exec); + vm->LoadExecutable(exec); return runtime::Module(vm); } diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index d2e71597e80d..ff3296cb6c16 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -47,11 +47,13 @@ class VirtualMachineDebug : public VirtualMachine { void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; - void Init(const Executable* exec); + void LoadExecutable(const Executable* exec); ~VirtualMachineDebug() {} private: + void Init(const std::vector& ctxs); + std::unordered_map packed_index_map; std::unordered_map> op_durations; std::unordered_map op_invokes; diff --git a/src/runtime/vm/serializer.cc b/src/runtime/vm/serializer.cc index 3d52abdc965b..0d7fb2b2f2e7 100644 --- a/src/runtime/vm/serializer.cc +++ b/src/runtime/vm/serializer.cc @@ -81,9 +81,6 @@ TVMByteArray Serializer::Serialize() { // Code section. SerializeCodeSection(); - // Context section. - SerializeContextSection(); - TVMByteArray arr; arr.data = code_.c_str(); arr.size = code_.length(); @@ -300,16 +297,6 @@ void Serializer::SerializeCodeSection() { } } -void Serializer::SerializeContextSection() { - CHECK(!exec_->ctxs.empty()); - std::vector serialized_ctx; - for (const auto& ctx : exec_->ctxs) { - serialized_ctx.push_back(static_cast(ctx.device_type)); - serialized_ctx.push_back(static_cast(ctx.device_id)); - } - strm_->Write(serialized_ctx); -} - runtime::Module CreateSerializer(const Executable* exec) { std::shared_ptr serializer = std::make_shared(); serializer->Init(exec); diff --git a/src/runtime/vm/serializer.h b/src/runtime/vm/serializer.h index 33e64de5c959..b3c893878bf6 100644 --- a/src/runtime/vm/serializer.h +++ b/src/runtime/vm/serializer.h @@ -32,8 +32,6 @@ * - The `primitive_map` that contains the name of individual primitive operators. * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of * a list of instructions/bytecode. - * - The `ctxs` that contains the device context used to execute the hardware - * dependent code. * * Note that only the library is returned as a separate module. All othere parts * are stored in a single serialized code that is organized with the following @@ -43,7 +41,6 @@ * - Primitive name section, containing the function name of the primitive ops * used by the virtual machine. * - Code section, handling the VM functions and bytecode. - * - Context section, saving the context information. * * The code section is again organized as follows for each VM function: * func_name, register_file_size, num_instructions (N) @@ -136,9 +133,6 @@ class Serializer : public runtime::ModuleNode { /*! \brief Serialize the vm functions in exec_. */ void SerializeCodeSection(); - /*! \brief Serialize the context in exec_. */ - void SerializeContextSection(); - /*! \brief The Relay virtual machine executable to be serialized. */ const Executable* exec_; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index d4dd8340fbf2..78b74768b930 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -582,7 +582,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, auto func_index = gvit->second; const auto& vm_func = exec->functions[func_index]; const auto& param_names = vm_func.params; - auto ctx = exec->GetParamsContext(); + auto ctx = this->GetParamsContext(); // Prepare the func args std::vector func_args(param_names.size()); @@ -605,12 +605,40 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, *rv = this->Invoke(vm_func, func_args); }); + } else if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0); + std::vector contexts; + for (int i = 0; i < args.size() / 2; ++i) { + TVMContext ctx; + int device_type = args[i * 2]; + ctx.device_type = DLDeviceType(device_type); + ctx.device_id = args[i * 2 + 1]; + contexts.push_back(ctx); + } + this->Init(contexts); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } +TVMContext VirtualMachine::GetParamsContext() const { + CHECK(!ctxs.empty()) << "Context has not been initialized yet." + << "\n"; + + // Use the fallback device if no device index is available. + int fallback_device_type = static_cast(ctxs[0].device_type); + // TODO(wweic): For heterogeneous execution, get device information from byte + + const auto& cit = + std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { + return fallback_device_type == static_cast(c.device_type); + }); + return (cit == ctxs.end() ? ctxs[0] : *cit); +} + void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -646,7 +674,7 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vectorGetAllocator(exec->ctxs[0]); + auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } @@ -692,7 +720,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const Executable* exec) { +void VirtualMachine::LoadExecutable(const Executable* exec) { CHECK(exec) << "The executable is not created yet."; this->exec = exec; @@ -711,6 +739,11 @@ void VirtualMachine::Init(const Executable* exec) { } } + +void VirtualMachine::Init(const std::vector& ctxs) { + this->ctxs = ctxs; +} + inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames.back().register_file[r] = val; } @@ -763,7 +796,7 @@ void VirtualMachine::RunLoop() { case Opcode::LoadConst: { auto constant_obj = exec->constants[instr.const_index]; // TODO(wweic) ctx could be obtained from the ctxs list. - auto device_obj = CopyTo(constant_obj, exec->ctxs[0]); + auto device_obj = CopyTo(constant_obj, ctxs[0]); WriteRegister(instr.dst, device_obj); pc++; goto main_loop; @@ -863,8 +896,8 @@ void VirtualMachine::RunLoop() { shape[i] = instr.alloc_tensor.shape[i]; } // TODO(wweic) ctx could be obtained from the ctxs list. - auto allocator = MemoryManager::Global()->GetAllocator(exec->ctxs[0]); - auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, exec->ctxs[0]); + auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; @@ -885,8 +918,8 @@ void VirtualMachine::RunLoop() { auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); // TODO(wweic) ctx could be obtained from the ctxs list. - auto allocator = MemoryManager::Global()->GetAllocator(exec->ctxs[0]); - auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, exec->ctxs[0]); + auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); + auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; @@ -932,7 +965,7 @@ void VirtualMachine::RunLoop() { runtime::Module CreateVirtualMachine(const Executable* exec) { std::shared_ptr vm = std::make_shared(); - vm->Init(exec); + vm->LoadExecutable(exec); return runtime::Module(vm); } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 63450d60a18e..1b40f894db08 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -48,15 +48,15 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): mod = relay.Module() mod["main"] = f exe = relay.vm.compile(mod, target) - exe.set_context(ctx) vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) return vm.invoke("main", *args) else: assert isinstance(f, relay.Module), "expected expression or module" mod = f exe = relay.vm.compile(mod, target) - exe.set_context(ctx) vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) ret = vm.invoke("main", *args) return ret diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 72b82db21ea1..d7effade913c 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -28,17 +28,15 @@ from tvm.contrib import util from tvm.relay import testing -def create_exec(f, ctx=tvm.cpu(), target="llvm", params=None): +def create_exec(f, target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f executable = _vm.compile(mod, target=target, params=params) - executable.set_context(ctx) return executable else: assert isinstance(f, relay.Module), "expected mod as relay.Module" executable = _vm.compile(f, target=target, params=params) - executable.set_context(ctx) return executable @@ -58,13 +56,13 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_exec(mod, ctx, target, params=params) + vm = create_exec(mod, target, params=params) ser = serializer.Serializer(vm) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() - des_exec.set_context(ctx) des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(ctx) result = des_vm.run(data) return result.asnumpy().astype(dtype) @@ -147,6 +145,7 @@ def test_save_load(): deser = deserializer.Deserializer(loaded_code, loaded_lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -163,6 +162,7 @@ def test_const(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -184,6 +184,7 @@ def test_if(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) # same res = veval(des_vm, x_data, x_data) @@ -221,6 +222,7 @@ def test_loop(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -239,6 +241,7 @@ def test_tuple(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -261,6 +264,7 @@ def test_adt_list(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm) assert len(result) == 2 @@ -308,6 +312,7 @@ def test_adt_compose(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.array(np.random.rand()).astype('float32') result = veval(des_vm, x_data) @@ -329,6 +334,7 @@ def test_closure(): deser = deserializer.Deserializer(code, lib) des_exec = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm) tvm.testing.assert_allclose(res.asnumpy(), 3.0) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 531dd28dc41b..53f573730576 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -27,8 +27,8 @@ def test_basic(): target = 'llvm' ctx = tvm.cpu() exe = relay.profiler_vm.compile(mod, target, params=params) - exe.set_context(ctx) vm = relay.profiler_vm.VirtualMachineProfiler(exe) + vm.init(ctx) data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data]) From b8a8fcdf361aefbf024767f3219efc24ef189ada Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 16 Oct 2019 23:15:51 +0000 Subject: [PATCH 4/6] make only vm related fields and methods protected --- include/tvm/runtime/vm.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 39ac86ae5c0e..0f32d0a07859 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -560,6 +560,14 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} + + /*! \brief load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(const Executable* exec); + + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; /*! \brief The current stack of call frames. */ @@ -621,13 +629,6 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} - - /*! \brief load the executable for the virtual machine. - * \param exec The executable. - */ - void LoadExecutable(const Executable* exec); - /*! \brief Initialize the virtual machine for a set of contexts. * \param contexts The set of TVM contexts. */ From 764b34c2b51ae87dfee6892c7828b9cbff5f1c38 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 17 Oct 2019 03:05:27 +0000 Subject: [PATCH 5/6] integrate seriliaztion/deserialization to executable --- include/tvm/runtime/vm.h | 67 +++ python/tvm/relay/__init__.py | 2 - python/tvm/relay/backend/deserializer.py | 81 --- python/tvm/relay/backend/serializer.py | 125 ----- python/tvm/relay/backend/vm.py | 94 ++++ src/runtime/vm/deserializer.cc | 324 ------------ src/runtime/vm/deserializer.h | 105 ---- src/runtime/vm/executable.cc | 523 +++++++++++++++++++- src/runtime/vm/serializer.cc | 317 ------------ src/runtime/vm/serializer.h | 152 ------ tests/python/relay/test_vm_serialization.py | 60 +-- 11 files changed, 703 insertions(+), 1147 deletions(-) delete mode 100644 python/tvm/relay/backend/deserializer.py delete mode 100644 python/tvm/relay/backend/serializer.py delete mode 100644 src/runtime/vm/deserializer.cc delete mode 100644 src/runtime/vm/deserializer.h delete mode 100644 src/runtime/vm/serializer.cc delete mode 100644 src/runtime/vm/serializer.h diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 0f32d0a07859..011826e6f1a1 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -435,6 +435,12 @@ struct VMFrame { * * The executable contains information (e.g. data in different memory regions) * to run in a virtual machine. + * + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. */ class Executable : public ModuleNode { public: @@ -449,6 +455,24 @@ class Executable : public ModuleNode { PackedFunc GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) final; + /*! + * \brief Serialize the executable into global section, constant section, and + * code section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Save(); + + /*! + * \brief Load the saved VM executable. + * + * \param code The bytecode in string. + * \param lib The compiled runtime library. + * + * \return exe The constructed executable. + */ + static runtime::Module Load(const std::string& code, const runtime::Module lib); + /*! * \brief Get the serialized form of the `functions`. This is * essentially bytecode serialization. @@ -466,6 +490,18 @@ class Executable : public ModuleNode { * Each instruction is printed in the following format: * opcode num_fields field1 ... fieldX # The text format. * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + * The field starting from # is only used for debugging. The serialized code * doesn't contain it, therefore the deserializer doens't need to handle it. */ @@ -503,6 +539,37 @@ class Executable : public ModuleNode { std::unordered_map primitive_map; /*! \brief The virtual machine's function table. */ std::vector functions; + + private: + /*! \brief Save the globals. */ + void SaveGlobalSection(); + + /*! \brief Save the constant pool. */ + void SaveConstantSection(); + + /*! \brief Save primitive op names. */ + void SavePrimitiveOpNames(); + + /*! \brief Save the vm functions. */ + void SaveCodeSection(); + + /*! \brief Load the globals. */ + void LoadGlobalSection(); + + /*! \brief Load the constant pool. */ + void LoadConstantSection(); + + /*! \brief Load primitive op names. */ + void LoadPrimitiveOpNames(); + + /*! \brief Load the vm functions.*/ + void LoadCodeSection(); + + /*! \brief The stream used for serialization. */ + dmlc::Stream* strm_; + + /*! \brief The serialized code. */ + std::string code_; }; /*! \brief The virtual machine. diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index ceb98c4d251e..fff9c99e5007 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -37,8 +37,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import serializer -from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py deleted file mode 100644 index b5e2353be501..000000000000 --- a/python/tvm/relay/backend/deserializer.py +++ /dev/null @@ -1,81 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine deserializer. - -Python interface for deserializing a Relay VM. -""" -from tvm import module -from tvm._ffi.runtime_ctypes import TVMByteArray -from . import _vm -from . import vm as rly_vm - -def _create_deserializer(code, lib): - """Create a deserializer object. - - Parameters - ---------- - code : bytearray - The serialized virtual machine bytecode. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - - Returns - ------- - ret : Deserializer - The created virtual machine executable deserializer. - """ - if isinstance(code, (bytes, str)): - code = bytearray(code) - elif not isinstance(code, (bytearray, TVMByteArray)): - raise TypeError("vm is expected to be the type of bytearray or " + - "TVMByteArray, but received {}".format(type(code))) - - if not isinstance(lib, module.Module): - raise TypeError("lib is expected to be the type of tvm.module.Module" + - ", but received {}".format(type(lib))) - return _vm._Deserializer(code, lib) - - -class Deserializer: - """Relay VM executable deserializer. - - Parameters - ---------- - code : bytearray - The serialized virtual machine bytecode. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - """ - def __init__(self, code, lib): - self.mod = _create_deserializer(code, lib) - self._deserialize = self.mod["deserialize"] - - def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM executable. - - Returns - ------- - ret : Executable - The deserialized Relay VM executable. - """ - return rly_vm.Executable(self._deserialize()) diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py deleted file mode 100644 index 7bee54c2d34d..000000000000 --- a/python/tvm/relay/backend/serializer.py +++ /dev/null @@ -1,125 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine serializer. - -Python interface for serializing a Relay VM. -""" -import tvm -from . import _vm -from . import vm as rly_vm - -def _create_serializer(executable): - """Create a VM serializer. - - Parameters - ---------- - executable : Union[Executable, :py:class:`~tvm.module.Module`] - The virtual machine executable to be serialized. - - Returns - ------- - ret : Serializer - The created virtual machine executable serializer. - """ - if isinstance(executable, rly_vm.Executable): - executable = executable.module - elif not isinstance(executable, tvm.module.Module): - raise TypeError("executable is expected to be an Executable or " + - "tvm.Module, but received {}".format(type(executable))) - - return _vm._Serializer(executable) - - -class Serializer: - """Relay VM serializer.""" - def __init__(self, executable): - self.mod = _create_serializer(executable) - self._get_lib = self.mod["get_lib"] - self._serialize = self.mod["serialize"] - - def serialize(self): - """Serialize the Relay VM. - - Returns - ------- - code : bytearray - The binary blob representing a serialized Relay VM. It can then be - saved to disk and later deserialized into a new VM. - - lib : :py:class:`~tvm.module.Module` - The runtime module that contains the generated code. It is - basically a library that is composed of hardware dependent code. - - Notes - ----- - The returned code is organized with the following sections in order. - - Global section. This section contains the globals used by the - virtual machine. - - Constant section. This section is used to store the constant pool of - a virtual machine. - - Primitive name section. This section is introduced to accommodate - the list of primitive operator names that will be invoked by the - virtual machine. - - Code section. The VM functions, including bytecode, are sitting in - this section. - - Examples - -------- - .. code-block:: python - - import numpy as np - import tvm - from tvm import relay - - # define a simple network. - x = relay.var('x', shape=(10, 10)) - f = relay.Function([x], x + x) - mod = relay.Module({"main": f}) - - # create a Relay VM. - ctx = tvm.cpu() - target = "llvm" - executable = relay.vm..compile(mod, target) - - # serialize. - ser = relay.serializer.Serializer(executable) - code, lib = ser.serialize() - - # save and load the code and lib file. - tmp = tvm.contrib.util.tempdir() - path_lib = tmp.relpath("lib.so") - lib.export_library(path_lib) - with open(tmp.relpath("code.ro"), "wb") as fo: - fo.write(code) - - loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) - - # deserialize. - deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_exec = deser.deserialize() - - # execute the deserialized executable. - des_vm.init(ctx) - x_data = np.random.rand(10, 10).astype('float32') - des_vm = relay.vm.VirtualMachine(des_exec) - res = des_vm.run(x_data) - print(res.asnumpy()) - """ - return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 7c14965d37d9..942c93b866f4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -25,6 +25,7 @@ import tvm from tvm import autotvm from tvm.relay import expr as _expr +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -56,10 +57,103 @@ class Executable(object): """Relay VM executable""" def __init__(self, mod): self.mod = mod + self._save = self.mod["save"] self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] self._get_stats = self.mod["get_stats"] + def save(self): + """Save the Relay VM Executable. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM executable. It + can then be saved to disk and later deserialized into a new + Executable. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + executable = relay.vm.compile(mod, target) + code, lib = executable.save() + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.ro"), "wb") as fo: + fo.write(code) + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) + # deserialize. + des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code) + # execute the deserialized executable. + x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) + des_vm.init(ctx) + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._save(), self._get_lib() + + @staticmethod + def load_exec(bytecode, lib): + """Construct an executable from saved artifacts. + + Parameters + ---------- + bytecode : bytearray + The binary blob representing a the Relay VM bytecode. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. + + Returns + ------- + exec: Executable + An executable constructed using the provided artifacts. + """ + if isinstance(bytecode, (bytes, str)): + code = bytearray(bytecode) + elif not isinstance(bytecode, (bytearray, TVMByteArray)): + raise TypeError("bytecode is expected to be the type of bytearray " + + "or TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, tvm.module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + + return Executable(_vm.Load_Executable(bytecode, lib)) + @property def lib(self): """Get the library that contains hardware dependent code. diff --git a/src/runtime/vm/deserializer.cc b/src/runtime/vm/deserializer.cc deleted file mode 100644 index 1a748d2523e7..000000000000 --- a/src/runtime/vm/deserializer.cc +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/runtime/vm/deserializer.cc - * \brief Implementation of APIs to deserialize the serialized VM executable. - */ - -#include "deserializer.h" - -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace runtime { -namespace vm { - -#define STREAM_CHECK(val, section) \ - CHECK(val) << "Invalid VM file format in the " << section << " section." \ - << "\n"; - -inline void Deserializer::Init(const std::string& code, const runtime::Module& lib) { - code_ = code; - exec_ = std::make_shared(); - exec_->lib = lib; - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Deserializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "deserialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Deserialize(); - *rv = runtime::Module(exec_); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -void Deserializer::Deserialize() { - // Check header. - uint64_t header; - STREAM_CHECK(strm_->Read(&header), "header"); - STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); - - // Check version. - std::string version; - STREAM_CHECK(strm_->Read(&version), "version"); - STREAM_CHECK(version == TVM_VERSION, "version"); - - // Global section. - DeserializeGlobalSection(); - - // Constant section. - DeserializeConstantSection(); - - // Primitive names that will be invoked by `InvokePacked` instructions. - DeserializePrimitiveOpNames(); - - // Code section. - DeserializeCodeSection(); -} - -void Deserializer::DeserializeGlobalSection() { - std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); - for (size_t i = 0; i < globals.size(); i++) { - exec_->global_map.insert({globals[i], i}); - } -} - -void Deserializer::DeserializeConstantSection() { - uint64_t sz; - // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); - - size_t size = static_cast(sz); - // Load each of the constants. - for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - exec_->constants.push_back(obj); - } -} - -void Deserializer::DeserializePrimitiveOpNames() { - std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); - for (size_t i = 0; i < primitive_names.size(); i++) { - exec_->primitive_map.insert({primitive_names[i], i}); - } -} - -// Extract the `cnt` number of fields started at `start` from the list -// `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, - Index cnt) { - CHECK_LE(static_cast(start + cnt), instr_fields.size()); - std::vector ret; - for (auto i = start; i < start + cnt; i++) { - ret.push_back(instr_fields[i]); - } - return ret; -} - -Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { - Opcode opcode = static_cast(instr.opcode); - switch (opcode) { - case Opcode::Move: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::Move(instr.fields[0], instr.fields[1]); - } - case Opcode::Ret: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Ret(instr.fields[0]); - } - case Opcode::Fatal: { - // Number of fields = 0 - DCHECK(instr.fields.empty()); - return Instruction::Fatal(); - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index packed_index = instr.fields[0]; - Index arity = instr.fields[1]; - Index output_size = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, arity); - return Instruction::InvokePacked(packed_index, arity, output_size, args); - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 5U); - DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); - - DLDataType dtype; - dtype.code = instr.fields[0]; - dtype.bits = instr.fields[1]; - dtype.lanes = instr.fields[2]; - - Index ndim = instr.fields[3]; - RegName dst = instr.fields[4]; - - std::vector shape = ExtractFields(instr.fields, 5, ndim); - - return Instruction::AllocTensor(shape, dtype, dst); - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 5U); - Index shape_register = instr.fields[0]; - - DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; - - RegName dst = instr.fields[4]; - - return Instruction::AllocTensorReg(shape_register, dtype, dst); - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index constructor_tag = instr.fields[0]; - Index num_fields = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector fields = ExtractFields(instr.fields, 3, num_fields); - - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index clo_index = instr.fields[0]; - Index num_freevar = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); - - return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); - } - case Opcode::If: { - // Number of fields = 4 - DCHECK_EQ(instr.fields.size(), 4U); - Index test = instr.fields[0]; - Index target = instr.fields[1]; - Index true_offset = instr.fields[2]; - Index false_offset = instr.fields[3]; - - return Instruction::If(test, target, true_offset, false_offset); - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index func_index = instr.fields[0]; - Index num_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_args); - - return Instruction::Invoke(func_index, args, dst); - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index closure = instr.fields[0]; - Index num_closure_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_closure_args); - - return Instruction::InvokeClosure(closure, args, dst); - } - case Opcode::LoadConst: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConst(instr.fields[0], instr.fields[1]); - } - case Opcode::LoadConsti: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); - } - case Opcode::GetField: { - // Number of fields = 3 - DCHECK_EQ(instr.fields.size(), 3U); - return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); - } - case Opcode::GetTag: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::GetTag(instr.fields[0], instr.fields[1]); - } - case Opcode::Goto: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Goto(instr.fields[0]); - } - default: - LOG(FATAL) << "Invalid opcode" << instr.opcode; - return Instruction(); - } -} - -void Deserializer::DeserializeCodeSection() { - // Load the number of functions. - uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); - - size_t num_funcs = static_cast(sz); - exec_->functions.resize(num_funcs); - for (size_t i = 0; i < num_funcs; i++) { - // Load the function info. - VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); - - // Load the instructions. - std::vector instructions; - for (size_t j = 0; j < loaded_func.num_instructions; j++) { - VMInstructionSerializer instr; - std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); - instructions.push_back(DeserializeInstruction(instr)); - } - - // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, - loaded_func.register_file_size); - auto it = exec_->global_map.find(loaded_func.name); - CHECK(it != exec_->global_map.end()); - CHECK_LE(it->second, exec_->global_map.size()); - exec_->functions[it->second] = vm_func; - } -} - -runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->Init(code, lib); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Deserializer") -.set_body_typed(CreateDeserializer); - -} // namespace vm -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vm/deserializer.h b/src/runtime/vm/deserializer.h deleted file mode 100644 index 54eb02075e8d..000000000000 --- a/src/runtime/vm/deserializer.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/runtime/vm/deserializer.h - * \brief Define a deserializer for the serialized Relay VM executable. - */ - -#ifndef TVM_RUNTIME_VM_DESERIALIZER_H_ -#define TVM_RUNTIME_VM_DESERIALIZER_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace runtime { -namespace vm { - -using namespace tvm::runtime::vm; -namespace runtime = tvm::runtime; - -class Deserializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the deserializer for creating a virtual machine - * executable object. - * - * \param code The serialized code. - * \param lib The serialized runtime module/library that contains the - * hardware dependent code. - */ - void Init(const std::string& code, const runtime::Module& lib); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Deserializer"; } - - /*! \brief Deserialize the serialized VM executable. */ - void Deserialize(); - - virtual ~Deserializer() { delete strm_; } - - private: - /*! \brief Deserialize the globals in `exec_`. */ - void DeserializeGlobalSection(); - - /*! \brief Deserialize the constant pool in `exec_`. */ - void DeserializeConstantSection(); - - /*! \brief Deserialize primitive op names in `exec_`. */ - void DeserializePrimitiveOpNames(); - - /*! \brief Deserialize the vm functions in `exec_`. */ - void DeserializeCodeSection(); - - /*! \brief Deserialize the context in `exec_`. */ - void DeserializeContextSection(); - - /*! \brief The code to be serialized. */ - std::string code_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The VM executable to be created. */ - std::shared_ptr exec_; -}; - -} // namespace vm -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_VM_DESERIALIZER_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f413908d3007..8768ddf53232 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -24,19 +24,32 @@ */ #include +#include +#include #include #include +#include #include #include +#include #include -#include "serializer.h" +#include "serialize_util.h" namespace tvm { namespace runtime { namespace vm { +#define STREAM_CHECK(val, section) \ + CHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +// Helper to serialize a vm instruction. +VMInstructionSerializer SerializeInstruction(const Instruction& instr); +// Helper to deserialize a serialized vm instruction. +Instruction DeserializeInstruction(const VMInstructionSerializer& instr); + PackedFunc Executable::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "get_lib") { @@ -51,6 +64,10 @@ PackedFunc Executable::GetFunction(const std::string& name, return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); + } else if (name == "save") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Save(); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); @@ -149,6 +166,503 @@ std::string Executable::Stats() const { return oss.str(); } +TVMByteArray Executable::Save() { + // Initialize the stream object. + strm_ = new dmlc::MemoryStringStream(&code_); + + uint64_t header = kTVMVMBytecodeMagic; + strm_->Write(header); + std::string version = TVM_VERSION; + strm_->Write(version); + + // Global section. + SaveGlobalSection(); + + // Constant section. + SaveConstantSection(); + + // Primitive names. + SavePrimitiveOpNames(); + + // Code section. + SaveCodeSection(); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Executable::SaveGlobalSection() { + std::vector > globals(this->global_map.begin(), + this->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.first); + } + strm_->Write(glbs); +} + +void Executable::SaveConstantSection() { + std::vector arrays; + for (const auto& obj : this->constants) { + const auto* cell = obj.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm_->Write(static_cast(this->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm_, it); + } +} + +void Executable::SavePrimitiveOpNames() { + std::vector primitive_names; + for (const auto& it : this->primitive_map) { + auto packed_index = static_cast(it.second); + if (primitive_names.size() <= packed_index) { + primitive_names.resize(packed_index + 1); + } + primitive_names[packed_index] = it.first; + } + strm_->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionExecutable`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Executable::SaveCodeSection() { + // Save the number of functions. + strm_->Write(static_cast(this->functions.size())); + for (const auto& func : this->functions) { + // Save the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm_); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm_); + } + } +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->code_ = code; + exec->lib = lib; + // Initialize the stream object. + if (exec->strm_ == nullptr) { + exec->strm_ = new dmlc::MemoryStringStream(&exec->code_); + } + + // Check header. + uint64_t header; + STREAM_CHECK(exec->strm_->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(exec->strm_->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); + + // Global section. + exec->LoadGlobalSection(); + + // Constant section. + exec->LoadConstantSection(); + + // Primitive names that will be invoked by `InvokePacked` instructions. + exec->LoadPrimitiveOpNames(); + + // Code section. + exec->LoadCodeSection(); + + return runtime::Module(exec); +} + +void Executable::LoadGlobalSection() { + std::vector globals; + STREAM_CHECK(strm_->Read(&globals), "global"); + for (size_t i = 0; i < globals.size(); i++) { + this->global_map.insert({globals[i], i}); + } +} + +void Executable::LoadConstantSection() { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm_), "constant"); + runtime::ObjectRef obj = runtime::vm::Tensor(constant); + this->constants.push_back(obj); + } +} + +void Executable::LoadPrimitiveOpNames() { + std::vector primitive_names; + STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); + for (size_t i = 0; i < primitive_names.size(); i++) { + this->primitive_map.insert({primitive_names[i], i}); + } +} + +// Extract the `cnt` number of fields started at `start` from the list +// `instr_fields`. +inline std::vector ExtractFields(const std::vector& instr_fields, + Index start, + Index cnt) { + CHECK_LE(static_cast(start + cnt), instr_fields.size()); + std::vector ret; + for (auto i = start; i < start + cnt; i++) { + ret.push_back(instr_fields[i]); + } + return ret; +} + +Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { + Opcode opcode = static_cast(instr.opcode); + switch (opcode) { + case Opcode::Move: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::Move(instr.fields[0], instr.fields[1]); + } + case Opcode::Ret: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Ret(instr.fields[0]); + } + case Opcode::Fatal: { + // Number of fields = 0 + DCHECK(instr.fields.empty()); + return Instruction::Fatal(); + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index packed_index = instr.fields[0]; + Index arity = instr.fields[1]; + Index output_size = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, arity); + return Instruction::InvokePacked(packed_index, arity, output_size, args); + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 5U); + DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); + + DLDataType dtype; + dtype.code = instr.fields[0]; + dtype.bits = instr.fields[1]; + dtype.lanes = instr.fields[2]; + + Index ndim = instr.fields[3]; + RegName dst = instr.fields[4]; + + std::vector shape = ExtractFields(instr.fields, 5, ndim); + + return Instruction::AllocTensor(shape, dtype, dst); + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + DCHECK_EQ(instr.fields.size(), 5U); + Index shape_register = instr.fields[0]; + + DLDataType dtype; + dtype.code = instr.fields[1]; + dtype.bits = instr.fields[2]; + dtype.lanes = instr.fields[3]; + + RegName dst = instr.fields[4]; + + return Instruction::AllocTensorReg(shape_register, dtype, dst); + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index constructor_tag = instr.fields[0]; + Index num_fields = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector fields = ExtractFields(instr.fields, 3, num_fields); + + return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index clo_index = instr.fields[0]; + Index num_freevar = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); + + return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); + } + case Opcode::If: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + Index test = instr.fields[0]; + Index target = instr.fields[1]; + Index true_offset = instr.fields[2]; + Index false_offset = instr.fields[3]; + + return Instruction::If(test, target, true_offset, false_offset); + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index func_index = instr.fields[0]; + Index num_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_args); + + return Instruction::Invoke(func_index, args, dst); + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index closure = instr.fields[0]; + Index num_closure_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_closure_args); + + return Instruction::InvokeClosure(closure, args, dst); + } + case Opcode::LoadConst: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConst(instr.fields[0], instr.fields[1]); + } + case Opcode::LoadConsti: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); + } + case Opcode::GetField: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); + } + case Opcode::GetTag: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::GetTag(instr.fields[0], instr.fields[1]); + } + case Opcode::Goto: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Goto(instr.fields[0]); + } + default: + LOG(FATAL) << "Invalid opcode" << instr.opcode; + return Instruction(); + } +} + +void Executable::LoadCodeSection() { + // Load the number of functions. + uint64_t sz; + STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); + + size_t num_funcs = static_cast(sz); + this->functions.resize(num_funcs); + for (size_t i = 0; i < num_funcs; i++) { + // Load the function info. + VMFunctionSerializer loaded_func; + STREAM_CHECK(loaded_func.Load(strm_), "code/function"); + + // Load the instructions. + std::vector instructions; + for (size_t j = 0; j < loaded_func.num_instructions; j++) { + VMInstructionSerializer instr; + std::vector instr_fields; + STREAM_CHECK(instr.Load(strm_), "code/instruction"); + instructions.push_back(DeserializeInstruction(instr)); + } + + // Create the VM function. + VMFunction vm_func = VMFunction(loaded_func.name, + loaded_func.params, + instructions, + loaded_func.register_file_size); + auto it = this->global_map.find(loaded_func.name); + CHECK(it != this->global_map.end()); + CHECK_LE(it->second, this->global_map.size()); + this->functions[it->second] = vm_func; + } +} + TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") .set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; @@ -200,6 +714,13 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") } }); +TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") +.set_body_typed([]( + std::string code, + runtime::Module lib) { + return Executable::Load(code, lib); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/serializer.cc b/src/runtime/vm/serializer.cc deleted file mode 100644 index 0d7fb2b2f2e7..000000000000 --- a/src/runtime/vm/serializer.cc +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/runtime/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM executable. - */ -#include "serializer.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace runtime { -namespace vm { - -inline void Serializer::Init(const Executable* exec) { - CHECK(exec); - exec_ = exec; - // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Serializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->exec_->GetLib(); - }); - } else if (name == "serialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Serialize(); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -TVMByteArray Serializer::Serialize() { - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); - - // Global section. - SerializeGlobalSection(); - - // Constant section. - SerializeConstantSection(); - - // Primitive names. - SerializePrimitiveOpNames(); - - // Code section. - SerializeCodeSection(); - - TVMByteArray arr; - arr.data = code_.c_str(); - arr.size = code_.length(); - return arr; -} - -void Serializer::SerializeGlobalSection() { - std::vector > globals(exec_->global_map.begin(), - exec_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - - std::vector glbs; - for (const auto& it : globals) { - glbs.push_back(it.first); - } - strm_->Write(glbs); -} - -void Serializer::SerializeConstantSection() { - std::vector arrays; - for (const auto& obj : exec_->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); - } - strm_->Write(static_cast(exec_->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); - } -} - -void Serializer::SerializePrimitiveOpNames() { - std::vector primitive_names; - for (const auto& it : exec_->primitive_map) { - auto packed_index = static_cast(it.second); - if (primitive_names.size() <= packed_index) { - primitive_names.resize(packed_index + 1); - } - primitive_names[packed_index] = it.first; - } - strm_->Write(primitive_names); -} - -// Serialize a virtual machine instruction. It creates a list that contains the -// hash, opcode, and all fields of an instruction. -// -// For example, the function signature used to create an `AllocTensor` -// instruction is: -// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) -// -// The serialized form will be: -// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` -// -// where hash is the hash of serialized instruction that is computed internally -// by the `VMInstructionSerializer`. It is used for sanity check before decoding. -// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` -// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` -// is the destination register, and the rest of it together indicates the shape -// of the tensor to be allocated. -VMInstructionSerializer SerializeInstruction(const Instruction& instr) { - std::vector fields; - // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; - switch (instr.op) { - case Opcode::Move: { - // Number of fields = 2 - fields.assign({instr.from, instr.dst}); - break; - } - case Opcode::Ret: { - // Number of fields = 1 - fields.push_back(instr.result); - break; - } - case Opcode::Fatal: { - // Number of fields = 0 - break; - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - // Note that arity includes both input arguments and outputs. We will - // put all the `arity` number of fields in the end for serialization. - fields.assign({instr.packed_index, instr.arity, instr.output_size}); - // Save the args. - fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); - break; - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - - // The number of dimensions is not needed for constructing an - // `AllocTensor` instruction as it equals to the length of the `shape` - // vector. However, we save it to conveniently deserialize the instruction - // because we will know how many fields are needed by the `shape` argument. - fields.push_back(instr.alloc_tensor.ndim); - fields.push_back(instr.dst); - - // Save the shape of the tensor. - // Note that this field is rotated to the end of the list. - fields.insert(fields.end(), instr.alloc_tensor.shape, - instr.alloc_tensor.shape + instr.alloc_tensor.ndim); - break; - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - fields.push_back(instr.alloc_tensor_reg.shape_register); - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - fields.push_back(instr.dst); - break; - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); - - // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); - break; - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); - - // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); - break; - } - case Opcode::If: { - // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, - instr.if_op.false_offset}); - break; - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - fields.assign({instr.func_index, instr.num_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.invoke_args_registers, - instr.invoke_args_registers + instr.num_args); - break; - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - fields.assign({instr.closure, instr.num_closure_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); - break; - } - case Opcode::LoadConst: { - // Number of fields = 2 - fields.assign({instr.const_index, instr.dst}); - break; - } - case Opcode::LoadConsti: { - // Number of fields = 2 - fields.assign({instr.load_consti.val, instr.dst}); - break; - } - case Opcode::GetField: { - // Number of fields = 3 - fields.assign({instr.object, instr.field_index, instr.dst}); - break; - } - case Opcode::GetTag: { - // Number of fields = 2 - fields.assign({instr.get_tag.object, instr.dst}); - break; - } - case Opcode::Goto: { - // Number of fields = 1 - fields.push_back(instr.pc_offset); - break; - } - default: - LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); - break; - } - - return VMInstructionSerializer(static_cast(instr.op), fields); -} - -void Serializer::SerializeCodeSection() { - // Save the number of functions. - strm_->Write(static_cast(exec_->functions.size())); - for (const auto& func : exec_->functions) { - // Serialize the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), - func.params); - func_format.Save(strm_); - - // Serialize each instruction. - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); - } - } -} - -runtime::Module CreateSerializer(const Executable* exec) { - std::shared_ptr serializer = std::make_shared(); - serializer->Init(exec); - return runtime::Module(serializer); -} - -TVM_REGISTER_GLOBAL("relay._vm._Serializer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - runtime::Module mod = args[0]; - const auto* exec = dynamic_cast(mod.operator->()); - CHECK(exec) << "Virtual machine has not been defined yet." - << "\n"; - *rv = CreateSerializer(exec); -}); - -} // namespace vm -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vm/serializer.h b/src/runtime/vm/serializer.h deleted file mode 100644 index b3c893878bf6..000000000000 --- a/src/runtime/vm/serializer.h +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/runtime/vm/serializer.h - * \brief Define a serializer for the Relay VM. - * - * The following components of a Relay VM will be serialized: - * - The `constants`, e.g., the constant pool, that contains the - * constants used in a Relay program. - * - The `packed_funcs` that essentially contains the generated code for - * a specific target. We return it as a runtime module that can be exported as - * a library file (e.g., .so, .o, or .tar). - * - The `global_map` that contains the globals. - * - The `primitive_map` that contains the name of individual primitive operators. - * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of - * a list of instructions/bytecode. - * - * Note that only the library is returned as a separate module. All othere parts - * are stored in a single serialized code that is organized with the following - * sections in order. - * - Global section, containing all globals. - * - Constant section, storing the constant pool. - * - Primitive name section, containing the function name of the primitive ops - * used by the virtual machine. - * - Code section, handling the VM functions and bytecode. - * - * The code section is again organized as follows for each VM function: - * func_name, register_file_size, num_instructions (N) - * param1, param2, ..., paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Serializing an `Instruction` requires us to deal with the bytecode. Each line - * of the instructions could be serialized as the following format: - * hash, opcode, f1, f2, ..., fX, field with variable length - * 1. hash: the hash of the instruction. This number will be used to help us - * validate if an instruction is well-formed during deserialization. - * 2. opcode: the opcode code of the instruction. - * 3. f1, f2, ..., fX. These fields together represent the fixed fields in - * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For - * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). - * 4. The rest of the line indicates the field with variable length, e.g., - * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. - */ - -#ifndef TVM_RUNTIME_VM_SERIALIZER_H_ -#define TVM_RUNTIME_VM_SERIALIZER_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace runtime { -namespace vm { - -using namespace tvm::runtime; -using namespace tvm::runtime::vm; - -/*! - * \brief The Relay VM serializer. - */ -class Serializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the serializer for an executable. - * - * \param vm The Relay virtual machine executable. - */ - void Init(const Executable* exec); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Serializer"; } - - /*! - * \brief Serialize the `exec_` into global section, constant section, and code - * section. - * - * \return The binary representation of the VM. - */ - TVMByteArray Serialize(); - - virtual ~Serializer() { delete strm_; } - - private: - /*! \brief Serialize the globals in exec_. */ - void SerializeGlobalSection(); - - /*! \brief Serialize the constant pool in exec_. */ - void SerializeConstantSection(); - - /*! \brief Serialize primitive op names in exec_. */ - void SerializePrimitiveOpNames(); - - /*! \brief Serialize the vm functions in exec_. */ - void SerializeCodeSection(); - - /*! \brief The Relay virtual machine executable to be serialized. */ - const Executable* exec_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The serialized code. */ - std::string code_; -}; - -VMInstructionSerializer SerializeInstruction(const Instruction& instr); - -} // namespace vm -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_VM_SERIALIZER_H_ diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index d7effade913c..014648099aeb 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -22,7 +22,6 @@ from tvm import relay from tvm.relay.module import Module as rly_module from tvm.relay import vm as _vm -from tvm.relay import serializer, deserializer from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.contrib import util @@ -56,11 +55,9 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_exec(mod, target, params=params) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + exe = create_exec(mod, target, params=params) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(ctx) result = des_vm.run(data) @@ -114,8 +111,7 @@ def test_serializer(): assert "f1 2 1 3" in code assert "f2 2 1 3" in code - ser = serializer.Serializer(exe) - code, lib = ser.serialize() + code, lib = exe.save() assert isinstance(code, bytearray) assert isinstance(lib, tvm.module.Module) @@ -127,8 +123,7 @@ def test_save_load(): # serialize. vm = create_exec(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + code, lib = vm.save() assert isinstance(code, bytearray) # save and load the code and lib file. @@ -142,8 +137,7 @@ def test_save_load(): loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. - deser = deserializer.Deserializer(loaded_code, loaded_lib) - des_exec = deser.deserialize() + des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -156,11 +150,9 @@ def test_const(): x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) exe = create_exec(f) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() + code, lib = exe.save() assert isinstance(code, bytearray) - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') @@ -179,10 +171,8 @@ def test_if(): y_data = np.random.rand(10, 10).astype('float32') exe = create_exec(f) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -217,10 +207,8 @@ def test_loop(): mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) exe = create_exec(mod) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -236,10 +224,8 @@ def test_tuple(): j_data = np.random.rand(10).astype('float32') exe = create_exec(f) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -259,10 +245,8 @@ def test_adt_list(): mod["main"] = f exe = create_exec(mod) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -307,10 +291,8 @@ def test_adt_compose(): mod["main"] = f exe = create_exec(mod) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) @@ -329,10 +311,8 @@ def test_closure(): main = clo(relay.const(2.0)) exe = create_exec(main) - ser = serializer.Serializer(exe) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_exec = deser.deserialize() + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) From b4180c20cfcdc12e02189d853615e78fb1a9614b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 17 Oct 2019 04:09:25 +0000 Subject: [PATCH 6/6] create stream --- include/tvm/runtime/vm.h | 69 +++++++++++++++++------- src/runtime/vm/executable.cc | 100 +++++++++++++++++++---------------- 2 files changed, 103 insertions(+), 66 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 011826e6f1a1..a276c658c496 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -541,34 +541,63 @@ class Executable : public ModuleNode { std::vector functions; private: - /*! \brief Save the globals. */ - void SaveGlobalSection(); - - /*! \brief Save the constant pool. */ - void SaveConstantSection(); + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); - /*! \brief Save primitive op names. */ - void SavePrimitiveOpNames(); + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); - /*! \brief Save the vm functions. */ - void SaveCodeSection(); + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); - /*! \brief Load the globals. */ - void LoadGlobalSection(); + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); - /*! \brief Load the constant pool. */ - void LoadConstantSection(); + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); - /*! \brief Load primitive op names. */ - void LoadPrimitiveOpNames(); + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); - /*! \brief Load the vm functions.*/ - void LoadCodeSection(); + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); - /*! \brief The serialized code. */ + /*! \brief The serialized bytecode. */ std::string code_; }; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 8768ddf53232..21f71af4eb8c 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -166,26 +166,32 @@ std::string Executable::Stats() const { return oss.str(); } +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + TVMByteArray Executable::Save() { // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); + code_.clear(); + dmlc::MemoryStringStream strm(&code_); - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); + // Save header + SaveHeader(&strm); // Global section. - SaveGlobalSection(); + SaveGlobalSection(&strm); // Constant section. - SaveConstantSection(); + SaveConstantSection(&strm); // Primitive names. - SavePrimitiveOpNames(); + SavePrimitiveOpNames(&strm); // Code section. - SaveCodeSection(); + SaveCodeSection(&strm); TVMByteArray arr; arr.data = code_.c_str(); @@ -193,7 +199,7 @@ TVMByteArray Executable::Save() { return arr; } -void Executable::SaveGlobalSection() { +void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector > globals(this->global_map.begin(), this->global_map.end()); auto comp = [](const std::pair& a, @@ -206,10 +212,10 @@ void Executable::SaveGlobalSection() { for (const auto& it : globals) { glbs.push_back(it.first); } - strm_->Write(glbs); + strm->Write(glbs); } -void Executable::SaveConstantSection() { +void Executable::SaveConstantSection(dmlc::Stream* strm) { std::vector arrays; for (const auto& obj : this->constants) { const auto* cell = obj.as(); @@ -217,13 +223,13 @@ void Executable::SaveConstantSection() { runtime::NDArray data = cell->data; arrays.push_back(const_cast(data.operator->())); } - strm_->Write(static_cast(this->constants.size())); + strm->Write(static_cast(this->constants.size())); for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); + runtime::SaveDLTensor(strm, it); } } -void Executable::SavePrimitiveOpNames() { +void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { std::vector primitive_names; for (const auto& it : this->primitive_map) { auto packed_index = static_cast(it.second); @@ -232,7 +238,7 @@ void Executable::SavePrimitiveOpNames() { } primitive_names[packed_index] = it.first; } - strm_->Write(primitive_names); + strm->Write(primitive_names); } // Serialize a virtual machine instruction. It creates a list that contains the @@ -384,85 +390,87 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { return VMInstructionSerializer(static_cast(instr.op), fields); } -void Executable::SaveCodeSection() { +void Executable::SaveCodeSection(dmlc::Stream* strm) { // Save the number of functions. - strm_->Write(static_cast(this->functions.size())); + strm->Write(static_cast(this->functions.size())); for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); - func_format.Save(strm_); + func_format.Save(strm); // Serialize each instruction. for (const auto& instr : func.instructions) { const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); + serialized_instr.Save(strm); } } } -runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->code_ = code; - exec->lib = lib; - // Initialize the stream object. - if (exec->strm_ == nullptr) { - exec->strm_ = new dmlc::MemoryStringStream(&exec->code_); - } - +void LoadHeader(dmlc::Stream* strm) { // Check header. uint64_t header; - STREAM_CHECK(exec->strm_->Read(&header), "header"); + STREAM_CHECK(strm->Read(&header), "header"); STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); // Check version. std::string version; - STREAM_CHECK(exec->strm_->Read(&version), "version"); + STREAM_CHECK(strm->Read(&version), "version"); STREAM_CHECK(version == TVM_VERSION, "version"); +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->lib = lib; + exec->code_ = code; + dmlc::MemoryStringStream strm(&exec->code_); + + // Load header. + LoadHeader(&strm); // Global section. - exec->LoadGlobalSection(); + exec->LoadGlobalSection(&strm); // Constant section. - exec->LoadConstantSection(); + exec->LoadConstantSection(&strm); // Primitive names that will be invoked by `InvokePacked` instructions. - exec->LoadPrimitiveOpNames(); + exec->LoadPrimitiveOpNames(&strm); // Code section. - exec->LoadCodeSection(); + exec->LoadCodeSection(&strm); return runtime::Module(exec); } -void Executable::LoadGlobalSection() { +void Executable::LoadGlobalSection(dmlc::Stream* strm) { std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); + STREAM_CHECK(strm->Read(&globals), "global"); for (size_t i = 0; i < globals.size(); i++) { this->global_map.insert({globals[i], i}); } } -void Executable::LoadConstantSection() { +void Executable::LoadConstantSection(dmlc::Stream* strm) { uint64_t sz; // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); size_t size = static_cast(sz); // Load each of the constants. for (size_t i = 0; i < size; i++) { runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); + STREAM_CHECK(constant.Load(strm), "constant"); runtime::ObjectRef obj = runtime::vm::Tensor(constant); this->constants.push_back(obj); } } -void Executable::LoadPrimitiveOpNames() { +void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); + STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); for (size_t i = 0; i < primitive_names.size(); i++) { this->primitive_map.insert({primitive_names[i], i}); } @@ -630,24 +638,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { } } -void Executable::LoadCodeSection() { +void Executable::LoadCodeSection(dmlc::Stream* strm) { // Load the number of functions. uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code"); size_t num_funcs = static_cast(sz); this->functions.resize(num_funcs); for (size_t i = 0; i < num_funcs; i++) { // Load the function info. VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); + STREAM_CHECK(loaded_func.Load(strm), "code/function"); // Load the instructions. std::vector instructions; for (size_t j = 0; j < loaded_func.num_instructions; j++) { VMInstructionSerializer instr; std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); + STREAM_CHECK(instr.Load(strm), "code/instruction"); instructions.push_back(DeserializeInstruction(instr)); }