Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class RPCError(TVMError):
"""Error thrown by the remote server handling the RPC call."""


@register_error
class RPCSessionTimeoutError(RPCError, TimeoutError):
"""Error thrown by the remote server when the RPC session has expired."""


@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
Expand Down
88 changes: 45 additions & 43 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
import os
import ctypes
import socket
import select
Expand Down Expand Up @@ -118,16 +119,6 @@ def download_linked_module(file_name):
return temp


def _serve_loop(sock, addr, load_library, work_path=None):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env(load_library, work_path)
_ffi_api.ServerLoop(sockfd)
if not work_path:
temp.remove()
logger.info("Finish serving %s", addr)


def _parse_server_opt(opts):
# parse client options
ret = {}
Expand All @@ -137,6 +128,47 @@ def _parse_server_opt(opts):
return ret


def _serving(sock, addr, opts, load_library):
logger.info(f"connected from {addr}")
work_path = utils.tempdir()
old_cwd = os.getcwd()
os.chdir(work_path.path) # Avoiding file name conflict between sessions.
logger.info(f"start serving at {work_path.path}")

def _serve_loop():
_server_env(load_library, work_path)
_ffi_api.ServerLoop(sock.fileno())

server_proc = multiprocessing.Process(target=_serve_loop)
server_proc.start()
server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout.

if server_proc.is_alive():
logger.info("timeout in RPC session, kill..")
_ffi_api.ReturnException(
sock.fileno(),
f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has expired, '
f'try to increase the "session_timeout" value.',
)

try:
import psutil # pylint: disable=import-outside-toplevel

# Terminate worker children firstly.
for child in psutil.Process(server_proc.pid).children(recursive=True):
child.terminate()
except ImportError:
# Don't dependent `psutil` hardly, because it isn't a pure Python
# package and maybe hard to be installed on some platforms.
pass
server_proc.terminate()

logger.info(f"finish serving {addr}")
os.chdir(old_cwd)
work_path.remove()
sock.close()


def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server."""

Expand Down Expand Up @@ -237,30 +269,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
raise exc

# step 3: serving
work_path = utils.tempdir()
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(
target=_serve_loop, args=(conn, addr, load_library, work_path)
)

server_proc.start()
# close from our side.
conn.close()
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))

if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..")
# pylint: disable=import-outside-toplevel
import psutil

parent = psutil.Process(server_proc.pid)
# terminate worker children
for child in parent.children(recursive=True):
child.terminate()
# terminate the worker
server_proc.terminate()
work_path.remove()
_serving(conn, addr, opts, load_library)


def _connect_proxy_loop(addr, key, load_library):
Expand All @@ -285,15 +294,8 @@ def _connect_proxy_loop(addr, key, load_library):
raise RuntimeError(f"{str(addr)} is not RPC Proxy")
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logger.info("Timeout in RPC session, kill..")
process.terminate()

_serving(sock, addr, _parse_server_opt(remote_key.split()[1:]), load_library)
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

#include "../../support/arena.h"
#include "../../support/ring_buffer.h"
#include "../../support/utils.h"
#include "../object_internal.h"
#include "rpc_local_session.h"

Expand Down Expand Up @@ -372,8 +373,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
if (code == RPCCode::kException) {
// switch to the state before sending exception.
this->SwitchToState(kRecvPacketNumBytes);
std::string msg = args[0];
LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
String msg = args[0];
if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) {
msg = "RPCError: Error caught from RPC call:\n" + msg;
}
LOG(FATAL) << msg;
}

ICHECK(setreturn != nullptr) << "fsetreturn not available";
Expand Down
34 changes: 34 additions & 0 deletions src/runtime/rpc/rpc_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,39 @@ TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv)
}
});

class SimpleSockHandler : public dmlc::Stream {
// Things that will interface with user directly.
public:
explicit SimpleSockHandler(int sockfd)
: sock_(static_cast<support::TCPSocket::SockType>(sockfd)) {}
using dmlc::Stream::Read;
using dmlc::Stream::ReadArray;
using dmlc::Stream::Write;
using dmlc::Stream::WriteArray;

// Unused here, implemented for microTVM framing layer.
void MessageStart(uint64_t packet_nbytes) {}
void MessageDone() {}

// Internal supporting.
// Override methods that inherited from dmlc::Stream.
private:
size_t Read(void* data, size_t size) final {
ICHECK_EQ(sock_.RecvAll(data, size), size);
return size;
}
void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); }

// Things of current class.
private:
support::TCPSocket sock_;
};

TVM_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) {
auto handler = SimpleSockHandler(sockfd);
RPCReference::ReturnException(msg.c_str(), &handler);
return;
});

} // namespace runtime
} // namespace tvm
31 changes: 31 additions & 0 deletions tests/python/unittest/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,34 @@ def test_rpc_tracker_via_proxy(device_key):
server1.terminate()
proxy_server.terminate()
tracker_server.terminate()


@tvm.testing.requires_rpc
@pytest.mark.parametrize("with_proxy", (True, False))
def test_rpc_session_timeout_error(with_proxy):
port = 9000
port_end = 10000

tracker = Tracker(port=port, port_end=port_end)
time.sleep(0.5)
tracker_addr = (tracker.host, tracker.port)

if with_proxy:
proxy = Proxy(host="0.0.0.0", port=port, port_end=port_end, tracker_addr=tracker_addr)
time.sleep(0.5)
server = rpc.Server(host=proxy.host, port=proxy.port, is_proxy=True, key="x1")
else:
server = rpc.Server(port=port, port_end=port_end, tracker_addr=tracker_addr, key="x1")
time.sleep(0.5)

rpc_sess = rpc.connect_tracker(*tracker_addr).request(key="x1", session_timeout=1)

with pytest.raises(tvm.error.RPCSessionTimeoutError):
f1 = rpc_sess.get_function("rpc.test.addone")
time.sleep(2)
f1(10)

server.terminate()
if with_proxy:
proxy.terminate()
tracker.terminate()