diff --git a/src/ray/raylet_rpc_client/raylet_client.cc b/src/ray/raylet_rpc_client/raylet_client.cc index baac6411b6f5..d76b49716331 100644 --- a/src/ray/raylet_rpc_client/raylet_client.cc +++ b/src/ray/raylet_rpc_client/raylet_client.cc @@ -47,7 +47,8 @@ RayletClient::RayletClient(const rpc::Address &address, ::RayConfig::instance().raylet_rpc_server_reconnect_timeout_max_s(), /*server_unavailable_timeout_callback=*/ std::move(raylet_unavailable_timeout_callback), - /*server_name=*/std::string("Raylet ") + address.ip_address())) {} + /*server_name=*/std::string("Raylet ") + address.ip_address())), + pins_in_flight_(std::make_shared>(0)) {} void RayletClient::RequestWorkerLease( const rpc::LeaseSpec &lease_spec, @@ -335,11 +336,13 @@ void RayletClient::PinObjectIDs( if (!generator_id.IsNil()) { request.set_generator_id(generator_id.Binary()); } - auto self = shared_from_this(); - pins_in_flight_++; - auto rpc_callback = [self, callback = std::move(callback)]( + + // NOTE: this callback can execute after the RayletClient instance is destroyed, so + // we capture the shared_ptr to `pins_in_flight_` instead of `this`. + pins_in_flight_->fetch_add(1); + auto rpc_callback = [callback, pins_in_flight = pins_in_flight_]( Status status, rpc::PinObjectIDsReply &&reply) { - self->pins_in_flight_--; + pins_in_flight->fetch_sub(1); callback(status, std::move(reply)); }; INVOKE_RETRYABLE_RPC_CALL(retryable_grpc_client_, diff --git a/src/ray/raylet_rpc_client/raylet_client.h b/src/ray/raylet_rpc_client/raylet_client.h index 39f8b97b4a22..07de0073ef02 100644 --- a/src/ray/raylet_rpc_client/raylet_client.h +++ b/src/ray/raylet_rpc_client/raylet_client.h @@ -39,8 +39,7 @@ namespace rpc { /// Raylet client is responsible for communication with raylet. It implements /// [RayletClientInterface] and works on worker registration, lease management, etc. -class RayletClient : public RayletClientInterface, - public std::enable_shared_from_this { +class RayletClient : public RayletClientInterface { public: /// Connect to the raylet. /// @@ -160,7 +159,7 @@ class RayletClient : public RayletClientInterface, const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } - int64_t GetPinsInFlight() const override { return pins_in_flight_.load(); } + int64_t GetPinsInFlight() const override { return pins_in_flight_->load(); } void GetNodeStats(const rpc::GetNodeStatsRequest &request, const rpc::ClientCallback &callback) override; @@ -188,7 +187,9 @@ class RayletClient : public RayletClientInterface, ResourceMappingType resource_ids_; /// The number of object ID pin RPCs currently in flight. - std::atomic pins_in_flight_ = 0; + /// NOTE: `shared_ptr` because it is captured in a callback that can outlive this + /// instance. + std::shared_ptr> pins_in_flight_; }; } // namespace rpc