diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 783e1cd3a014..87a37b66a77c 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -57,10 +57,12 @@ def __init__( remote_kw: dict, session_name: str = "hexagon-rpc", remote_stack_size_bytes: int = 128 * 1024, + rpc_receive_buffer_size_bytes: int = 2 * 1024 * 1024, ): self._launcher = launcher self._session_name = session_name self._remote_stack_size_bytes = remote_stack_size_bytes + self._rpc_receive_buffer_size_bytes = rpc_receive_buffer_size_bytes self._remote_kw = remote_kw self._rpc = None self.device = None @@ -81,6 +83,7 @@ def __enter__(self): self._session_name, self._remote_stack_size_bytes, os.environ.get("HEXAGON_SIM_ARGS", ""), + self._rpc_receive_buffer_size_bytes, ], ) self.device = self._rpc.hexagon(0) diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 89fcc54f9a33..7c8b81445323 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -45,13 +45,19 @@ namespace hexagon { class HexagonTransportChannel : public RPCChannel { public: - explicit HexagonTransportChannel(const std::string& uri, int remote_stack_size_bytes) { + explicit HexagonTransportChannel(const std::string& uri, int remote_stack_size_bytes, + uint32_t receive_buf_size_bytes) { if (_handle != AEE_EUNKNOWN) return; enable_unsigned_pd(true); set_remote_stack_size(remote_stack_size_bytes); + AEEResult rc = hexagon_rpc_open(uri.c_str(), &_handle); ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_open failed. URI: " << uri.c_str(); + + rc = hexagon_rpc_init(_handle, receive_buf_size_bytes); + ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_set_receive_buf_size failed. receive_buf_size_bytes: " + << receive_buf_size_bytes; } size_t Send(const void* data, size_t size) override { @@ -105,10 +111,15 @@ class HexagonTransportChannel : public RPCChannel { TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + std::string session_name = args[0]; int remote_stack_size_bytes = args[1]; + // For simulator, the third parameter is sim_args, ignore it. + int hexagon_rpc_receive_buf_size_bytes = args[3]; HexagonTransportChannel* hexagon_channel = - new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes); + new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes, + static_cast(hexagon_rpc_receive_buf_size_bytes)); std::unique_ptr channel(hexagon_channel); auto ep = RPCEndpoint::Create(std::move(channel), session_name, "", NULL); auto sess = CreateClientSession(ep); diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index f61b1b6b4040..af91dd3b4e6d 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -40,9 +40,6 @@ extern "C" { #include "../../hexagon/hexagon_common.h" #include "hexagon_rpc.h" -// TODO(mehrdadh): make this configurable. -#define TVM_HEXAGON_RPC_BUFF_SIZE_BYTES 2 * 1024 * 1024 - namespace tvm { namespace runtime { namespace hexagon { @@ -190,10 +187,17 @@ class HexagonRPCServer { } // namespace tvm namespace { -tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server() { - static tvm::runtime::hexagon::HexagonRPCServer g_hexagon_rpc_server( - new uint8_t[TVM_HEXAGON_RPC_BUFF_SIZE_BYTES], TVM_HEXAGON_RPC_BUFF_SIZE_BYTES); - return &g_hexagon_rpc_server; +static tvm::runtime::hexagon::HexagonRPCServer* g_hexagon_rpc_server; +tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server( + uint32_t rpc_receive_buff_size_bytes = 0) { + if (g_hexagon_rpc_server) { + return g_hexagon_rpc_server; + } + CHECK_GT(rpc_receive_buff_size_bytes, 0) << "RPC receive buffer size is not valid."; + static tvm::runtime::hexagon::HexagonRPCServer hexagon_rpc_server( + new uint8_t[rpc_receive_buff_size_bytes], rpc_receive_buff_size_bytes); + g_hexagon_rpc_server = &hexagon_rpc_server; + return g_hexagon_rpc_server; } } // namespace @@ -216,7 +220,6 @@ int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) { return AEE_ENOMEMORY; } reset_device_api(); - get_hexagon_rpc_server(); return AEE_SUCCESS; } @@ -229,6 +232,11 @@ int __QAIC_HEADER(hexagon_rpc_close)(remote_handle64 handle) { return AEE_SUCCESS; } +int __QAIC_HEADER(hexagon_rpc_init)(remote_handle64 _h, uint32_t buff_size_bytes) { + get_hexagon_rpc_server(buff_size_bytes); + return AEE_SUCCESS; +} + /*! * \brief Send data from Host to Hexagon over RPCSession. * \param _handle The remote handle diff --git a/src/runtime/hexagon/rpc/hexagon_rpc.idl b/src/runtime/hexagon/rpc/hexagon_rpc.idl index 55b8d39bcb02..6b05324e3c87 100644 --- a/src/runtime/hexagon/rpc/hexagon_rpc.idl +++ b/src/runtime/hexagon/rpc/hexagon_rpc.idl @@ -25,4 +25,5 @@ typedef sequence buffer; interface hexagon_rpc : remote_handle64 { AEEResult send(in buffer data); AEEResult receive(rout buffer buf, rout int64_t buf_written_size); + AEEResult init(in uint32_t buff_size_bytes); }; diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index d03df7f9e573..d1cc6c4613b3 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -1311,6 +1311,8 @@ detail::Optional SimulatorRPCChannel::to_nullptr(const detail::M TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + std::string session_name = args[0]; // For target, the second parameter is remote_stack_size_bytes, ignore it. std::string sim_args = args[2];