Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][vm] Separate VM runtime with executable #4100

Merged
merged 6 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 184 additions & 26 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -430,15 +431,184 @@ struct VMFrame {
caller_return_register(0) {}
};

/*! \brief The executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*/
class Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Save();

/*!
* \brief Load the saved VM executable.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);

/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.

* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
zhiics marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;

/*! \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
runtime::Module GetLib() const { return lib; }

virtual ~Executable() {}

const char* type_key() const final {
return "VMExecutable";
}

/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;

private:
/*!
* \brief Save the globals.
*
* \param strm The input stream.
*/
void SaveGlobalSection(dmlc::Stream* strm);

/*!
* \brief Save the constant pool.
*
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);

/*!
* \brief Save primitive op names.
*
* \param strm The input stream.
*/
void SavePrimitiveOpNames(dmlc::Stream* strm);

/*!
* \brief Save the vm functions.
*
* \param strm The input stream.
*/
void SaveCodeSection(dmlc::Stream* strm);

/*!
* \brief Load the globals.
*
* \param strm The input stream.
*/
void LoadGlobalSection(dmlc::Stream* strm);

/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);

/*!
* \brief Load primitive op names.
*
* \param strm The input stream.
*/
void LoadPrimitiveOpNames(dmlc::Stream* strm);

/*!
* \brief Load the vm functions.
*
* \param strm The input stream.
*/
void LoadCodeSection(dmlc::Stream* strm);

/*! \brief The serialized bytecode. */
std::string code_;
};

/*! \brief The virtual machine.
*
* The virtual machine contains all the current execution state,
* as well as the global view of functions, the global constant
* table, the compiled operators.
* as well as the executable.
*
* The goal is to have a single self-contained object,
* enabling one to easily pass around VMs, execute them on
* multiple threads, or serialized them to disk or over the
* multiple threads, or serialize them to disk or over the
* wire.
*/
class VirtualMachine : public runtime::ModuleNode {
Expand Down Expand Up @@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine";
}

/*! \brief The runtime module/library that contains generated code. */
runtime::Module lib;
VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}

/*! \brief load the executable for the virtual machine.
* \param exec The executable.
*/
void LoadExecutable(const Executable* exec);

protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */
Index func_index;
/*! \brief The current pointer to the code section. */
Expand All @@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */
ObjectRef return_register;

/*! \brief The executable the VM will operate on. */
const Executable* exec;
zhiics marked this conversation as resolved.
Show resolved Hide resolved

/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;

Expand Down Expand Up @@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);

VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}

/*! \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts.
*/
Expand All @@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/
TVMContext GetParamsContext() const;

/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);

/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;

/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;

private:
/*! \brief Invoke a global setting up the VM state to execute.
*
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj

# Root operators
Expand Down
81 changes: 0 additions & 81 deletions python/tvm/relay/backend/deserializer.py

This file was deleted.

12 changes: 8 additions & 4 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None):

Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
exec : Executable
The executable with profiling code.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
Expand All @@ -60,21 +60,25 @@ def compile(mod, target=None, target_host=None, params=None):
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm())
return vm.Executable(compiler._get_exec())

class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]

class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super().__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]

def get_stat(self):
Expand Down
Loading