Skip to content
2 changes: 1 addition & 1 deletion apps/graph_executor/python/tvm_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from . import _base
from nnvm.symbol import *
from . import op_tvm_def
from .build import build, bind, save_params
from .build import build, bind, save_params, compile_graph


10 changes: 10 additions & 0 deletions apps/graph_executor/python/tvm_graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def bind(g, ctx):
m = _create_exec(g.handle, ctx.device_type, ctx.device_id)
return m

_get_module = tvm.get_global_func("tvm_graph._get_module_from_graph")

def compile_graph(sym_fname, lib_fname, params_fname,
sym, target, shape, dtype="float32"):
g = build(sym, target, shape, dtype)
m = _get_module(g.handle)
m.save(lib_fname)
json_str = g.apply('SaveJSON').json_attr('json')
with open(sym_fname, 'w') as f:
f.write(json_str)

@tvm.register_func("tvm_graph.lower")
def _lower(sch, inputs, func_name):
Expand Down
85 changes: 67 additions & 18 deletions apps/graph_executor/src/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
* \file NNVM Graph executor.
*/
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <numeric>
#include <string>

namespace tvm {
namespace contrib {
Expand Down Expand Up @@ -53,6 +56,8 @@ class GraphExecutor : public runtime::ModuleNode {
void SetInput(int index, DLTensor* data_in);
// Copy index-th output to data_out
void GetOutput(int index, DLTensor* data_out);
// Load parameters from stream
void LoadParams(dmlc::Stream* strm);
// Load parameters from file
void LoadParams(std::string fname);
// Execute the graph.
Expand Down Expand Up @@ -98,7 +103,7 @@ PackedFunc GraphExecutor::GetFunction(
});
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0]);
this->LoadParams(args[0].operator std::string());
});
} else {
return PackedFunc();
Expand Down Expand Up @@ -233,19 +238,17 @@ TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
}
});


void GraphExecutor::LoadParams(std::string fname) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
void GraphExecutor::LoadParams(dmlc::Stream *strm) {
uint64_t header, reserved;
CHECK(fi->Read(&header))
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(fi->Read(&reserved))
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";

std::vector<std::string> names;
CHECK(fi->Read(&names))
CHECK(strm->Read(&names))
<< "Invalid parameters file format";

nnvm::Symbol s;
Expand All @@ -257,20 +260,22 @@ void GraphExecutor::LoadParams(std::string fname) {
name_index.emplace(input_names[i], i);
}

{
uint64_t sz;
fi->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[idx]))
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(fi.get(), &data_entry_[idx]))
<< "Invalid parameters file format";
}
}
}

void GraphExecutor::LoadParams(std::string fname) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
this->LoadParams(fi.get());
}

void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph();
Expand Down Expand Up @@ -383,7 +388,8 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs,
}
// get compiled function from module.
runtime::PackedFunc pf = module_.GetFunction(it->second, false);
auto fexec = [arg_ptr, pf] () {
CHECK(pf != nullptr) << "no such function in module: " << it->second;
auto fexec = [arg_ptr, pf] () {
runtime::TVMRetValue rv;
runtime::TVMArgs targs(arg_ptr->arg_values.data(),
arg_ptr->arg_tcodes.data(),
Expand All @@ -410,5 +416,48 @@ TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
*rv = CreateExecutor(g, ctx);
});


TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* graph_handle = args[0];
nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
*rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
});


TVM_REGISTER_GLOBAL("tvm_graph._load_executor")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string sym_json = args[0];
std::string lib_fname = args[1];
std::string param_blob = args[2];
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[3].operator int());
ctx.device_id = args[4];

// load graph from json string
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(sym_json);
g = nnvm::ApplyPass(std::move(g), "LoadJSON");

// load module from file
static const PackedFunc* fsys_load_ = nullptr;
if (fsys_load_ == nullptr) {
fsys_load_ = runtime::Registry::Get("tvm.contrib.rpc.server.load_module");
CHECK(fsys_load_ != nullptr);
}
runtime::Module m = (*fsys_load_)(lib_fname);
g.attrs["module"] = std::make_shared<nnvm::any>(m);

std::shared_ptr<GraphExecutor> exec =
std::make_shared<GraphExecutor>();
exec->Init(g, ctx);

// load params form stream of string
dmlc::MemoryStringStream strm(&param_blob);
exec->LoadParams(&strm);

*rv = tvm::runtime::Module(exec);
});
} // namespace contrib
} // namespace tvm
25 changes: 25 additions & 0 deletions apps/graph_executor/src/graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file Additional optimization pass of NNVM.
*/
#include <dmlc/json.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
Expand Down Expand Up @@ -514,3 +515,27 @@ NNVM_REGISTER_OP(layout_transform)
.set_num_outputs(1);
} // namespace contrib
} // namespace tvm

