diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 94644d797c1a..92f477b058fd 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -72,8 +72,10 @@ struct TypeIndex { kRuntimeShapeTuple = 6, /*! \brief runtime::PackedFunc. */ kRuntimePackedFunc = 7, - /*! \brief runtime::DRef */ + /*! \brief runtime::DRef for disco distributed runtime */ kRuntimeDiscoDRef = 8, + /*! \brief runtime::RPCObjectRef */ + kRuntimeRPCObjectRef = 9, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index cca47f80b9df..96a4dbce79cd 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -206,7 +206,8 @@ class MinRPCExecute : public MinRPCExecInterface { ret_tcode[1] = kTVMBytes; ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); TVMByteArrayFree(reinterpret_cast(ret_value[1].v_handle)); // NOLINT(*) - } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle || + rv_tcode == kTVMObjectHandle) { ret_tcode[1] = kTVMOpaqueHandle; ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); } else { @@ -755,7 +756,17 @@ class MinRPCServer { } void ReadObject(int* tcode, TVMValue* value) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); + // handles RPCObject in minRPC + // NOTE: object needs to be supported by C runtime + // because minrpc's restriction of C only + // we only handle RPCObjectRef + uint32_t type_index; + Read(&type_index); + MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex); + uint64_t object_handle; + Read(&object_handle); + tcode[0] = kTVMObjectHandle; + value[0].v_handle = reinterpret_cast(object_handle); } private: diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e16f09cb9dee..732b017e44fe 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -33,6 +33,14 @@ class Object; /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; +/*! + * \brief type index of kRuntimeRPCObjectRefTypeIndex + * \note this needs to be kept consistent with runtime/object.h + * but we explicitly declare it here because minrpc needs to be minimum dep + * only c C API + */ +constexpr const int kRuntimeRPCObjectRefTypeIndex = 9; + // When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index f2c09132fc70..2c431cdb643c 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -175,8 +175,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { for (int i = 0; i < num_args; ++i) { int tcode = type_codes[i]; if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { - LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " - << args[i].AsObjectRef()->GetTypeKey() << " is not supported by RPC"; + if (!args[i].IsObjectRef()) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " + << args[i].AsObjectRef()->GetTypeKey() + << " is not supported by RPC"; + } } else if (tcode == kDLDevice) { DLDevice dev = args[i]; ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel"; @@ -219,14 +222,48 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Write(cdata); } - void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); } - uint64_t GetObjectBytes(void* obj) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); - return 0; + void WriteObject(Object* obj) { + // NOTE: for now all remote object are encoded as RPCObjectRef + // follow the same disco protocol in case we would like to upgrade later + // + // Rationale note: Only handle remote object allows the same mechanism to work for minRPC + // which is needed for wasm and other env that goes through C API + if (obj->IsInstance()) { + auto* ref = static_cast(obj); + this->template Write(kRuntimeRPCObjectRefTypeIndex); + uint64_t handle = reinterpret_cast(ref->object_handle()); + this->template Write(handle); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } + } + uint64_t GetObjectBytes(Object* obj) { + if (obj->IsInstance()) { + return sizeof(uint32_t) + sizeof(int64_t); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } } void ReadObject(int* tcode, TVMValue* value) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); + // NOTE: for now all remote object are encoded as RPCObjectRef + // follow the same disco protocol in case we would like to upgrade later + // + // Rationale note: Only handle remote object allows the same mechanism to work for minRPC + // which is needed for wasm and other env that goes through C API + uint32_t type_index; + this->template Read(&type_index); + if (type_index == kRuntimeRPCObjectRefTypeIndex) { + uint64_t handle; + this->template Read(&handle); + tcode[0] = kTVMObjectHandle; + value[0].v_handle = reinterpret_cast(handle); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + } } void MessageDone() { diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index d4aec5596f37..92691ee6fd28 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -27,6 +27,7 @@ #include #include +#include namespace tvm { namespace runtime { @@ -64,7 +65,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; ret_tcode_pack[2] = kTVMOpaqueHandle; encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); - } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle || + rv_tcode == kTVMObjectHandle) { // MoveToCHost means rv no longer manages the object. // return handle instead. rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); @@ -88,7 +90,21 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* a const FEncodeReturn& encode_return) { PackedFuncObj* pf = static_cast(func); TVMRetValue rv; - pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + + // unwrap RPCObjectRef in case we are directly using it to call LocalSession + std::vector values(arg_values, arg_values + num_args); + std::vector type_codes(arg_type_codes, arg_type_codes + num_args); + TVMArgs args(arg_values, arg_type_codes, num_args); + + for (int i = 0; i < num_args; ++i) { + if (args[i].IsObjectRef()) { + RPCObjectRef obj_ref = args[i]; + values[i].v_handle = obj_ref->object_handle(); + continue; + } + } + + pf->CallPacked(TVMArgs(values.data(), type_codes.data(), args.size()), &rv); this->EncodeReturn(std::move(rv), encode_return); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 94f6720ca8da..a696005ab836 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -157,6 +157,8 @@ class RPCWrappedFunc : public Object { } }; +TVM_REGISTER_OBJECT_TYPE(RPCObjectRefObj); + // RPC that represents a remote module session. class RPCModuleNode final : public ModuleNode { public: @@ -294,6 +296,11 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons void* handle = args[1]; auto n = make_object(handle, sess_); *rv = Module(n); + } else if (tcode == kTVMObjectHandle) { + ICHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); + *rv = ObjectRef(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { ICHECK_EQ(args.size(), 3); DLTensor* tensor = args[1]; diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 60d067e49d3f..b09900d0abaa 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -142,7 +142,7 @@ class RPCSession { /*! * \brief Free a remote function. - * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param handle The remote handle, can be NDArray/PackedFunc/Module/Object * \param type_code The type code of the underlying type. */ virtual void FreeHandle(void* handle, int type_code) = 0; @@ -287,6 +287,55 @@ struct RemoteSpace { std::shared_ptr sess; }; +/*! + * \brief Object wrapper that represents a reference to a remote object + */ +class RPCObjectRefObj : public Object { + public: + /*! + * \brief constructor + * \param object_handle handle that points to the remote object + * \param sess The remote session + */ + RPCObjectRefObj(void* object_handle, std::shared_ptr sess) + : object_handle_(object_handle), sess_(sess) {} + + ~RPCObjectRefObj() { + if (object_handle_ != nullptr) { + try { + sess_->FreeHandle(object_handle_, kTVMObjectHandle); + } catch (const Error& e) { + // fault tolerance to remote close + } + object_handle_ = nullptr; + } + } + + const std::shared_ptr& sess() const { return sess_; } + + void* object_handle() const { return object_handle_; } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef; + static constexpr const char* _type_key = "runtime.RPCObjectRef"; + TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object); + + private: + // The object handle + void* object_handle_{nullptr}; + // The local channel + std::shared_ptr sess_; +}; + +/*! + * \brief Managed reference to RPCObjectRefObj. + * \sa RPCObjectRefObj + * \note No public constructor is provided as it is not supposed to be directly created by users. + */ +class RPCObjectRef : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj); +}; + /*! * \brief Create a Global RPC module that refers to the session. * \param sess The RPC session of the global module. diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 9591e3ea4d60..fff203df0051 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -426,6 +426,7 @@ def test_rpc_return_ndarray(): ref_count = m("ref_count") get_elem = m("get_elem") get_arr_elem = m("get_arr_elem") + # array test def run_arr_test(): arr = get_arr() @@ -435,6 +436,36 @@ def run_arr_test(): run_arr_test() +@tvm.testing.requires_rpc +def test_rpc_return_remote_object(): + def check(client, is_local): + make_shape = client.get_function("runtime.ShapeTuple") + get_elem = client.get_function("runtime.GetShapeTupleElem") + get_size = client.get_function("runtime.GetShapeTupleSize") + shape = make_shape(2, 3) + assert shape.type_key == "runtime.RPCObjectRef" + assert get_elem(shape, 0) == 2 + assert get_elem(shape, 1) == 3 + assert get_size(shape) == 2 + + # start server + server = rpc.Server(key="x1") + client = rpc.connect("127.0.0.1", server.port, key="x1") + check(rpc.LocalSession(), True) + check(client, False) + + def check_minrpc(): + if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None: + return + # Test minrpc server. + temp = utils.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec), False) + + check_minrpc() + + @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession()