From 10401fb6657f813dd319deeeec2c04c1cc682af8 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 3 Nov 2019 19:57:11 +0800 Subject: [PATCH 01/13] Support C++ RPC --- apps/cpp_rpc/Makefile | 28 +++ apps/cpp_rpc/README.md | 38 ++++ apps/cpp_rpc/main.cc | 270 ++++++++++++++++++++++ apps/cpp_rpc/rpc_env.cc | 248 ++++++++++++++++++++ apps/cpp_rpc/rpc_env.h | 84 +++++++ apps/cpp_rpc/rpc_server.cc | 363 ++++++++++++++++++++++++++++++ apps/cpp_rpc/rpc_server.h | 55 +++++ apps/cpp_rpc/rpc_tracker_client.h | 249 ++++++++++++++++++++ src/common/socket.h | 187 ++++++++++++++- src/common/util.h | 77 +++++++ src/runtime/rpc/rpc_session.h | 21 ++ src/runtime/rpc/rpc_socket_impl.h | 39 ++++ 12 files changed, 1657 insertions(+), 2 deletions(-) create mode 100644 apps/cpp_rpc/Makefile create mode 100644 apps/cpp_rpc/README.md create mode 100644 apps/cpp_rpc/main.cc create mode 100644 apps/cpp_rpc/rpc_env.cc create mode 100644 apps/cpp_rpc/rpc_env.h create mode 100644 apps/cpp_rpc/rpc_server.cc create mode 100644 apps/cpp_rpc/rpc_server.h create mode 100644 apps/cpp_rpc/rpc_tracker_client.h create mode 100644 src/common/util.h create mode 100644 src/runtime/rpc/rpc_socket_impl.h diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile new file mode 100644 index 000000000000..5b464f9123ca --- /dev/null +++ b/apps/cpp_rpc/Makefile @@ -0,0 +1,28 @@ +# Makefile to compile RPC Server. +TVM_ROOT=$(shell cd ../..; pwd) +DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core +TVM_RUNTIME_DIR?= + +PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ + -I${TVM_ROOT}/include\ + -I${DMLC_CORE}/include\ + -I${TVM_ROOT}/3rdparty/dlpack/include + +PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR) + +ifeq ($(USE_GLOG), 1) + PKG_CFLAGS += -DDMLC_USE_GLOG=1 + PKG_LDFLAGS += -lglog +endif + +.PHONY: clean all + +all: tvm_rpc + +# Build rule for all in one TVM package library +tvm_rpc: *.cc + @mkdir -p $(@D) + $(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS) + +clean: + -rm -f tvm_rpc \ No newline at end of file diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md new file mode 100644 index 000000000000..f7846408bab1 --- /dev/null +++ b/apps/cpp_rpc/README.md @@ -0,0 +1,38 @@ +# TVM RPC Server +This folder contains a simple recipe to make RPC server in c++. + +## Usage +- Build tvm runtime +- Make the rpc executable [Makefile](Makefile). + `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/` + You could cross compile the TVM runtime like this: +``` + cd tvm + mkdir arm_runtime + cp cmake/config.cmake arm_runtime + cd arm_runtime + cmake .. -DCMAKE_CXX_COMPILER="/path/to/cross compiler g++/" + make runtime +``` +- Use `./tvm_rpc server` to start the RPC server + +## How it works +- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. + +``` +Command line usage + server - Start the server +--host - The hostname of the server, Default=0.0.0.0 +--port - The port of the RPC, Default=9090 +--port-end - The end search port of the RPC, Default=9199 +--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default="" +--key - The key used to identify the device type in tracker. Default="" +--custom-addr - Custom IP Address to Report to RPC Tracker. Default="" +--silent - Whether to run in silent mode. Default=True +--proxy - Whether to run in proxy mode. Default=False + Example + ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp +``` + +## Note +Currently support is only there for Linux / Android environment. \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc new file mode 100644 index 000000000000..1c8b01a2f69e --- /dev/null +++ b/apps/cpp_rpc/main.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_server.cc + * \brief RPC Server for TVM. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../src/common/util.h" +#include "../../src/common/socket.h" +#include "rpc_server.h" + +using namespace std; +using namespace tvm::runtime; +using namespace tvm::common; + +static const string kUSAGE = \ +"Command line usage\n" \ +" server - Start the server\n" \ +"--host - The hostname of the server, Default=0.0.0.0\n" \ +"--port - The port of the RPC, Default=9090\n" \ +"--port-end - The end search port of the RPC, Default=9199\n" \ +"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ +"--key - The key used to identify the device type in tracker. Default=\"\"\n" \ +"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ +"--silent - Whether to run in silent mode. Default=True\n" \ +"--proxy - Whether to run in proxy mode. Default=False\n" \ +"\n" \ +" Example\n" \ +" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " +" --tracker=127.0.0.1:9190 --key=rasp" \ +"\n"; + +/*! + * \brief RpcServerArgs. + * \arg host The hostname of the server, Default=0.0.0.0 + * \arg port The port of the RPC, Default=9090 + * \arg port_end The end search port of the RPC, Default=9199 + * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \arg key The key used to identify the device type in tracker. Default="" + * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \arg silent Whether run in silent mode. Default=True + * \arg isProxy Whether to run in proxy mode. Default=False + */ +struct RpcServerArgs { + string host = "0.0.0.0"; + int port = 9090; + int port_end = 9099; + string tracker; + string key; + string custom_addr; + bool silent = false; + bool isProxy = false; +}; + +/*! + * \brief PrintArgs print the contents of RpcServerArgs + * \param args RpcServerArgs structure + */ +void PrintArgs(struct RpcServerArgs args) { + LOG(INFO) << "host = " << args.host; + LOG(INFO) << "port = " << args.port; + LOG(INFO) << "port_end = " << args.port_end; + LOG(INFO) << "tracker = " << args.tracker; + LOG(INFO) << "key = " << args.key; + LOG(INFO) << "custom_addr = " << args.custom_addr; + LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); + LOG(INFO) << "proxy = " << ((args.isProxy) ? ("True"): ("False")); +} + +/*! + * \brief CtrlCHandler, exits if Ctrl+C is pressed + * \param s signal + */ +void CtrlCHandler(int s) { + LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; + exit(1); +} + +/*! + * \brief HandleCtrlC Register for handling Ctrl+C event. + */ +void HandleCtrlC() { + // Ctrl+C handler + struct sigaction sigIntHandler; + sigIntHandler.sa_handler = CtrlCHandler; + sigemptyset(&sigIntHandler.sa_mask); + sigIntHandler.sa_flags = 0; + sigaction(SIGINT, &sigIntHandler, NULL); +} + +/*! + * \brief GetCmdOption Parse and find the command option. + * \param argc arg counter + * \param argv arg values + * \param option command line option to search for. + * \param key whether the option itself is key + * \return value corresponding to option. + */ +string GetCmdOption(int argc, char* argv[], string option, bool key = false) { + string cmd; + for (int i = 1; i < argc; ++i) { + string arg = argv[i]; + if (arg.find(option) == 0) { + if (key) { + cmd = argv[i]; + return cmd; + } + // We assume "=" is the end of option. + CHECK_EQ(*option.rbegin(), '='); + cmd = arg.substr(arg.find("=") + 1); + return cmd; + } + } + return cmd; +} + +/*! + * \brief ValidateTracker Check the tracker address format is correct and changes the format. + * \param tracker The tracker input. + * \return result of operation. + */ +bool ValidateTracker(string &tracker) { + vector list = Split(tracker, ':'); + if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { + return false; + } + ostringstream ss; + ss << "('" << list[0] << "', " << list[1] << ")"; + tracker = ss.str(); + return true; +} + +/*! + * \brief ParseCmdArgs parses the command line arguments. + * \param argc arg counter + * \param argv arg values + * \param args, the output structure which holds the parsed values + */ +void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { + string silent = GetCmdOption(argc, argv, "--silent", true); + if (!silent.empty()) { + args.silent = true; + // Only errors and fatal is logged + dmlc::InitLogging("--minloglevel=2"); + } + + string host = GetCmdOption(argc, argv, "--host="); + if (!host.empty()) { + if (!ValidateIP(host)) { + LOG(WARNING) << "Wrong host address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.host = host; + } + + string port = GetCmdOption(argc, argv, "--port="); + if (!port.empty()) { + if (!IsNumber(port) || stoi(port) > 65535) { + LOG(WARNING) << "Wrong port number."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.port = stoi(port); + } + + string port_end = GetCmdOption(argc, argv, "--port_end="); + if (!port_end.empty()) { + if (!IsNumber(port_end) || stoi(port_end) > 65535) { + LOG(WARNING) << "Wrong port_end number."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.port_end = stoi(port_end); + } + + string tracker = GetCmdOption(argc, argv, "--tracker="); + if (!tracker.empty()) { + if (!ValidateTracker(tracker)) { + LOG(WARNING) << "Wrong tracker address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.tracker = tracker; + } + + string key = GetCmdOption(argc, argv, "--key="); + if (!key.empty()) { + args.key = key; + } + + string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + if (!custom_addr.empty()) { + if (!ValidateIP(custom_addr)) { + LOG(WARNING) << "Wrong custom address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.custom_addr = custom_addr; + } +} + +/*! + * \brief RpcServer Starts the RPC server. + * \param argc arg counter + * \param argv arg values + * \return result of operation. + */ +int RpcServer(int argc, char * argv[]) { + struct RpcServerArgs args; + + /* parse the command line args */ + ParseCmdArgs(argc, argv, args); + PrintArgs(args); + + // Ctrl+C handler + LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; + HandleCtrlC(); + tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, + args.key, args.custom_addr, args.silent); + return 0; +} + +/*! + * \brief main The main function. + * \param argc arg counter + * \param argv arg values + * \return result of operation. + */ +int main(int argc, char * argv[]) { + if (argc <= 1) { + LOG(INFO) << kUSAGE; + return 0; + } + + if (0 == strcmp(argv[1], "server")) { + RpcServer(argc, argv); + } else { + LOG(INFO) << kUSAGE; + } + + return 0; +} diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc new file mode 100644 index 000000000000..b5c2a8c186e2 --- /dev/null +++ b/apps/cpp_rpc/rpc_env.cc @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_env.cc + * \brief Server environment of the RPC. + */ +#include +#include +#ifndef _MSC_VER +#include +#include +#include +#else +#include +#endif +#include +#include +#include +#include +#include + +#include "rpc_env.h" +#include "../../src/common/util.h" +#include "../../src/runtime/file_util.h" + +namespace tvm { +namespace runtime { + +RPCEnv::RPCEnv() { + #if defined(__linux__) || defined(__ANDROID__) + base_ = "./rpc"; + mkdir(&base_[0], 0777); + + TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") + .set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); + }); + + TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") + .set_body([](TVMArgs args, TVMRetValue *rv) { + static RPCEnv env; + std::string file_name = env.GetPath(args[0]); + *rv = Load(&file_name, ""); + LOG(INFO) << "Load module from " << file_name << " ..."; + }); + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} +/*! + * \brief GetPath To get the workpath from packed function + * \param name The file name + * \return The full path of file. + */ +std::string RPCEnv::GetPath(const std::string& file_name) { + // we assume file_name has "/" means file_name is the exact path + // and does not create /.rpc/ + if (file_name.find("/") != std::string::npos) { + return file_name; + } else { + return base_ + "/" + file_name; + } +} +/*! + * \brief Remove The RPC Environment cleanup function + */ +void RPCEnv::Remove() { + #if defined(__linux__) || defined(__ANDROID__) + CleanDir(&base_[0]); + int ret = rmdir(&base_[0]); + if (ret != 0) { + LOG(WARNING) << "Remove directory " << base_ << " failed"; + } + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} + +/*! + * \brief ListDir get the list of files in a directory + * \param dirname The root directory name + * \return vector Files in directory. + */ +std::vector ListDir(const std::string &dirname) { + std::vector vec; + #ifndef _MSC_VER + DIR *dp = opendir(dirname.c_str()); + if (dp == NULL) { + int errsv = errno; + LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); + } + dirent *d; + while ((d = readdir(dp)) != NULL) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; + } + f += d->d_name; + vec.push_back(f); + } + } + closedir(dp); + #else + WIN32_FIND_DATA fd; + std::string pattern = dirname + "/*"; + HANDLE handle = FindFirstFile(pattern.c_str(), &fd); + if (handle == INVALID_HANDLE_VALUE) { + int errsv = GetLastError(); + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + do { + if (fd.cFileName != "." && fd.cFileName != "..") { + std::string f = dirname; + char clast = f[f.length() - 1]; + if (f == ".") { + f = fd.cFileName; + } else if (clast != '/' && clast != '\\') { + f += '/'; + f += fd.cFileName; + } + vec.push_back(f); + } + } while (FindNextFile(handle, &fd)); + FindClose(handle); + #endif + return vec; +} + +/*! + * \brief LinuxShared Creates a linux shared library + * \param output The output file name + * \param files The files for building + * \param options The compiler options + * \param cc The compiler + */ +void LinuxShared(const std::string output, + const std::vector &files, + std::string options = "", + std::string cc = "g++") { + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (auto f = files.begin(); f != files.end(); ++f) { + cmd += " " + *f; + } + cmd += " " + options; + CHECK(system(cmd.c_str()) == 0) << "Compilation error."; +} + +/*! + * \brief CreateShared Creates a shared library + * \param output The output file name + * \param files The files for building + */ +void CreateShared(const std::string output, const std::vector &files) { + #if defined(__linux__) || defined(__ANDROID__) + LinuxShared(output, files); + #else + LOG(FATAL) << "Do not support creating shared library"; + #endif +} + +/*! + * \brief Load Load module from file + This function will automatically call + cc.create_shared if the path is in format .o or .tar + High level handling for .o and .tar file. + We support this to be consistent with RPC module load. + * \param fileIn The input file, file name will be updated + * \param fmt The format of file + * \return Module The loaded module + */ +Module Load(std::string *fileIn, const std::string fmt) { + std::string file = *fileIn; + if (common::EndsWith(file, ".so")) { + return Module::LoadFromFile(file, fmt); + } + + #if defined(__linux__) || defined(__ANDROID__) + std::string file_name = file + ".so"; + if (common::EndsWith(file, ".o")) { + std::vector files; + files.push_back(file); + CreateShared(file_name, files); + } else if (common::EndsWith(file, ".tar")) { + std::string tmp_dir = "./rpc/tmp/"; + mkdir(&tmp_dir[0], 0777); + std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; + CHECK(system(cmd.c_str()) == 0) << "Untar library error."; + CreateShared(file_name, ListDir(tmp_dir)); + CleanDir(tmp_dir); + rmdir(&tmp_dir[0]); + } else { + file_name = file; + } + *fileIn = file_name; + return Module::LoadFromFile(file_name, fmt); + #else + LOG(FATAL) << "Do not support creating shared library"; + #endif +} + +/*! + * \brief CleanDir Removes the files from the directory + * \param dirname The name of the directory + */ +void CleanDir(const std::string &dirname) { + #if defined(__linux__) || defined(__ANDROID__) + DIR *dp = opendir(dirname.c_str()); + dirent *d; + while ((d = readdir(dp)) != NULL) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + filename = dirname + "/" + d->d_name; + int ret = std::remove(&filename[0]); + if (ret != 0) { + LOG(WARNING) << "Remove file " << filename << " failed"; + } + } + } + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} + +} // namespace runtime +} // namespace tvm diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h new file mode 100644 index 000000000000..becdf5daf9dd --- /dev/null +++ b/apps/cpp_rpc/rpc_env.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_env.h + * \brief Server environment of the RPC. + */ +#ifndef TVM_APPS_CPP_RPC_ENV_H_ +#define TVM_APPS_CPP_RPC_ENV_H_ + +#include +#if defined(__linux__) || defined(__ANDROID__) +#include +#include +#endif +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Load Load module from file + This function will automatically call + cc.create_shared if the path is in format .o or .tar + High level handling for .o and .tar file. + We support this to be consistent with RPC module load. + * \param file The input file + * \param file The format of file + * \return Module The loaded module + */ +Module Load(std::string *path, const std::string fmt = ""); + +/*! + * \brief CleanDir Removes the files from the directory + * \param dirname THe name of the directory + */ +void CleanDir(const std::string &dirname); + +/*! + * \brief RPCEnv The RPC Environment parameters for c++ rpc server + */ +struct RPCEnv { + public: + /*! + * \brief Constructor Init The RPC Environment initialize function + */ + RPCEnv(); + /*! + * \brief GetPath To get the workpath from packed function + * \param name The file name + * \return The full path of file. + */ + std::string GetPath(const std::string& file_name); + /*! + * \brief Remove The RPC Environment cleanup function + */ + void Remove(); + + /*! + * \base_ Holds the environment path. + */ + std::string base_; +}; // RPCEnv + +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_ENV_H_ diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc new file mode 100644 index 000000000000..6eb244ca26c7 --- /dev/null +++ b/apps/cpp_rpc/rpc_server.cc @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_server.cc + * \brief RPC Server implementation. + */ + +#include + +#if defined(__linux__) || defined(__ANDROID__) +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "rpc_server.h" +#include "rpc_env.h" +#include "rpc_tracker_client.h" +#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_socket_impl.h" +#include "../../src/common/socket.h" + +#if defined(__linux__) || defined(__ANDROID__) +static pid_t waitpid_eintr(int *status) { + pid_t pid = 0; + while ((pid = waitpid(-1, status, 0)) == -1) { + if (errno == EINTR) { + continue; + } else { + perror("waitpid"); + abort(); + } + } + return pid; +} +#endif + +namespace tvm { +namespace runtime { + +/*! + * \brief RPCServer RPC Server class. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param isProxy Whether to run in proxy mode. Default=False + */ +class RPCServer { + public: + /*! + * \brief Constructor. + */ + RPCServer(const std::string &host, + int port, + int port_end, + const std::string &tracker_addr, + const std::string &key, + const std::string &custom_addr, + bool is_proxy) { + // Init the values + host_ = host; + port_ = port; + port_end_ = port_end; + tracker_addr_ = tracker_addr; + key_ = key; + custom_addr_ = custom_addr; + is_proxy_ = is_proxy; + } + + /*! + * \brief Destructor. + */ + ~RPCServer() { + // Free the resources + tracker_sock_.Close(); + listen_sock_.Close(); + } + + /*! + * \brief Start Creates the RPC listen process and execution. + */ + void Start() { + listen_sock_.Create(); + my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); + LOG(INFO) << "bind to " << host_ << ":" << my_port_; + listen_sock_.Listen(1); + std::future proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); + proc.get(); + // Close the listen socket + listen_sock_.Close(); + } + + private: + /*! + * \brief ListenLoopProc The listen process. + */ + void ListenLoopProc() { + TrackerClient tracker(tracker_addr_, key_, custom_addr_); + while (1) { + common::TCPSocket conn; + common::SockAddr addr("0.0.0.0", 0); + std::string opts; + try { + // step 1: setup tracker and report to tracker + tracker.TryConnect(); + // step 2: wait for in-coming connections + AcceptConnection(&tracker, &conn, &addr, &opts); + } + catch (const char* msg) { + LOG(WARNING) << "Socket exception: " << msg; + // close tracker resource + tracker.Close(); + continue; + } + catch (std::exception& e) { + // Other errors + LOG(WARNING) << "Exception standard: " << e.what(); + continue; + } + + int timeout = GetTimeOutFromOpts(opts); + #if defined(__linux__) || defined(__ANDROID__) + // step 3: serving + if (timeout) { + const pid_t timer_pid = fork(); + if (timer_pid == 0) { + // Timer process + sleep(timeout); + exit(0); + } + + const pid_t worker_pid = fork(); + if (worker_pid == 0) { + // Worker process + ServerLoopProc(conn, addr); + exit(0); + } + + int status = 0; + const pid_t finished_first = waitpid_eintr(&status); + if (finished_first == timer_pid) { + kill(worker_pid, SIGKILL); + } else if (finished_first == worker_pid) { + kill(timer_pid, SIGKILL); + } else { + LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; + } + + int status_second = 0; + waitpid_eintr(&status_second); + + // Logging. + if (finished_first == timer_pid) { + LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout + << "), Process status =" << status_second; + } else if (finished_first == worker_pid) { + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status =" << status_second; + } + } else { + auto pid = fork(); + if (pid == 0) { + ServerLoopProc(conn, addr); + exit(0); + } + // Wait for the result + int status = 0; + wait(&status); + LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + } + #else + // step 3: serving + std::future proc(std::async(std::launch::async, + &RPCServer::ServerLoopProc, this, conn, addr)); + // wait until server process finish or timeout + if (timeout) { + // Autoterminate after timeout + proc.wait_for(std::chrono::seconds(timeout)); + } else { + // Wait for the result + proc.get(); + } + #endif + // close from our side. + LOG(INFO) << "Socket Connection Closed"; + conn.Close(); + } + } + + + /*! + * \brief AcceptConnection Accepts the RPC Server connection. + * \param tracker Tracker details. + * \param conn New connection information. + * \param addr New connection address information. + * \param opts Parsed options for socket + * \param ping_period Timeout for select call waiting + */ + void AcceptConnection(TrackerClient *tracker, + common::TCPSocket *conn_sock, + common::SockAddr *addr, + std::string *opts, + int ping_period = 2) { + std::set old_keyset; + std::string matchkey; + + // Report resource to tracker and get key + tracker->ReportResourceAndGetKey(my_port_, &matchkey); + + while (1) { + tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey); + common::TCPSocket conn = listen_sock_.Accept(addr); + + int code = kRPCMagic; + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + if (code != kRPCMagic) { + conn.Close(); + LOG(FATAL) << "Client connected is not TVM RPC server"; + continue; + } + + int keylen = 0; + CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); + + #define CLIENT_HEADER "client:" + #define SERVER_HEADER "server:" + + std::string expect_header = CLIENT_HEADER + matchkey; + std::string server_key = SERVER_HEADER + key_; + if (size_t(keylen) < expect_header.length()) { + conn.Close(); + LOG(FATAL) << "Wrong client header length"; + continue; + } + + std::string remote_key; + remote_key.resize(keylen); + CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); + + std::stringstream ssin(remote_key); + std::string arg0; + ssin >> arg0; + if (arg0 != expect_header) { + code = kRPCMismatch; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + conn.Close(); + LOG(WARNING) << "Mismatch key from" << addr->AsString(); + continue; + } else { + code = kRPCSuccess; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + keylen = server_key.length(); + CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); + CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); + LOG(INFO) << "Connection success " << addr->AsString(); + ssin >> *opts; + *conn_sock = conn; + return; + } + } + } + + /*! + * \brief ServerLoopProc The Server loop process. + * \param sock The socket information + * \param addr The socket address information + */ + void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { + // Server loop + auto env = RPCEnv(); + RPCServerLoop(sock.sockfd); + LOG(INFO) << "Finish serving " << addr.AsString(); + env.Remove(); + } + + /*! + * \brief GetTimeOutFromOpts Parse and get the timeout option. + * \param opts The option string + * \param timeout value after parsing. + */ + int GetTimeOutFromOpts(std::string opts) { + std::string cmd; + std::string option = "-timeout="; + + if (opts.find(option) == 0) { + cmd = opts.substr(opts.find_last_of(option) + 1); + CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; + return std::stoi(cmd); + } + return 0; + } + + std::string host_; + int port_; + int my_port_; + int port_end_; + std::string tracker_addr_; + std::string key_; + std::string custom_addr_; + bool is_proxy_; + common::TCPSocket listen_sock_; + common::TCPSocket tracker_sock_; +}; + +/*! + * \brief RPCServerCreate Creates the RPC Server. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param silent Whether run in silent mode. Default=True + * \param isProxy Whether to run in proxy mode. Default=False + */ +void RPCServerCreate(std::string host, + int port, + int port_end, + std::string tracker_addr, + std::string key, + std::string custom_addr, + bool silent, + bool is_proxy) { + if (silent) { + // Only errors and fatal is logged + dmlc::InitLogging("--minloglevel=2"); + } + // Start the rpc server + RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr, is_proxy); + rpc.Start(); +} + +TVM_REGISTER_GLOBAL("rpc._ServerCreate") +.set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); + }); +} // namespace runtime +} // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h new file mode 100644 index 000000000000..9d39fc6713b0 --- /dev/null +++ b/apps/cpp_rpc/rpc_server.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_server.h + * \brief RPC Server implementation. + */ +#ifndef TVM_APPS_CPP_RPC_SERVER_H_ +#define TVM_APPS_CPP_RPC_SERVER_H_ + +#include +#include "tvm/runtime/c_runtime_api.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief RPCServerCreate Creates the RPC Server. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param silent Whether run in silent mode. Default=True + * \param isProxy Whether to run in proxy mode. Default=False + */ +TVM_DLL void RPCServerCreate(std::string host = "", + int port = 9090, + int port_end = 9099, + std::string tracker_addr = "", + std::string key = "", + std::string custom_addr = "", + bool silent = true, + bool is_proxy = false); +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h new file mode 100644 index 000000000000..f4db35a7030c --- /dev/null +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_tracker_client.h + * \brief RPC Tracker client to report resources. + */ +#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ +#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/common/socket.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief TrackerClient Tracker client class. + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + */ +class TrackerClient { + public: + /*! + * \brief Constructor. + */ + TrackerClient(const std::string &tracker_addr, + const std::string &key, + const std::string &custom_addr) { + tracker_addr_ = tracker_addr; + key_ = key; + custom_addr_ = custom_addr; + } + /*! + * \brief Destructor. + */ + ~TrackerClient() { + // Free the resources + Close(); + } + /*! + * \brief IsValid Check tracker is valid. + */ + bool IsValid() { + return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); + } + /*! + * \brief TryConnect Connect to tracker if the tracker address is valid. + */ + void TryConnect() { + if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) { + tracker_sock_ = ConnectWithRetry(); + + int code = kRPCTrackerMagic; + CHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kUpdateInfo) + << ", {\"key\": \"server:"<< key_ << "\"}]"; + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + } + } + /*! + * \brief Close Clean up tracker resources. + */ + void Close() { + // close tracker resource + if (!tracker_sock_.IsClosed()) { + tracker_sock_.Close(); + } + } + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param port listening port. + * \param matchkey Random match key output. + */ + void ReportResourceAndGetKey(int port, + std::string *matchkey) { + if (!tracker_sock_.IsClosed()) { + *matchkey = RandomKey(key_ + ":", old_keyset_); + if (custom_addr_.empty()) { + custom_addr_ = "null"; + } + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" + << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + } else { + *matchkey = key_; + } + } + + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param listen_sock Listen socket details for select. + * \param port listening port. + * \param ping_period Select wait time. + * \param matchkey Random match key output. + */ + void WaitConnectionAndUpdateKey(common::TCPSocket listen_sock, + int port, + int ping_period, + std::string *matchkey) { + int unmatch_period_count = 0; + int unmatch_timeout = 4; + while (1) { + if (!tracker_sock_.IsClosed()) { + common::SelectHelper selecter; + selecter.WatchRead(listen_sock.sockfd); + + int ready = selecter.Select(ping_period * 1000); + if ((ready <= 0) || (!selecter.CheckRead(listen_sock.sockfd))) { + std::ostringstream ss; + ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]"; + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string pending_keys = tracker_sock_.RecvBytes(); + old_keyset_.insert(*matchkey); + + // if match key not in pending key set + // it means the key is acquired by a client but not used. + if (pending_keys.find(*matchkey) == std::string::npos) { + unmatch_period_count += 1; + } else { + unmatch_period_count = 0; + } + // regenerate match key if key is acquired but not used for a while + if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) { + LOG(INFO) << "no incoming connections, regenerate key ..."; + + *matchkey = RandomKey(key_ + ":", old_keyset_); + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" + << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + tracker_sock_.SendBytes(ss.str()); + + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + unmatch_period_count = 0; + } + continue; + } + } + break; + } + } + + private: + /*! + * \brief Connect to a RPC address with retry. + This function is only reliable to short period of server restart. + * \param timeout Timeout during retry + * \param retry_period Number of seconds before we retry again. + * \return TCPSocket The socket information if connect is success. + */ + common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) { + auto tbegin = std::chrono::system_clock::now(); + while (1) { + common::SockAddr addr(tracker_addr_); + common::TCPSocket sock; + sock.Create(); + LOG(INFO) << "Tracker connecting to " << addr.AsString(); + if (sock.Connect(addr)) { + return sock; + } + + auto period = (std::chrono::duration_cast( + std::chrono::system_clock::now() - tbegin)).count(); + CHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); + LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() + << " retry in " << retry_period << " seconds."; + std::this_thread::sleep_for(std::chrono::seconds(retry_period)); + } + } + /*! + * \brief Random Generate a random number between 0 and 1. + * \return random float value. + */ + float Random() { + std::random_device rd; // Will be used to obtain a seed for the random number engine + std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd() + std::uniform_real_distribution<> dis(0.0, 1.0); + return dis(gen); + } + /*! + * \brief Generate a random key. + * \param prefix The string prefix. + * \return cmap The conflict map set. + */ + std::string RandomKey(const std::string& prefix, const std::set &cmap) { + if (!cmap.empty()) { + while (1) { + std::string key = prefix + std::to_string(Random()); + if (cmap.find(key) == cmap.end()) { + return key; + } + } + } + return prefix + std::to_string(Random()); + } + + std::string tracker_addr_; + std::string key_; + std::string custom_addr_; + common::TCPSocket tracker_sock_; + std::set old_keyset_; +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ diff --git a/src/common/socket.h b/src/common/socket.h index 39bcff863c10..7d3ae4284a9c 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -43,11 +43,14 @@ using ssize_t = int; #include #include #include +#include #include #endif #include #include #include +#include +#include "../common/util.h" namespace tvm { @@ -62,6 +65,25 @@ inline std::string GetHostName() { return std::string(buf.c_str()); } +/*! + * \brief ValidateIP validates an ip address. + * \param ip The ip address in string format localhost or x.x.x.x format + * \return result of operation. + */ +inline bool ValidateIP(std::string ip) { + if (ip == "localhost") { + return true; + } + std::vector list = Split(ip, '.'); + if (list.size() != 4) + return false; + for (std::string str : list) { + if (!IsNumber(str) || std::stoi(str) > 255 || std::stoi(str) < 0) + return false; + } + return true; +} + /*! * \brief Common data structure for network address. */ @@ -76,6 +98,23 @@ struct SockAddr { SockAddr(const char *url, int port) { this->Set(url, port); } + + /*! + * \brief SockAddr Get the socket address from tracker. + * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) + * \return SockAddr parsed from url. + */ + explicit SockAddr(const std::string &url) { + size_t sep = url.find(","); + std::string host = url.substr(2, sep - 3); + std::string port = url.substr(sep + 1, url.length() - 1); + CHECK(ValidateIP(host)) << "Url address is not valid " << url; + if (host == "localhost") { + host = "127.0.0.1"; + } + this->Set(host.c_str(), std::stoi(port)); + } + /*! * \brief set the address * \param host the url of the address @@ -203,17 +242,20 @@ class Socket { } /*! * \brief try bind the socket to host, from start_port to end_port + * \param host host_address to bind the socket * \param start_port starting port number to try * \param end_port ending port number to try * \return the port successfully bind to, return -1 if failed to bind any port */ - inline int TryBindHost(int start_port, int end_port) { + inline int TryBindHost(std::string host, int start_port, int end_port) { for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); + SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast(&addr.addr), (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0) { return port; + } else { + LOG(WARNING) << "Bind failed to " << host << ":" << port; } #if defined(_WIN32) if (WSAGetLastError() != WSAEADDRINUSE) { @@ -373,6 +415,20 @@ class TCPSocket : public Socket { } return TCPSocket(newfd); } + /*! + * \brief get a new connection + * \param addr client address from which connection accepted + * \return The accepted socket connection. + */ + TCPSocket Accept(SockAddr *addr) { + socklen_t addrlen = sizeof(addr->addr); + SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), + &addrlen); + if (newfd == INVALID_SOCKET) { + Socket::Error("Accept"); + } + return TCPSocket(newfd); + } /*! * \brief decide whether the socket is at OOB mark * \return 1 if at mark, 0 if not, -1 if an error occurred @@ -468,7 +524,134 @@ class TCPSocket : public Socket { } return ndone; } + /*! + * \brief Send the data to remote. + * \param data The data to be sent. + */ + void SendBytes(std::string data) { + int datalen = data.length(); + CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); + CHECK_EQ(SendAll(data.c_str(), datalen), datalen); + } + /*! + * \brief Receive the data to remote. + * \return The data received. + */ + std::string RecvBytes() { + int datalen = 0; + CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); + std::string data; + data.resize(datalen); + CHECK_EQ(RecvAll(&data[0], datalen), datalen); + return data; + } }; + +/*! \brief helper data structure to perform select */ +struct SelectHelper { + public: + SelectHelper(void) { + FD_ZERO(&read_set); + FD_ZERO(&write_set); + FD_ZERO(&except_set); + maxfd = 0; + } + /*! + * \brief add file descriptor to watch for read + * \param fd file descriptor to be watched + */ + inline void WatchRead(TCPSocket::SockType fd) { + FD_SET(fd, &read_set); + if (fd > maxfd) maxfd = fd; + } + /*! + * \brief add file descriptor to watch for write + * \param fd file descriptor to be watched + */ + inline void WatchWrite(TCPSocket::SockType fd) { + FD_SET(fd, &write_set); + if (fd > maxfd) maxfd = fd; + } + /*! + * \brief add file descriptor to watch for exception + * \param fd file descriptor to be watched + */ + inline void WatchException(TCPSocket::SockType fd) { + FD_SET(fd, &except_set); + if (fd > maxfd) maxfd = fd; + } + /*! + * \brief Check if the descriptor is ready for read + * \param fd file descriptor to check status + */ + inline bool CheckRead(TCPSocket::SockType fd) const { + return FD_ISSET(fd, &read_set) != 0; + } + /*! + * \brief Check if the descriptor is ready for write + * \param fd file descriptor to check status + */ + inline bool CheckWrite(TCPSocket::SockType fd) const { + return FD_ISSET(fd, &write_set) != 0; + } + /*! + * \brief Check if the descriptor has any exception + * \param fd file descriptor to check status + */ + inline bool CheckExcept(TCPSocket::SockType fd) const { + return FD_ISSET(fd, &except_set) != 0; + } + /*! + * \brief wait for exception event on a single descriptor + * \param fd the file descriptor to wait the event for + * \param timeout the timeout counter, can be 0, which means wait until the event happen + * \return 1 if success, 0 if timeout, and -1 if error occurs + */ + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = 0) { // NOLINT(*) + fd_set wait_set; + FD_ZERO(&wait_set); + FD_SET(fd, &wait_set); + return Select_(static_cast(fd + 1), + NULL, NULL, &wait_set, timeout); + } + /*! + * \brief peform select on the set defined + * \param select_read whether to watch for read event + * \param select_write whether to watch for write event + * \param select_except whether to watch for exception event + * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block + * \return number of active descriptors selected, + * return -1 if error occurs + */ + inline int Select(long timeout = 0) { // NOLINT(*) + int ret = Select_(static_cast(maxfd + 1), + &read_set, &write_set, &except_set, timeout); + if (ret == -1) { + Socket::Error("Select"); + } + return ret; + } + + private: + inline static int Select_(int maxfd, fd_set *rfds, + fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*) +#if !defined(_WIN32) + CHECK(maxfd < FD_SETSIZE) << "maxfd must be smaller than FDSETSIZE"; +#endif + if (timeout == 0) { + return select(maxfd, rfds, wfds, efds, NULL); + } else { + timeval tm; + tm.tv_usec = (timeout % 1000) * 1000; + tm.tv_sec = timeout / 1000; + return select(maxfd, rfds, wfds, efds, &tm); + } + } + + TCPSocket::SockType maxfd; + fd_set read_set, write_set, except_set; +}; + } // namespace common } // namespace tvm #endif // TVM_COMMON_SOCKET_H_ diff --git a/src/common/util.h b/src/common/util.h new file mode 100644 index 000000000000..2430cde3472c --- /dev/null +++ b/src/common/util.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file util.h + * \brief Defines some common utility function.. + */ +#ifndef TVM_COMMON_UTIL_H_ +#define TVM_COMMON_UTIL_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace common { + +/*! + * \brief IsNumber check whether string is a number. + * \param str input string + * \return result of operation. + */ +inline bool IsNumber(const std::string& str) { + return !str.empty() && std::find_if(str.begin(), + str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); +} + +/*! + * \brief split Split the string based on delimiter + * \param str Input string + * \param delim The delimiter. + * \return vector of strings which are splitted. + */ +inline std::vector Split(const std::string& str, char delim) { + std::string item; + std::istringstream is(str); + std::vector ret; + while (std::getline(is, item, delim)) { + ret.push_back(item); + } + return ret; +} + +/*! + * \brief EndsWith check whether the strings ends with + * \param value The full string + * \param end The end substring + * \return bool The result. + */ +inline bool EndsWith(std::string const & value, std::string const & end) { + if (end.size() <= value.size()) { + return std::equal(end.rbegin(), end.rend(), value.rbegin()); + } + return false; +} + +} // namespace common +} // namespace tvm +#endif // TVM_COMMON_UTIL_H_ \ No newline at end of file diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index d982f68bcb6e..611a36bda8df 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -36,8 +36,29 @@ namespace tvm { namespace runtime { +// Magic header for RPC data plane const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// duplicate key in proxy +const int kRPCDupicate = kRPCMagic + 1; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; /*! \brief The remote functio handle */ using RPCFuncHandle = void*; diff --git a/src/runtime/rpc/rpc_socket_impl.h b/src/runtime/rpc/rpc_socket_impl.h new file mode 100644 index 000000000000..33abcef1af9e --- /dev/null +++ b/src/runtime/rpc/rpc_socket_impl.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_socket_impl.h + * \brief Socket based RPC implementation. + */ +#ifndef TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ +#define TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ + +namespace tvm { +namespace runtime { + +/*! + * \brief RPCServerLoop Start the rpc server loop. + * \param sockfd Socket file descriptor + */ +void RPCServerLoop(int sockfd); + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ \ No newline at end of file From 24bbff2a4bf235af44eac989339bac7114993242 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Fri, 8 Nov 2019 13:35:21 +0800 Subject: [PATCH 02/13] add ASF header --- apps/cpp_rpc/Makefile | 17 +++++++++++++++++ apps/cpp_rpc/README.md | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 5b464f9123ca..450c2b9135d3 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # Makefile to compile RPC Server. TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index f7846408bab1..07bba414919f 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -1,3 +1,20 @@ + + + + + + + + + + + + + + + + + # TVM RPC Server This folder contains a simple recipe to make RPC server in c++. From cf03d76922996ddf8d21a1ef22a1cf61ae591977 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Fri, 8 Nov 2019 13:41:03 +0800 Subject: [PATCH 03/13] CPPLint need whitespace at end of file --- src/common/util.h | 2 +- src/runtime/rpc/rpc_socket_impl.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/util.h b/src/common/util.h index 2430cde3472c..0d2c5eb73f3d 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -74,4 +74,4 @@ inline bool EndsWith(std::string const & value, std::string const & end) { } // namespace common } // namespace tvm -#endif // TVM_COMMON_UTIL_H_ \ No newline at end of file +#endif // TVM_COMMON_UTIL_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.h b/src/runtime/rpc/rpc_socket_impl.h index 33abcef1af9e..ea7c8394bff8 100644 --- a/src/runtime/rpc/rpc_socket_impl.h +++ b/src/runtime/rpc/rpc_socket_impl.h @@ -36,4 +36,4 @@ void RPCServerLoop(int sockfd); } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ \ No newline at end of file +#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ From c2e82279d49f09335602b232fedc7959a6b2d898 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sat, 9 Nov 2019 03:56:44 +0800 Subject: [PATCH 04/13] Correct as Tianqi's suggestion --- apps/cpp_rpc/main.cc | 1 - apps/cpp_rpc/rpc_env.cc | 5 +- apps/cpp_rpc/rpc_env.h | 18 ++--- apps/cpp_rpc/rpc_server.cc | 38 +++++----- apps/cpp_rpc/rpc_server.h | 1 - apps/cpp_rpc/rpc_tracker_client.h | 35 +++++---- src/common/socket.h | 115 +++++++++++++++--------------- 7 files changed, 101 insertions(+), 112 deletions(-) diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 1c8b01a2f69e..689f9c8f8922 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_server.cc * \brief RPC Server for TVM. */ diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5c2a8c186e2..f6d4a5842ae6 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_env.cc * \brief Server environment of the RPC. */ @@ -71,7 +70,7 @@ RPCEnv::RPCEnv() { * \param name The file name * \return The full path of file. */ -std::string RPCEnv::GetPath(const std::string& file_name) { +std::string RPCEnv::GetPath(std::string file_name) { // we assume file_name has "/" means file_name is the exact path // and does not create /.rpc/ if (file_name.find("/") != std::string::npos) { @@ -83,7 +82,7 @@ std::string RPCEnv::GetPath(const std::string& file_name) { /*! * \brief Remove The RPC Environment cleanup function */ -void RPCEnv::Remove() { +void RPCEnv::CleanUp() { #if defined(__linux__) || defined(__ANDROID__) CleanDir(&base_[0]); int ret = rmdir(&base_[0]); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index becdf5daf9dd..58df23792d21 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_env.h * \brief Server environment of the RPC. */ @@ -26,10 +25,6 @@ #define TVM_APPS_CPP_RPC_ENV_H_ #include -#if defined(__linux__) || defined(__ANDROID__) -#include -#include -#endif #include namespace tvm { @@ -58,21 +53,22 @@ void CleanDir(const std::string &dirname); */ struct RPCEnv { public: - /*! - * \brief Constructor Init The RPC Environment initialize function - */ + /*! + * \brief Constructor Init The RPC Environment initialize function + */ RPCEnv(); /*! * \brief GetPath To get the workpath from packed function * \param name The file name * \return The full path of file. */ - std::string GetPath(const std::string& file_name); + std::string GetPath(std::string file_name); /*! - * \brief Remove The RPC Environment cleanup function + * \brief The RPC Environment cleanup function */ - void Remove(); + void CleanUp(); + private: /*! * \base_ Holds the environment path. */ diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 6eb244ca26c7..0c2edfb6fffa 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -18,11 +18,9 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_server.cc * \brief RPC Server implementation. */ - #include #if defined(__linux__) || defined(__ANDROID__) @@ -43,8 +41,11 @@ #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "../../src/common/socket.h" +namespace tvm { +namespace runtime { + #if defined(__linux__) || defined(__ANDROID__) -static pid_t waitpid_eintr(int *status) { +static pid_t waitPidEintr(int *status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { @@ -58,9 +59,6 @@ static pid_t waitpid_eintr(int *status) { } #endif -namespace tvm { -namespace runtime { - /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 @@ -69,7 +67,7 @@ namespace runtime { * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \param isProxy Whether to run in proxy mode. Default=False + * \param is_proxy Whether to run in proxy mode. Default=False */ class RPCServer { public: @@ -122,7 +120,7 @@ class RPCServer { */ void ListenLoopProc() { TrackerClient tracker(tracker_addr_, key_, custom_addr_); - while (1) { + while (true) { common::TCPSocket conn; common::SockAddr addr("0.0.0.0", 0); std::string opts; @@ -163,7 +161,7 @@ class RPCServer { } int status = 0; - const pid_t finished_first = waitpid_eintr(&status); + const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { kill(worker_pid, SIGKILL); } else if (finished_first == worker_pid) { @@ -173,7 +171,7 @@ class RPCServer { } int status_second = 0; - waitpid_eintr(&status_second); + waitPidEintr(&status_second); // Logging. if (finished_first == timer_pid) { @@ -221,10 +219,10 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient *tracker, - common::TCPSocket *conn_sock, - common::SockAddr *addr, - std::string *opts, + void AcceptConnection(TrackerClient* tracker, + common::TCPSocket* conn_sock, + common::SockAddr* addr, + std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -232,7 +230,7 @@ class RPCServer { // Report resource to tracker and get key tracker->ReportResourceAndGetKey(my_port_, &matchkey); - while (1) { + while (true) { tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey); common::TCPSocket conn = listen_sock_.Accept(addr); @@ -247,17 +245,17 @@ class RPCServer { int keylen = 0; CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); - #define CLIENT_HEADER "client:" - #define SERVER_HEADER "server:" - + const char* CLIENT_HEADER = "client:"; + const char* SERVER_HEADER = "server:"; std::string expect_header = CLIENT_HEADER + matchkey; std::string server_key = SERVER_HEADER + key_; if (size_t(keylen) < expect_header.length()) { conn.Close(); - LOG(FATAL) << "Wrong client header length"; + LOG(INFO) << "Wrong client header length"; continue; } + CHECK_NE(keylen, 0); std::string remote_key; remote_key.resize(keylen); CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); @@ -295,7 +293,7 @@ class RPCServer { auto env = RPCEnv(); RPCServerLoop(sock.sockfd); LOG(INFO) << "Finish serving " << addr.AsString(); - env.Remove(); + env.CleanUp(); } /*! diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 9d39fc6713b0..2c255353c29b 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_server.h * \brief RPC Server implementation. */ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index f4db35a7030c..89424c7511f0 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file rpc_tracker_client.h * \brief RPC Tracker client to report resources. */ @@ -49,12 +48,11 @@ class TrackerClient { /*! * \brief Constructor. */ - TrackerClient(const std::string &tracker_addr, - const std::string &key, - const std::string &custom_addr) { - tracker_addr_ = tracker_addr; - key_ = key; - custom_addr_ = custom_addr; + TrackerClient(const std::string& tracker_addr, + const std::string& key, + const std::string& custom_addr) + : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), + gen_(std::random_device{}()), dis_(0.0, 1.0) { } /*! * \brief Destructor. @@ -140,13 +138,12 @@ class TrackerClient { std::string *matchkey) { int unmatch_period_count = 0; int unmatch_timeout = 4; - while (1) { + while (true) { if (!tracker_sock_.IsClosed()) { - common::SelectHelper selecter; - selecter.WatchRead(listen_sock.sockfd); - - int ready = selecter.Select(ping_period * 1000); - if ((ready <= 0) || (!selecter.CheckRead(listen_sock.sockfd))) { + common::PollHelper poller; + poller.WatchRead(listen_sock.sockfd); + poller.Poll(ping_period * 1000); + if (!poller.CheckRead(listen_sock.sockfd)) { std::ostringstream ss; ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]"; tracker_sock_.SendBytes(ss.str()); @@ -194,7 +191,7 @@ class TrackerClient { */ common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) { auto tbegin = std::chrono::system_clock::now(); - while (1) { + while (true) { common::SockAddr addr(tracker_addr_); common::TCPSocket sock; sock.Create(); @@ -216,10 +213,7 @@ class TrackerClient { * \return random float value. */ float Random() { - std::random_device rd; // Will be used to obtain a seed for the random number engine - std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd() - std::uniform_real_distribution<> dis(0.0, 1.0); - return dis(gen); + return dis_(gen_); } /*! * \brief Generate a random key. @@ -228,7 +222,7 @@ class TrackerClient { */ std::string RandomKey(const std::string& prefix, const std::set &cmap) { if (!cmap.empty()) { - while (1) { + while (true) { std::string key = prefix + std::to_string(Random()); if (cmap.find(key) == cmap.end()) { return key; @@ -243,6 +237,9 @@ class TrackerClient { std::string custom_addr_; common::TCPSocket tracker_sock_; std::set old_keyset_; + std::mt19937 gen_; + std::uniform_real_distribution dis_; + }; } // namespace runtime } // namespace tvm diff --git a/src/common/socket.h b/src/common/socket.h index 7d3ae4284a9c..01d4ccfdd6f3 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -50,8 +50,17 @@ using ssize_t = int; #include #include #include +#include #include "../common/util.h" +#if defined(_WIN32) +typedef int sock_size_t; + +static inline int poll(struct pollfd *pfd, int nfds, + int timeout) { return WSAPoll ( pfd, nfds, timeout ); } +#else +#include +#endif // defined(_WIN32) namespace tvm { namespace common { @@ -75,9 +84,10 @@ inline bool ValidateIP(std::string ip) { return true; } std::vector list = Split(ip, '.'); - if (list.size() != 4) + if (list.size() != 4) { return false; - for (std::string str : list) { + } + for (const auto& str : list) { if (!IsNumber(str) || std::stoi(str) > 255 || std::stoi(str) < 0) return false; } @@ -547,109 +557,100 @@ class TCPSocket : public Socket { } }; -/*! \brief helper data structure to perform select */ -struct SelectHelper { +/*! \brief helper data structure to perform poll */ +struct PollHelper { public: - SelectHelper(void) { - FD_ZERO(&read_set); - FD_ZERO(&write_set); - FD_ZERO(&except_set); - maxfd = 0; - } /*! * \brief add file descriptor to watch for read * \param fd file descriptor to be watched */ inline void WatchRead(TCPSocket::SockType fd) { - FD_SET(fd, &read_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLIN; } /*! * \brief add file descriptor to watch for write * \param fd file descriptor to be watched */ inline void WatchWrite(TCPSocket::SockType fd) { - FD_SET(fd, &write_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLOUT; } /*! * \brief add file descriptor to watch for exception * \param fd file descriptor to be watched */ inline void WatchException(TCPSocket::SockType fd) { - FD_SET(fd, &except_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLPRI; } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status */ inline bool CheckRead(TCPSocket::SockType fd) const { - return FD_ISSET(fd, &read_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); } /*! * \brief Check if the descriptor is ready for write * \param fd file descriptor to check status */ inline bool CheckWrite(TCPSocket::SockType fd) const { - return FD_ISSET(fd, &write_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } /*! * \brief Check if the descriptor has any exception * \param fd file descriptor to check status */ inline bool CheckExcept(TCPSocket::SockType fd) const { - return FD_ISSET(fd, &except_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); } /*! * \brief wait for exception event on a single descriptor * \param fd the file descriptor to wait the event for - * \param timeout the timeout counter, can be 0, which means wait until the event happen + * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline static int WaitExcept(TCPSocket::SockType fd, long timeout = 0) { // NOLINT(*) - fd_set wait_set; - FD_ZERO(&wait_set); - FD_SET(fd, &wait_set); - return Select_(static_cast(fd + 1), - NULL, NULL, &wait_set, timeout); - } - /*! - * \brief peform select on the set defined - * \param select_read whether to watch for read event - * \param select_write whether to watch for write event - * \param select_except whether to watch for exception event - * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block - * \return number of active descriptors selected, - * return -1 if error occurs - */ - inline int Select(long timeout = 0) { // NOLINT(*) - int ret = Select_(static_cast(maxfd + 1), - &read_set, &write_set, &except_set, timeout); - if (ret == -1) { - Socket::Error("Select"); - } - return ret; + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLPRI; + return poll(&pfd, 1, timeout); } - private: - inline static int Select_(int maxfd, fd_set *rfds, - fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*) -#if !defined(_WIN32) - CHECK(maxfd < FD_SETSIZE) << "maxfd must be smaller than FDSETSIZE"; -#endif - if (timeout == 0) { - return select(maxfd, rfds, wfds, efds, NULL); + /*! + * \brief peform poll on the set defined, read, write, exception + * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block + * \return + */ + inline void Poll(long timeout = -1) { // NOLINT(*) + std::vector fdset; + fdset.reserve(fds.size()); + for (auto kv : fds) { + fdset.push_back(kv.second); + } + int ret = poll(fdset.data(), fdset.size(), timeout); + if (ret == -1) { + Socket::Error("Poll"); } else { - timeval tm; - tm.tv_usec = (timeout % 1000) * 1000; - tm.tv_sec = timeout / 1000; - return select(maxfd, rfds, wfds, efds, &tm); + for (auto& pfd : fdset) { + auto revents = pfd.revents & pfd.events; + if (!revents) { + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; + } + } } } - TCPSocket::SockType maxfd; - fd_set read_set, write_set, except_set; + std::unordered_map fds; }; } // namespace common From 1686822f8ac528fe90b5e22545b33f87bebf3fe8 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sat, 9 Nov 2019 13:31:49 +0800 Subject: [PATCH 05/13] modify validateip implementation --- src/common/socket.h | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/common/socket.h b/src/common/socket.h index 01d4ccfdd6f3..73bb3fcab6c7 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -54,10 +54,13 @@ using ssize_t = int; #include "../common/util.h" #if defined(_WIN32) -typedef int sock_size_t; - static inline int poll(struct pollfd *pfd, int nfds, - int timeout) { return WSAPoll ( pfd, nfds, timeout ); } + int timeout) { + return WSAPoll ( pfd, nfds, timeout ); +} +static inline int inet_pton(int family, const char* addr_str, void* addr_buf) { + return InetPton(family, addr_str, addr_buf); +} #else #include #endif // defined(_WIN32) @@ -83,15 +86,11 @@ inline bool ValidateIP(std::string ip) { if (ip == "localhost") { return true; } - std::vector list = Split(ip, '.'); - if (list.size() != 4) { - return false; - } - for (const auto& str : list) { - if (!IsNumber(str) || std::stoi(str) > 255 || std::stoi(str) < 0) - return false; - } - return true; + struct sockaddr_in sa_ipv4; + struct sockaddr_in6 sa_ipv6; + bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr)); + bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr)); + return is_ipv4 || is_ipv6; } /*! From 116be2d98efb86e3f65db70d9dcc276c57e666a8 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sat, 9 Nov 2019 13:53:21 +0800 Subject: [PATCH 06/13] Code Format --- src/common/socket.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/socket.h b/src/common/socket.h index 73bb3fcab6c7..247a2896540e 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -56,7 +56,7 @@ using ssize_t = int; #if defined(_WIN32) static inline int poll(struct pollfd *pfd, int nfds, int timeout) { - return WSAPoll ( pfd, nfds, timeout ); + return WSAPoll(pfd, nfds, timeout); } static inline int inet_pton(int family, const char* addr_str, void* addr_buf) { return InetPton(family, addr_str, addr_buf); From 331cfd3d854d2bc30f06804dae8cc599d802db84 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sat, 9 Nov 2019 17:29:30 +0800 Subject: [PATCH 07/13] Code Style Format --- apps/cpp_rpc/main.cc | 6 +++--- apps/cpp_rpc/rpc_server.cc | 2 +- apps/cpp_rpc/rpc_server.h | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 689f9c8f8922..18d14a1b0b3a 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -65,7 +65,7 @@ static const string kUSAGE = \ * \arg key The key used to identify the device type in tracker. Default="" * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \arg silent Whether run in silent mode. Default=True - * \arg isProxy Whether to run in proxy mode. Default=False + * \arg is_proxy Whether to run in proxy mode. Default=False */ struct RpcServerArgs { string host = "0.0.0.0"; @@ -75,7 +75,7 @@ struct RpcServerArgs { string key; string custom_addr; bool silent = false; - bool isProxy = false; + bool is_proxy = false; }; /*! @@ -90,7 +90,7 @@ void PrintArgs(struct RpcServerArgs args) { LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); - LOG(INFO) << "proxy = " << ((args.isProxy) ? ("True"): ("False")); + LOG(INFO) << "proxy = " << ((args.is_proxy) ? ("True"): ("False")); } /*! diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 0c2edfb6fffa..f0649e7f0be9 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -334,7 +334,7 @@ class RPCServer { * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True - * \param isProxy Whether to run in proxy mode. Default=False + * \param is_proxy Whether to run in proxy mode. Default=False */ void RPCServerCreate(std::string host, int port, diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 2c255353c29b..c88e9b4750a5 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -39,7 +39,7 @@ namespace runtime { * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True - * \param isProxy Whether to run in proxy mode. Default=False + * \param is_proxy Whether to run in proxy mode. Default=False */ TVM_DLL void RPCServerCreate(std::string host = "", int port = 9090, From e576ae73ad4afa87fa7498ae848de066298de3b9 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 10 Nov 2019 03:00:35 +0800 Subject: [PATCH 08/13] Add popen --- apps/cpp_rpc/main.cc | 2 +- apps/cpp_rpc/rpc_env.cc | 12 ++++++----- apps/cpp_rpc/rpc_server.cc | 20 ++++++++--------- src/common/util.h | 44 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 16 deletions(-) diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 18d14a1b0b3a..b68ce501877c 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -111,7 +111,7 @@ void HandleCtrlC() { sigIntHandler.sa_handler = CtrlCHandler; sigemptyset(&sigIntHandler.sa_mask); sigIntHandler.sa_flags = 0; - sigaction(SIGINT, &sigIntHandler, NULL); + sigaction(SIGINT, &sigIntHandler, nullptr); } /*! diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index f6d4a5842ae6..79fa3366eac2 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -103,12 +103,12 @@ std::vector ListDir(const std::string &dirname) { std::vector vec; #ifndef _MSC_VER DIR *dp = opendir(dirname.c_str()); - if (dp == NULL) { + if (dp == nullptr) { int errsv = errno; LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); } dirent *d; - while ((d = readdir(dp)) != NULL) { + while ((d = readdir(dp)) != nullptr) { std::string filename = d->d_name; if (filename != "." && filename != "..") { std::string f = dirname; @@ -164,7 +164,8 @@ void LinuxShared(const std::string output, cmd += " " + *f; } cmd += " " + options; - CHECK(system(cmd.c_str()) == 0) << "Compilation error."; + std::string executed_result = common::Execute(cmd); + CHECK(executed_result.length() == 0) << executed_result; } /*! @@ -206,7 +207,8 @@ Module Load(std::string *fileIn, const std::string fmt) { std::string tmp_dir = "./rpc/tmp/"; mkdir(&tmp_dir[0], 0777); std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; - CHECK(system(cmd.c_str()) == 0) << "Untar library error."; + std::string executed_result = common::Execute(cmd); + CHECK(executed_result.length() == 0) << executed_result; CreateShared(file_name, ListDir(tmp_dir)); CleanDir(tmp_dir); rmdir(&tmp_dir[0]); @@ -228,7 +230,7 @@ void CleanDir(const std::string &dirname) { #if defined(__linux__) || defined(__ANDROID__) DIR *dp = opendir(dirname.c_str()); dirent *d; - while ((d = readdir(dp)) != NULL) { + while ((d = readdir(dp)) != nullptr) { std::string filename = d->d_name; if (filename != "." && filename != "..") { filename = dirname + "/" + d->d_name; diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index f0649e7f0be9..fef616b7da40 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -104,14 +104,14 @@ class RPCServer { * \brief Start Creates the RPC listen process and execution. */ void Start() { - listen_sock_.Create(); - my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); - LOG(INFO) << "bind to " << host_ << ":" << my_port_; - listen_sock_.Listen(1); - std::future proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); - proc.get(); - // Close the listen socket - listen_sock_.Close(); + listen_sock_.Create(); + my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); + LOG(INFO) << "bind to " << host_ << ":" << my_port_; + listen_sock_.Listen(1); + std::future proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); + proc.get(); + // Close the listen socket + listen_sock_.Close(); } private: @@ -145,7 +145,7 @@ class RPCServer { int timeout = GetTimeOutFromOpts(opts); #if defined(__linux__) || defined(__ANDROID__) // step 3: serving - if (timeout) { + if (timeout != 0) { const pid_t timer_pid = fork(); if (timer_pid == 0) { // Timer process @@ -196,7 +196,7 @@ class RPCServer { std::future proc(std::async(std::launch::async, &RPCServer::ServerLoopProc, this, conn, addr)); // wait until server process finish or timeout - if (timeout) { + if (timeout != 0) { // Autoterminate after timeout proc.wait_for(std::chrono::seconds(timeout)); } else { diff --git a/src/common/util.h b/src/common/util.h index 0d2c5eb73f3d..2eb70406e388 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -25,14 +25,42 @@ #ifndef TVM_COMMON_UTIL_H_ #define TVM_COMMON_UTIL_H_ +#include #include #include #include #include +#include +#include namespace tvm { namespace common { +/*! + * \brief TVMPOpen wrapper of popen between windows / unix. + * \param command executed command + * \param type "r" is for reading or "w" for writing. + * \return normal standard stream + */ +inline FILE * TVMPOpen(const char* command, const char* type) { +#if defined(_WIN32) + return _popen(command, type); +#else + return popen(command, type); +#endif +} +/*! + * \brief TVMPClose wrapper of pclose between windows / linux + * \param stream the stream needed to be close. + * \return exit status + */ +inline int TVMPClose(FILE* stream) { +#if defined(_WIN32) + return _pclose(stream); +#else + return pclose(stream); +#endif +} /*! * \brief IsNumber check whether string is a number. * \param str input string @@ -72,6 +100,22 @@ inline bool EndsWith(std::string const & value, std::string const & end) { return false; } +/*! + * \brief Execute the command + * \param cmd The command we want to execute + * \return executed output message + */ +inline std::string Execute(std::string cmd) { + std::array buffer; + std::string result; + cmd += " 2>&1"; + std::unique_ptr pipe(TVMPOpen(cmd.c_str(), "r"), TVMPClose); + while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result += buffer.data(); + } + return result; +} + } // namespace common } // namespace tvm #endif // TVM_COMMON_UTIL_H_ From f6f0c6d9be8f44691ddd943eac8bb2ebbdc27429 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 10 Nov 2019 03:24:19 +0800 Subject: [PATCH 09/13] GCC 5 on Linux need pthread, so for Linux better add lpthread all. --- apps/cpp_rpc/Makefile | 10 +++++++++- apps/cpp_rpc/README.md | 4 +++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 450c2b9135d3..9cd39b446acc 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -19,13 +19,21 @@ TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core TVM_RUNTIME_DIR?= +OS?= + +# Android can not link pthrad, but Linux need. +ifeq ($(OS), Linux) +LINK_PTHREAD=-lpthread +else +LINK_PTHREAD= +endif PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include -PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR) +PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) $(LINK_PTHREAD) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR) ifeq ($(USE_GLOG), 1) PKG_CFLAGS += -DDMLC_USE_GLOG=1 diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index 07bba414919f..f8d80040e257 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -21,7 +21,9 @@ This folder contains a simple recipe to make RPC server in c++. ## Usage - Build tvm runtime - Make the rpc executable [Makefile](Makefile). - `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/` + `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux` + if you want to compile it for embedded Linux, you should add `OS=Linux`. + if the target os is Android, you doesn't need to pass OS argument. You could cross compile the TVM runtime like this: ``` cd tvm From 33f82c0b65e7f912978573602e33a4c1f2f561a3 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 10 Nov 2019 16:14:40 +0800 Subject: [PATCH 10/13] popen implementation --- apps/cpp_rpc/README.md | 4 +-- apps/cpp_rpc/main.cc | 8 +++-- apps/cpp_rpc/rpc_env.cc | 15 ++++++--- apps/cpp_rpc/rpc_env.h | 2 +- apps/cpp_rpc/rpc_server.cc | 8 +++-- src/common/socket.h | 4 +-- src/common/util.h | 60 ++++++++++++++++++++++++++++++----- src/runtime/rpc/rpc_session.h | 2 +- 8 files changed, 80 insertions(+), 23 deletions(-) diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index f8d80040e257..6eaa8aaccc49 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -47,11 +47,11 @@ Command line usage --tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default="" --key - The key used to identify the device type in tracker. Default="" --custom-addr - Custom IP Address to Report to RPC Tracker. Default="" ---silent - Whether to run in silent mode. Default=True +--silent - Whether to run in silent mode. Default=False --proxy - Whether to run in proxy mode. Default=False Example ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp ``` ## Note -Currently support is only there for Linux / Android environment. \ No newline at end of file +Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently. \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index b68ce501877c..94d8246e3b34 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -48,7 +48,7 @@ static const string kUSAGE = \ "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ "--key - The key used to identify the device type in tracker. Default=\"\"\n" \ "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ -"--silent - Whether to run in silent mode. Default=True\n" \ +"--silent - Whether to run in silent mode. Default=False\n" \ "--proxy - Whether to run in proxy mode. Default=False\n" \ "\n" \ " Example\n" \ @@ -64,7 +64,7 @@ static const string kUSAGE = \ * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \arg key The key used to identify the device type in tracker. Default="" * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \arg silent Whether run in silent mode. Default=True + * \arg silent Whether run in silent mode. Default=False * \arg is_proxy Whether to run in proxy mode. Default=False */ struct RpcServerArgs { @@ -169,6 +169,10 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } + string proxy = GetCmdOption(argc, argv, "--proxy", true); + if (!proxy.empty()) { + args.is_proxy = true; + } string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 79fa3366eac2..c562e408a784 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - /*! * \file rpc_env.cc * \brief Server environment of the RPC. @@ -164,8 +163,11 @@ void LinuxShared(const std::string output, cmd += " " + *f; } cmd += " " + options; - std::string executed_result = common::Execute(cmd); - CHECK(executed_result.length() == 0) << executed_result; + std::string err_msg; + auto executed_status = common::Execute(cmd, err_msg); + if (executed_status) { + LOG(ERROR) << err_msg; + } } /*! @@ -207,8 +209,11 @@ Module Load(std::string *fileIn, const std::string fmt) { std::string tmp_dir = "./rpc/tmp/"; mkdir(&tmp_dir[0], 0777); std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; - std::string executed_result = common::Execute(cmd); - CHECK(executed_result.length() == 0) << executed_result; + std::string err_msg; + int executed_status = common::Execute(cmd, err_msg); + if (executed_status) { + LOG(ERROR) << err_msg; + } CreateShared(file_name, ListDir(tmp_dir)); CleanDir(tmp_dir); rmdir(&tmp_dir[0]); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index 58df23792d21..82409bae81a1 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -70,7 +70,7 @@ struct RPCEnv { private: /*! - * \base_ Holds the environment path. + * \brief Holds the environment path. */ std::string base_; }; // RPCEnv diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index fef616b7da40..6e803251f690 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -44,6 +44,10 @@ namespace tvm { namespace runtime { +/*! + * \brief wait the child process end. + * \param status status value + */ #if defined(__linux__) || defined(__ANDROID__) static pid_t waitPidEintr(int *status) { pid_t pid = 0; @@ -176,9 +180,9 @@ class RPCServer { // Logging. if (finished_first == timer_pid) { LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout - << "), Process status =" << status_second; + << "), Process status = " << status_second; } else if (finished_first == worker_pid) { - LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status =" << status_second; + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } } else { auto pid = fork(); diff --git a/src/common/socket.h b/src/common/socket.h index 247a2896540e..616991d021d1 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -251,7 +251,7 @@ class Socket { } /*! * \brief try bind the socket to host, from start_port to end_port - * \param host host_address to bind the socket + * \param host host address to bind the socket * \param start_port starting port number to try * \param end_port ending port number to try * \return the port successfully bind to, return -1 if failed to bind any port @@ -264,7 +264,7 @@ class Socket { sizeof(sockaddr_in))) == 0) { return port; } else { - LOG(WARNING) << "Bind failed to " << host << ":" << port; + LOG(WARNING) << "Bind failed to " << host << ":" << port; } #if defined(_WIN32) if (WSAGetLastError() != WSAEADDRINUSE) { diff --git a/src/common/util.h b/src/common/util.h index 2eb70406e388..30148f8bf891 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -26,6 +26,10 @@ #define TVM_COMMON_UTIL_H_ #include +#ifndef _WIN32 +#include +#include +#endif #include #include #include @@ -41,7 +45,7 @@ namespace common { * \param type "r" is for reading or "w" for writing. * \return normal standard stream */ -inline FILE * TVMPOpen(const char* command, const char* type) { +inline FILE* TVMPOpen(const char* command, const char* type) { #if defined(_WIN32) return _popen(command, type); #else @@ -61,6 +65,41 @@ inline int TVMPClose(FILE* stream) { return pclose(stream); #endif } + +/* + * gnulib sys_wait.h.in says on Windows + * When an unhandled fatal signal terminates a process, the exit code is 3. + * # define WIFSIGNALED(x) ((x) == 3) + * # define WIFEXITED(x) ((x) != 3) + * # define WIFSTOPPED(x) 0 + */ +/*! + * \brief TVMWifexited wrapper of WIFEXITED between windows / linux + * \param status The status field that was filled in by the wait or waitpid function + * \return the exit code of the child process + */ +inline int TVMWifexited(int status) { +#if defined(_WIN32) + return (status != 3); +#else + return WIFEXITED(status); +#endif +} + +/*! + * \brief TVMWexitstatus wrapper of WEXITSTATUS between windows / linux + * \param status The status field that was filled in by the wait or waitpid function. + * \return the child process exited normally or not + */ +inline int TVMWexitstatus(int status) { +#if defined(_WIN32) + return status; +#else + return WEXITSTATUS(status); +#endif +} + + /*! * \brief IsNumber check whether string is a number. * \param str input string @@ -93,7 +132,7 @@ inline std::vector Split(const std::string& str, char delim) { * \param end The end substring * \return bool The result. */ -inline bool EndsWith(std::string const & value, std::string const & end) { +inline bool EndsWith(std::string const& value, std::string const& end) { if (end.size() <= value.size()) { return std::equal(end.rbegin(), end.rend(), value.rbegin()); } @@ -103,17 +142,22 @@ inline bool EndsWith(std::string const & value, std::string const & end) { /*! * \brief Execute the command * \param cmd The command we want to execute - * \return executed output message + * \param err_msg The error message if we have + * \return executed output status */ -inline std::string Execute(std::string cmd) { +inline int Execute(std::string cmd, std::string& err_msg) { std::array buffer; std::string result; cmd += " 2>&1"; - std::unique_ptr pipe(TVMPOpen(cmd.c_str(), "r"), TVMPClose); - while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { - result += buffer.data(); + FILE* fd = TVMPOpen(cmd.c_str(), "r"); + while (fgets(buffer.data(), buffer.size(), fd) != nullptr) { + err_msg += buffer.data(); + } + int status = TVMPClose(fd); + if (TVMWifexited(status)) { + return TVMWexitstatus(status); } - return result; + return 255; } } // namespace common diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 611a36bda8df..8492a04b1dec 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -43,7 +43,7 @@ const int kRPCTrackerMagic = 0x2f271; // sucess response const int kRPCSuccess = kRPCMagic + 0; // duplicate key in proxy -const int kRPCDupicate = kRPCMagic + 1; +const int kRPCDuplicate = kRPCMagic + 1; // cannot found matched key in server const int kRPCMismatch = kRPCMagic + 2; From 3a7412d57d416c57cd8c14242bef43fe68a9357b Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 10 Nov 2019 16:23:29 +0800 Subject: [PATCH 11/13] Remove proxy unnessary code --- apps/cpp_rpc/README.md | 1 - apps/cpp_rpc/main.cc | 8 -------- apps/cpp_rpc/rpc_server.cc | 14 ++++---------- apps/cpp_rpc/rpc_server.h | 4 +--- src/runtime/rpc/rpc_session.h | 2 -- 5 files changed, 5 insertions(+), 24 deletions(-) diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index 6eaa8aaccc49..4baecaf25150 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -48,7 +48,6 @@ Command line usage --key - The key used to identify the device type in tracker. Default="" --custom-addr - Custom IP Address to Report to RPC Tracker. Default="" --silent - Whether to run in silent mode. Default=False ---proxy - Whether to run in proxy mode. Default=False Example ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp ``` diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 94d8246e3b34..3cf2ed6a5d59 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -49,7 +49,6 @@ static const string kUSAGE = \ "--key - The key used to identify the device type in tracker. Default=\"\"\n" \ "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ "--silent - Whether to run in silent mode. Default=False\n" \ -"--proxy - Whether to run in proxy mode. Default=False\n" \ "\n" \ " Example\n" \ " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " @@ -65,7 +64,6 @@ static const string kUSAGE = \ * \arg key The key used to identify the device type in tracker. Default="" * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \arg silent Whether run in silent mode. Default=False - * \arg is_proxy Whether to run in proxy mode. Default=False */ struct RpcServerArgs { string host = "0.0.0.0"; @@ -75,7 +73,6 @@ struct RpcServerArgs { string key; string custom_addr; bool silent = false; - bool is_proxy = false; }; /*! @@ -90,7 +87,6 @@ void PrintArgs(struct RpcServerArgs args) { LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); - LOG(INFO) << "proxy = " << ((args.is_proxy) ? ("True"): ("False")); } /*! @@ -169,10 +165,6 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } - string proxy = GetCmdOption(argc, argv, "--proxy", true); - if (!proxy.empty()) { - args.is_proxy = true; - } string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 6e803251f690..b35a63bd67dc 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -71,7 +71,6 @@ static pid_t waitPidEintr(int *status) { * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \param is_proxy Whether to run in proxy mode. Default=False */ class RPCServer { public: @@ -83,8 +82,7 @@ class RPCServer { int port_end, const std::string &tracker_addr, const std::string &key, - const std::string &custom_addr, - bool is_proxy) { + const std::string &custom_addr) { // Init the values host_ = host; port_ = port; @@ -92,7 +90,6 @@ class RPCServer { tracker_addr_ = tracker_addr; key_ = key; custom_addr_ = custom_addr; - is_proxy_ = is_proxy; } /*! @@ -324,7 +321,6 @@ class RPCServer { std::string tracker_addr_; std::string key_; std::string custom_addr_; - bool is_proxy_; common::TCPSocket listen_sock_; common::TCPSocket tracker_sock_; }; @@ -338,7 +334,6 @@ class RPCServer { * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True - * \param is_proxy Whether to run in proxy mode. Default=False */ void RPCServerCreate(std::string host, int port, @@ -346,20 +341,19 @@ void RPCServerCreate(std::string host, std::string tracker_addr, std::string key, std::string custom_addr, - bool silent, - bool is_proxy) { + bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr, is_proxy); + RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); rpc.Start(); } TVM_REGISTER_GLOBAL("rpc._ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index c88e9b4750a5..205182e4449a 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -39,7 +39,6 @@ namespace runtime { * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True - * \param is_proxy Whether to run in proxy mode. Default=False */ TVM_DLL void RPCServerCreate(std::string host = "", int port = 9090, @@ -47,8 +46,7 @@ TVM_DLL void RPCServerCreate(std::string host = "", std::string tracker_addr = "", std::string key = "", std::string custom_addr = "", - bool silent = true, - bool is_proxy = false); + bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 8492a04b1dec..3518455c83d1 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -42,8 +42,6 @@ const int kRPCMagic = 0xff271; const int kRPCTrackerMagic = 0x2f271; // sucess response const int kRPCSuccess = kRPCMagic + 0; -// duplicate key in proxy -const int kRPCDuplicate = kRPCMagic + 1; // cannot found matched key in server const int kRPCMismatch = kRPCMagic + 2; From b7d03d1e808f679b144b9ab631456f49b3868e89 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Sun, 10 Nov 2019 16:29:27 +0800 Subject: [PATCH 12/13] CppLint doesn't like pass by reference --- apps/cpp_rpc/rpc_env.cc | 4 ++-- src/common/util.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index c562e408a784..6b6ebd41fcb8 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -164,7 +164,7 @@ void LinuxShared(const std::string output, } cmd += " " + options; std::string err_msg; - auto executed_status = common::Execute(cmd, err_msg); + auto executed_status = common::Execute(cmd, &err_msg); if (executed_status) { LOG(ERROR) << err_msg; } @@ -210,7 +210,7 @@ Module Load(std::string *fileIn, const std::string fmt) { mkdir(&tmp_dir[0], 0777); std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; std::string err_msg; - int executed_status = common::Execute(cmd, err_msg); + int executed_status = common::Execute(cmd, &err_msg); if (executed_status) { LOG(ERROR) << err_msg; } diff --git a/src/common/util.h b/src/common/util.h index 30148f8bf891..c126f83004a4 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -145,13 +145,13 @@ inline bool EndsWith(std::string const& value, std::string const& end) { * \param err_msg The error message if we have * \return executed output status */ -inline int Execute(std::string cmd, std::string& err_msg) { +inline int Execute(std::string cmd, std::string* err_msg) { std::array buffer; std::string result; cmd += " 2>&1"; FILE* fd = TVMPOpen(cmd.c_str(), "r"); while (fgets(buffer.data(), buffer.size(), fd) != nullptr) { - err_msg += buffer.data(); + *err_msg += buffer.data(); } int status = TVMPClose(fd); if (TVMWifexited(status)) { From 3fd1572502303cdbb358732930c7edbae9f4e1fe Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Mon, 11 Nov 2019 02:11:36 +0800 Subject: [PATCH 13/13] LOG Error becomes FATAL and license --- apps/cpp_rpc/rpc_env.cc | 4 ++-- src/common/util.h | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 6b6ebd41fcb8..44f848dc749e 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -166,7 +166,7 @@ void LinuxShared(const std::string output, std::string err_msg; auto executed_status = common::Execute(cmd, &err_msg); if (executed_status) { - LOG(ERROR) << err_msg; + LOG(FATAL) << err_msg; } } @@ -212,7 +212,7 @@ Module Load(std::string *fileIn, const std::string fmt) { std::string err_msg; int executed_status = common::Execute(cmd, &err_msg); if (executed_status) { - LOG(ERROR) << err_msg; + LOG(FATAL) << err_msg; } CreateShared(file_name, ListDir(tmp_dir)); CleanDir(tmp_dir); diff --git a/src/common/util.h b/src/common/util.h index c126f83004a4..93f32f48a2a6 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -66,13 +66,6 @@ inline int TVMPClose(FILE* stream) { #endif } -/* - * gnulib sys_wait.h.in says on Windows - * When an unhandled fatal signal terminates a process, the exit code is 3. - * # define WIFSIGNALED(x) ((x) == 3) - * # define WIFEXITED(x) ((x) != 3) - * # define WIFSTOPPED(x) 0 - */ /*! * \brief TVMWifexited wrapper of WIFEXITED between windows / linux * \param status The status field that was filled in by the wait or waitpid function