namespace dmlc {
namespace json {

template<>
struct Handler<DLDataType> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this in c_runtime_api.cc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just put this here is fine

static void Write(JSONWriter *writer, const DLDataType& data) {
std::vector<int> tmp({data.code, data.bits, data.lanes});
writer->Write(tmp);
}

static void Read(JSONReader *reader, DLDataType* data) {
std::vector<int> tmp;
reader->Read(&tmp);
data->code = tmp[0];
data->bits = tmp[1];
data->lanes = tmp[2];
}
};

DMLC_JSON_ENABLE_ANY(std::vector<DLDataType>, list_dltype);

} // namespace dmlc
} // namespace json
49 changes: 49 additions & 0 deletions apps/graph_executor/tests/test_rpc_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tvm
from tvm.contrib import util, rpc
import tvm_graph as tg
import numpy as np
import os

def test_rpc_executor():
host = 'localhost'
port = 9090
server = rpc.Server(host, port)

tmp = util.tempdir()
sym_fname = tmp.relpath('net.json')
lib_fname = tmp.relpath('net.o')
param_fname = tmp.relpath('net.param')

x = tg.Variable('x')
y = tg.Variable('y')
sym = tg.exp(y + x)

shape = (10, 128)
dtype = tvm.float32
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
tg.save_params(param_fname, {'x': na, 'y': nb})

target = "llvm"
shapes = {'x': shape, 'y': shape}
tg.compile_graph(sym_fname, lib_fname, param_fname,
sym, target, shapes)

remote = rpc.connect(host, port)
ctx = remote.cpu(0)

remote.upload(lib_fname)
rm = remote.load_executor(sym_fname, os.path.basename(lib_fname),
param_fname, ctx)
run, get_output = rm['run'], rm['get_output']

nc = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
run()
get_output(0, nc)

np.testing.assert_allclose(
nc.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
server.terminate()

if __name__ == "__main__":
test_rpc_executor()
34 changes: 34 additions & 0 deletions python/tvm/contrib/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(self, sess):
self._tbl_index = _SessTableIndex(sess)
self._upload_func = None
self._download_func = None
self._load_executor_func = None

def get_function(self, name):
"""Get function from the session.
Expand Down Expand Up @@ -297,6 +298,39 @@ def load_module(self, path):
"""
return _LoadRemoteModule(self._sess, path)

def load_executor(self, sym_fname, lib_fname, param_fname, ctx):
"""Load a remote graph executor, with the local files.

Parameters
----------
sym_fname : str
The local path to the symbol json file.

lib_fname : str
The relative library location to remote temp folder. The
library need to be uploaded first.

param_fname : str
The local path to the parameters file.

Returns
-------
exec : GraphExecutor
The remote graph executor containing remote function.
"""
sym_json = open(sym_fname, "r").read()
param_blob = bytearray(open(param_fname, "rb").read())
if not self._load_executor_func:
self._load_executor_func = self.get_function(
"tvm_graph._load_executor")
assert ctx.device_type / RPC_SESS_MASK == self._tbl_index + 1
device_type = ctx.device_type % RPC_SESS_MASK
return self._load_executor_func(sym_json,
lib_fname,
param_blob,
device_type,
ctx.device_id)


def connect(url, port, key=""):
"""Connect to RPC Server
Expand Down
41 changes: 40 additions & 1 deletion python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,57 @@

import logging
import argparse
import os
import ctypes
from ..contrib import rpc

def find_lib_path(name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simply import it instead of copy it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an additional argument search_path=None, to the original function

"""Find dynamic library."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
base_path = os.path.join(curr_path, "../../../")
apps_path = os.path.join(base_path, "apps/graph_executor/lib/")
api_path = os.path.join(base_path, 'lib/')
cmake_build_path = os.path.join(base_path, 'build/Release/')
dll_path = [curr_path, base_path, apps_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(base_path, 'build', vs_configuration))
dll_path.append(os.path.join(base_path, 'windows/x64', vs_configuration))
else:
dll_path.append(os.path.join(base_path, 'build', vs_configuration))
dll_path.append(os.path.join(base_path, 'windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
dll_path = [os.path.abspath(x) for x in dll_path]
lib_dll_path = [os.path.join(p, name) for p in dll_path]

# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path)))
return lib_found


def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--port_end', type=int, default=9199,
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime")
args = parser.parse_args()

if args.with_executor:
lib_path = find_lib_path('libtvm_graph_exec.so')
ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)

logging.basicConfig(level=logging.INFO)
server = rpc.Server(args.host, args.port, args.port_end)
server.proc.join()
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class RPCSession::EventHandler {
arg_recv_stage_ = 0;
arg_buf_.reset();
}
// strip sessionon mask
// strip session on mask
TVMContext StripSessMask(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum class RPCCode : int {
kModuleLoad,
kModuleFree,
kModuleGetFunc,
kModuleGetSource
kModuleGetSource,
};

/*!
Expand Down