Skip to content

Commit

Permalink
[CPP_RPC] allow user supplied work dir (apache#7670)
Browse files Browse the repository at this point in the history
* [CPP_RPC] allow user supplied work dir

* clang format
  • Loading branch information
euntaik authored and Trevor Morris committed May 6, 2021
1 parent 6121b49 commit d9041ce
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 27 deletions.
10 changes: 9 additions & 1 deletion apps/cpp_rpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static const string kUsage =
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
"--work-dir - Custom work directory. Default=\"\"\n"
"--silent - Whether to run in silent mode. Default=False\n"
"\n"
" Example\n"
Expand All @@ -70,6 +71,7 @@ static const string kUsage =
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \arg key The key used to identify the device type in tracker. Default=""
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \arg work_dir Custom work directory. Default=""
* \arg silent Whether run in silent mode. Default=False
*/
struct RpcServerArgs {
Expand All @@ -79,6 +81,7 @@ struct RpcServerArgs {
string tracker;
string key;
string custom_addr;
string work_dir;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
Expand All @@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
LOG(INFO) << "work_dir = " << args.work_dir;
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
}

Expand Down Expand Up @@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
dmlc::InitLogging("--minloglevel=0");
}
#endif
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
if (!work_dir.empty()) {
args.work_dir = work_dir;
}
}

/*!
Expand Down Expand Up @@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) {
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
args.silent);
args.work_dir, args.silent);
return 0;
}

Expand Down
35 changes: 20 additions & 15 deletions apps/cpp_rpc/rpc_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
#include <iostream>
#include <string>
#include <vector>

#include "../../src/support/utils.h"
#include "rpc_env.h"

Expand Down Expand Up @@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname);
*/
std::string BuildSharedLibrary(std::string file_in);

RPCEnv::RPCEnv() {
RPCEnv::RPCEnv(const std::string& wd) {
if (wd != "") {
base_ = wd + "/.cache";
mkdir(wd.c_str(), 0777);
mkdir(base_.c_str(), 0777);
} else {
#if defined(ANDROID) || defined(__ANDROID__)
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
#elif !defined(_WIN32)
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
#else
base_ = "./rpc";
base_ = "./rpc";
#endif
mkdir(base_.c_str(), 0777);
}

mkdir(base_.c_str(), 0777);
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPath(args[0]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct RPCEnv {
/*!
* \brief Constructor Init The RPC Environment initialize function
*/
RPCEnv();
RPCEnv(const std::string& word_dir = "");
/*!
* \brief GetPath To get the workpath from packed function
* \param name The file name
Expand Down
21 changes: 12 additions & 9 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ class RPCServer {
* \brief Constructor.
*/
RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
std::string custom_addr)
std::string custom_addr, std::string work_dir)
: host_(std::move(host)),
port_(port),
my_port_(0),
port_end_(port_end),
tracker_addr_(std::move(tracker_addr)),
key_(std::move(key)),
custom_addr_(std::move(custom_addr)) {}
custom_addr_(std::move(custom_addr)),
work_dir_(std::move(work_dir)) {}

/*!
* \brief Destructor.
Expand Down Expand Up @@ -174,7 +175,7 @@ class RPCServer {
const pid_t worker_pid = fork();
if (worker_pid == 0) {
// Worker process
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
_exit(0);
}

Expand All @@ -201,7 +202,7 @@ class RPCServer {
} else {
auto pid = fork();
if (pid == 0) {
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
exit(0);
}
// Wait for the result
Expand Down Expand Up @@ -308,9 +309,10 @@ class RPCServer {
* \param sock The socket information
* \param addr The socket address information
*/
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
std::string work_dir) {
// Server loop
const auto env = RPCEnv();
const auto env = RPCEnv(work_dir);
RPCServerLoop(int(sock.sockfd));
LOG(INFO) << "Finish serving " << addr.AsString();
env.CleanUp();
Expand Down Expand Up @@ -339,6 +341,7 @@ class RPCServer {
std::string tracker_addr_;
std::string key_;
std::string custom_addr_;
std::string work_dir_;
support::TCPSocket listen_sock_;
support::TCPSocket tracker_sock_;
};
Expand Down Expand Up @@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
* silent mode. Default=True
*/
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
std::string key, std::string custom_addr, bool silent) {
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
if (silent) {
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
std::move(custom_addr));
std::move(custom_addr), std::move(work_dir));
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
});
} // namespace runtime
} // namespace tvm
3 changes: 2 additions & 1 deletion apps/cpp_rpc/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket);
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param work_dir Custom work directory. Default=""
* \param silent Whether run in silent mode. Default=True
*/
void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
std::string tracker_addr = "", std::string key = "",
std::string custom_addr = "", bool silent = true);
std::string custom_addr = "", std::string work_dir = "", bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_

0 comments on commit d9041ce

Please sign in to comment.