diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index 25a63354..720f2da5 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -44,7 +44,7 @@ struct ServerInvokeContext : InvokeContext int req; ServerInvokeContext(ProxyServer& proxy_server, CallContext& call_context, int req) - : InvokeContext{*proxy_server.m_connection}, proxy_server{proxy_server}, call_context{call_context}, req{req} + : InvokeContext{proxy_server.m_connection}, proxy_server{proxy_server}, call_context{call_context}, req{req} { } }; @@ -207,9 +207,6 @@ class EventLoop LogFn m_log_fn; }; -void AddClient(EventLoop& loop); -void RemoveClient(EventLoop& loop); - //! Single element task queue used to handle recursive capnp calls. (If server //! makes an callback into the client in the middle of a request, while client //! thread is blocked waiting for server response, this is what allows the @@ -263,15 +260,13 @@ struct Waiter class Connection { public: - Connection(EventLoop& loop, kj::Own&& stream_, bool add_client) + Connection(EventLoop& loop, kj::Own&& stream_) : m_loop(loop), m_stream(kj::mv(stream_)), m_network(*m_stream, ::capnp::rpc::twoparty::Side::CLIENT, ::capnp::ReaderOptions()), m_rpc_system(::capnp::makeRpcClient(m_network)) { - if (add_client) { - std::unique_lock lock(m_loop.m_mutex); - m_loop.addClient(lock); - } + std::unique_lock lock(m_loop.m_mutex); + m_loop.addClient(lock); } Connection(EventLoop& loop, kj::Own&& stream_, @@ -381,7 +376,7 @@ ProxyClientBase::~ProxyClientBase() noexcept // down while external code is still holding client references. // // The first case is handled here in destructor when m_loop is not null. The - // second case is handled by the m_cleanup function, which sets m_loop to + // second case is handled by the m_cleanup function, which sets m_connection to // null so nothing happens here. if (m_connection) { // Remove m_cleanup callback so it doesn't run and try to access @@ -412,10 +407,11 @@ ProxyClientBase::~ProxyClientBase() noexcept template ProxyServerBase::ProxyServerBase(Impl* impl, bool owned, Connection& connection) - : m_impl(impl), m_owned(owned), m_connection(&connection) + : m_impl(impl), m_owned(owned), m_connection(connection) { assert(impl != nullptr); - AddClient(connection.m_loop); + std::unique_lock lock(m_connection.m_loop.m_mutex); + m_connection.m_loop.addClient(lock); } template @@ -428,12 +424,13 @@ ProxyServerBase::~ProxyServerBase() // (event loop) thread since destructors could be making IPC calls or // doing expensive cleanup. if (m_owned) { - m_connection->addAsyncCleanup([impl] { delete impl; }); + m_connection.addAsyncCleanup([impl] { delete impl; }); } m_impl = nullptr; m_owned = false; } - RemoveClient(m_connection->m_loop); // FIXME: Broken when connection is null? + std::unique_lock lock(m_connection.m_loop.m_mutex); + m_connection.m_loop.removeClient(lock); } template @@ -479,14 +476,14 @@ struct ThreadContext //! over the stream. Also create a new Connection object embedded in the //! client that is freed when the client is closed. template -std::unique_ptr> ConnectStream(EventLoop& loop, int fd, bool add_client) +std::unique_ptr> ConnectStream(EventLoop& loop, int fd) { typename InitInterface::Client init_client(nullptr); std::unique_ptr connection; loop.sync([&] { auto stream = loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP); - connection = std::make_unique(loop, kj::mv(stream), add_client); + connection = std::make_unique(loop, kj::mv(stream)); init_client = connection->m_rpc_system.bootstrap(ServerVatId().vat_id).castAs(); Connection* connection_ptr = connection.get(); connection->onDisconnect([&loop, connection_ptr] { @@ -507,9 +504,6 @@ void ServeStream(EventLoop& loop, kj::Own&& stream, std::function make_server); -//! Same as above but accept file descriptor rather than stream object. -void ServeStream(EventLoop& loop, int fd, std::function make_server); - extern thread_local ThreadContext g_thread_context; } // namespace mp diff --git a/include/mp/proxy-types.h b/include/mp/proxy-types.h index fd7532b9..e4f58358 100644 --- a/include/mp/proxy-types.h +++ b/include/mp/proxy-types.h @@ -122,12 +122,12 @@ auto PassField(TypeList<>, ServerContext& server_context, const Fn& fn, const Ar ServerContext server_context{server, call_context, req}; { auto& request_threads = g_thread_context.request_threads; - auto request_thread = request_threads.find(server.m_connection); + auto request_thread = request_threads.find(&server.m_connection); if (request_thread == request_threads.end()) { request_thread = g_thread_context.request_threads - .emplace(std::piecewise_construct, std::forward_as_tuple(server.m_connection), - std::forward_as_tuple(context_arg.getCallbackThread(), server.m_connection, + .emplace(std::piecewise_construct, std::forward_as_tuple(&server.m_connection), + std::forward_as_tuple(context_arg.getCallbackThread(), &server.m_connection, /* destroy_connection= */ false)) .first; } else { @@ -139,13 +139,13 @@ auto PassField(TypeList<>, ServerContext& server_context, const Fn& fn, const Ar fn.invoke(server_context, args...); } KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - server.m_connection->m_loop.sync([&] { + server.m_connection.m_loop.sync([&] { auto fulfiller_dispose = kj::mv(fulfiller); fulfiller_dispose->fulfill(kj::mv(call_context)); }); })) { - server.m_connection->m_loop.sync([&]() { + server.m_connection.m_loop.sync([&]() { auto fulfiller_dispose = kj::mv(fulfiller); fulfiller_dispose->reject(kj::mv(*exception)); }); @@ -153,19 +153,19 @@ auto PassField(TypeList<>, ServerContext& server_context, const Fn& fn, const Ar }))); auto thread_client = context_arg.getThread(); - return JoinPromises(server.m_connection->m_threads.getLocalServer(thread_client) + return JoinPromises(server.m_connection.m_threads.getLocalServer(thread_client) .then([&server, invoke, req](kj::Maybe perhaps) { KJ_IF_MAYBE(thread_server, perhaps) { const auto& thread = static_cast&>(*thread_server); - server.m_connection->m_loop.log() << "IPC server post request #" << req << " {" - << thread.m_thread_context.thread_name << "}"; + server.m_connection.m_loop.log() << "IPC server post request #" << req << " {" + << thread.m_thread_context.thread_name << "}"; thread.m_thread_context.waiter->post(std::move(invoke)); } else { - server.m_connection->m_loop.log() << "IPC server error request #" << req - << ", missing thread to execute request"; + server.m_connection.m_loop.log() << "IPC server error request #" << req + << ", missing thread to execute request"; throw std::runtime_error("invalid thread handle"); } }), @@ -1327,7 +1327,7 @@ void clientDestroy(Client& client) template void serverDestroy(Server& server) { - server.m_connection->m_loop.log() << "IPC server destroy" << typeid(server).name(); + server.m_connection.m_loop.log() << "IPC server destroy" << typeid(server).name(); } template @@ -1418,8 +1418,8 @@ kj::Promise serverInvoke(Server& server, CallContext& call_context, Fn fn) using Results = typename decltype(call_context.getResults())::Builds; int req = ++server_reqs; - server.m_connection->m_loop.log() << "IPC server recv request #" << req << " " - << TypeName() << " " << LogEscape(params.toString()); + server.m_connection.m_loop.log() << "IPC server recv request #" << req << " " + << TypeName() << " " << LogEscape(params.toString()); try { using ServerContext = ServerInvokeContext; @@ -1428,11 +1428,11 @@ kj::Promise serverInvoke(Server& server, CallContext& call_context, Fn fn) return ReplaceVoid([&]() { return fn.invoke(server_context, ArgList()); }, [&]() { return kj::Promise(kj::mv(call_context)); }) .then([&server, req](CallContext call_context) { - server.m_connection->m_loop.log() << "IPC server send response #" << req << " " << TypeName() - << " " << LogEscape(call_context.getResults().toString()); + server.m_connection.m_loop.log() << "IPC server send response #" << req << " " << TypeName() + << " " << LogEscape(call_context.getResults().toString()); }); } catch (...) { - server.m_connection->m_loop.log() + server.m_connection.m_loop.log() << "IPC server unhandled exception " << boost::current_exception_diagnostic_information(); throw; } diff --git a/include/mp/proxy.h b/include/mp/proxy.h index 62f267ca..16ae5f89 100644 --- a/include/mp/proxy.h +++ b/include/mp/proxy.h @@ -101,11 +101,7 @@ struct ProxyServerBase : public virtual Interface_::Server * appropriate times depending on semantics of the particular method being * wrapped. */ bool m_owned; - /** - * Connection is a pointer rather than a reference because for the Init - * server, the server object needs to be created before the connection. - */ - Connection* m_connection; + Connection& m_connection; }; //! Customizable (through template specialization) base class used in generated ProxyServer implementations from diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index 94a71a34..9207550b 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -236,17 +236,6 @@ void EventLoop::startAsyncThread(std::unique_lock& lock) } } -void AddClient(EventLoop& loop) -{ - std::unique_lock lock(loop.m_mutex); - loop.addClient(lock); -} -void RemoveClient(EventLoop& loop) -{ - std::unique_lock lock(loop.m_mutex); - loop.removeClient(lock); -} - ProxyServer::ProxyServer(ThreadContext& thread_context, std::thread&& thread) : m_thread_context(thread_context), m_thread(std::move(thread)) { diff --git a/src/mp/test/test.cpp b/src/mp/test/test.cpp index 3a7d34cc..409cfdb7 100644 --- a/src/mp/test/test.cpp +++ b/src/mp/test/test.cpp @@ -23,7 +23,7 @@ KJ_TEST("Call FooInterface methods") EventLoop loop("mptest", [](bool raise, const std::string& log) {}); auto pipe = loop.m_io_context.provider->newTwoWayPipe(); - auto connection_client = std::make_unique(loop, kj::mv(pipe.ends[0]), true); + auto connection_client = std::make_unique(loop, kj::mv(pipe.ends[0])); auto foo_client = std::make_unique>( connection_client->m_rpc_system.bootstrap(ServerVatId().vat_id).castAs(), connection_client.get(), /* destroy_connection= */ false);