Skip to content

Commit

Permalink
Add support for using the VM across the RPC boundary. (#7746)
Browse files Browse the repository at this point in the history
* Get basic verison of VM RPC working

* Test case passes

* Clean up PR

* Lint

* Format

* Address Andrew R and TK feedback

* Add comment for Andrew

* Address Zhi's comment

* Format

* Fix broken test
  • Loading branch information
jroesch committed Mar 30, 2021
1 parent 1c2555a commit fd18751
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 42 deletions.
34 changes: 28 additions & 6 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ class Executable : public ModuleNode {
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

/*!
* \brief Write the Executable to the binary stream in serialized form.
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

/*!
* \brief Write the Executable to the provided path as a file contianing its serialized content.
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
Expand Down Expand Up @@ -125,12 +138,24 @@ class Executable : public ModuleNode {
* \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
* \return The runtime module that contains the hardware dependent code.
*/
runtime::Module GetLib() const;

/*!
* \brief Set the `lib` module in an executable.
*
* This allows us to do partial initialization in the case of (de|ser)ialization cases.
* This method also ensures correct initialization of library ensuring we only Import a
* single library.
*
* NB: This also provides some abstraction over how libraries are stored as there are plans
* to iterate on the way runtime::Module works in the backend of the compiler.
*/
runtime::Module GetLib() const { return lib; }
void SetLib(const runtime::Module& lib);

/*!
* \brief Get the arity of the VM Fucntion.
* \brief Get the arity of the VMFunction.
* \param func Function name.
* \return The number of parameters.
*/
Expand All @@ -148,9 +173,6 @@ class Executable : public ModuleNode {

const char* type_key() const final { return "VMExecutable"; }

/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
Expand Down
23 changes: 17 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,24 +269,35 @@ def _collect_dso_modules(self):
return self._collect_from_import_tree(is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.
"""
Export the module and all imported modules into a single device library.
This function only works on host llvm modules.
It will pack all the imported modules
This function only works on host LLVM modules, other runtime::Module
subclasses will work with this API but they must support implement
the save and load mechanisms of modules completely including saving
from streams and files. This will pack your non-shared library module
into a single shared library which can later be loaded by TVM.
Parameters
----------
file_name : str
The name of the shared library.
fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library.
The compilation function to use create the final library object during
export.
For example, when fcompile=_cc.create_shared, or when it is not supplied but
module is "llvm," this is used to link all produced artifacts
into a final dynamic library.
This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".
workspace_dir : str, optional
the path to a directory used to create intermediary
artifacts for the process exporting of the library.
The path of the directory used to create the intermediate
artifacts when exporting the module.
If this is not provided a temporary dir will be created.
kwargs : dict, optional
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

import tvm
from tvm.runtime import Module
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from .object import Object
Expand Down Expand Up @@ -299,12 +300,44 @@ class VirtualMachine(object):
POOLED_ALLOCATOR = 2

def __init__(self, exe, device, memory_cfg=None):
if not isinstance(exe, Executable):
"""
Construct a VirtualMachine wrapper class which provides a simple
interface over the raw C++ Module based API.
Parameters
----------
exe: Union[Executable, Module]
The executable either with the wrapper Python type or the raw runtime.Module.
In most cases this will be the Python wrapper class tvm.runtime.vm.Executable but
if you instead get the underlying runtime.Module subclass (i.e `exe.mod`) you
can directly pass it to this method.
This case can occur when doing things such as RPC where TVM's module APIs
return the raw modules, not the wrapped modules. This constructor will
handle this internally.
device: Union[Device, List[Device]]
The device, or devices on which to execute the VM code.
memory_cfg: Optional[str]
The allocator behavior to use for the VM.
Returns
-------
vm: VirtualMachine
A VM wrapper object.
"""
if not isinstance(exe, Executable) and not isinstance(exe, Module):
raise TypeError(
"exe is expected to be the type of Executable, "
+ "but received {}".format(type(exe))
)
self.module = _ffi_api._VirtualMachine(exe.module)

if not isinstance(exe, Executable):
exe = Executable(exe)

self.module = exe.mod["vm_load_executable"]()
self._exec = exe
self._init = self.module["init"]
self._invoke = self.module["invoke"]
Expand Down
8 changes: 5 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() {

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module lib;
if (funcs.size() > 0) {
Map<String, IRModule> build_funcs;
for (const auto& i : funcs) {
build_funcs.Set(i.first, i.second);
}
exec_->lib = tvm::build(build_funcs, target_host_);
lib = tvm::build(build_funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_);
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_);
exec_->SetLib(lib);
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
42 changes: 23 additions & 19 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
#undef TVM_INIT_CONTEXT_FUNC
}

Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + type_key;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.find(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
LOG(FATAL) << "Binary was created using " << type_key
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}

return (*f)(static_cast<void*>(stream));
}

/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
Expand Down Expand Up @@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
} else {
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + tkey;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
ICHECK(f != nullptr)
<< "Binary was created using " << tkey
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}
Module m = (*f)(static_cast<void*>(stream));
auto m = LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@
#include <tvm/runtime/module.h>

#include <functional>
#include <string>

namespace tvm {
namespace runtime {

/*! \brief Load a module with the given type key directly from the stream.
* This function wraps the registry mechanism used to store type based deserializers
* for each runtime::Module sub-class.
*
* \param type_key The type key of the serialized module.
* \param stream A pointer to the stream containing the serialized module.
* \return module The deserialized module.
*/
Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);

/*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
Expand Down
76 changes: 75 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include <utility>
#include <vector>

#include "../file_utils.h"
#include "../library_module.h"
#include "serialize_utils.h"

namespace tvm {
Expand Down Expand Up @@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else if (name == "vm_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(this);
*rv = Module(vm);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
Expand Down Expand Up @@ -475,9 +483,37 @@ void LoadHeader(dmlc::Stream* strm) {
STREAM_CHECK(version == TVM_VERSION, "version");
}

runtime::Module Executable::GetLib() const {
ICHECK_LE(this->imports_.size(), 1)
<< "The kernel library must be imported as the only module in an Executable";

if (this->imports().size() == 0) {
return Module(nullptr);
} else {
return this->imports_[0];
}
}

void Executable::SetLib(const runtime::Module& lib) {
ICHECK(lib.defined()) << "the provided library can not be null";

ICHECK_EQ(this->imports_.size(), 0)
<< "A VMExecutable should never have more than one import inside an the executable, \n"
<< "the first import should *always* be the library containing"
<< "the platform specific kernel code";

this->Import(lib);
}

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
auto exec = make_object<Executable>();
exec->lib = lib;

// Support null-initialization of lib, to enable initialization during
// deserialization before we have we have deserialized the imports.
if (lib.defined()) {
exec->SetLib(lib);
}

exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);

Expand Down Expand Up @@ -765,6 +801,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}

void Executable::SaveToBinary(dmlc::Stream* stream) {
auto code_bytes = this->Save();
std::string code(code_bytes.data, code_bytes.size);
stream->Write(code);

ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization";
}

Module ExecutableLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string code;
stream->Read(&code);
auto exec = Executable::Load(code, Module());
return exec;
}

void Executable::SaveToFile(const std::string& path, const std::string& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
SaveToBinary(strm);
SaveBinaryToFile(path, data);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);

// Load module from module.
Module ExecutableLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
return exec;
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);

TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
ICHECK(exec) << "The executable is not created yet.";
exec_ = exec;

runtime::Module lib = exec_->lib;
// Get the list of packed functions.
runtime::Module lib = exec_->GetLib();

ICHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions"
<< "\n";
<< "If the executable has declared primitive functions, the"
<< "generated kernel library must non-be null.";

for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
Expand Down
Loading

0 comments on commit fd18751

Please sign in to comment.