Skip to content
52 changes: 25 additions & 27 deletions src/ray/object_manager/object_directory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace ray {

ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_client)
: gcs_client_(gcs_client) {}
ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service,
std::shared_ptr<gcs::AsyncGcsClient> &gcs_client)
: io_service_(io_service), gcs_client_(gcs_client) {}

namespace {

Expand Down Expand Up @@ -61,6 +62,8 @@ void ObjectDirectory::RegisterBackend() {
// empty, since this may indicate that the objects have been evicted from
// all nodes.
for (const auto &callback_pair : callbacks) {
// It is safe to call the callback directly since this is already running
// in the subscription callback stack.
callback_pair.second(client_id_vec, object_id);
}
};
Expand Down Expand Up @@ -102,38 +105,29 @@ ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id,
return status;
};

ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) {
const ClientTableDataT &data = gcs_client_->client_table().GetClient(client_id);
void ObjectDirectory::LookupRemoteConnectionInfo(
RemoteConnectionInfo &connection_info) const {
const ClientTableDataT &data =
gcs_client_->client_table().GetClient(connection_info.client_id);
ClientID result_client_id = ClientID::from_binary(data.client_id);
if (result_client_id == ClientID::nil() || !data.is_insertion) {
fail_callback();
} else {
const auto &info =
RemoteConnectionInfo(client_id, data.node_manager_address,
static_cast<uint16_t>(data.object_manager_port));
success_callback(info);
if (result_client_id != ClientID::nil() && data.is_insertion) {
connection_info.ip = data.node_manager_address;
connection_info.port = static_cast<uint16_t>(data.object_manager_port);
}
return ray::Status::OK();
}

void ObjectDirectory::RunFunctionForEachClient(
const InfoSuccessCallback &client_function) {
std::vector<RemoteConnectionInfo> ObjectDirectory::LookupAllRemoteConnections() const {
std::vector<RemoteConnectionInfo> remote_connections;
const auto &clients = gcs_client_->client_table().GetAllClients();
for (const auto &client_pair : clients) {
const ClientTableDataT &data = client_pair.second;
if (client_pair.first == ClientID::nil() ||
client_pair.first == gcs_client_->client_table().GetLocalClientId() ||
!data.is_insertion) {
continue;
} else {
const auto &info =
RemoteConnectionInfo(client_pair.first, data.node_manager_address,
static_cast<uint16_t>(data.object_manager_port));
client_function(info);
RemoteConnectionInfo info(client_pair.first);
LookupRemoteConnectionInfo(info);
if (info.Connected() &&
info.client_id != gcs_client_->client_table().GetLocalClientId()) {
remote_connections.push_back(info);
}
}
return remote_connections;
}

ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
Expand All @@ -156,7 +150,9 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
// have been evicted from all nodes.
std::vector<ClientID> client_id_vec(listener_state.current_object_locations.begin(),
listener_state.current_object_locations.end());
callback(client_id_vec, object_id);
io_service_.post([this, callback, client_id_vec, object_id]() {
callback(client_id_vec, object_id);
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key fix, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at least for #3201.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Stephanie. Could you explain what's happening, and how this change fixes #3201?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is basically the same one that was discussed when ray.wait was first implemented here. Basically the callback deletes from a data structure shared with the caller, so an iterator held by the caller gets invalidated when the callback returns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm still confused because I thought we addressed that issue before merging. The iterator in SubscribeRemainingWaitObjects makes a copy of the object ids it iterates over (see here), which is not shared with the callback. The callback does not modify the vector of object ids (see here). Could you explain precisely the failure scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the memory referenced by wait_state in that function gets invalidated because the callback deletes the wait_id from active_wait_requests.

Copy link
Contributor

@elibol elibol Nov 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, one failure scenario is as follows:

  1. The final object id's callback is invoked immediately.
  2. The wait_id is removed within the callback (i.e. WaitComplete is invoked).
  3. At this point, the loop exits, so this is not invoked.
  4. This is invoked, which references memory that doesn't exist.

Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

return status;
}

Expand Down Expand Up @@ -187,6 +183,8 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
std::unordered_set<ClientID> client_ids;
std::vector<ClientID> locations_vector = UpdateObjectLocations(
client_ids, location_history, gcs_client_->client_table());
// It is safe to call the callback directly since this is already running
// in the GCS client's lookup callback stack.
callback(locations_vector, object_id);
});
return status;
Expand Down
66 changes: 33 additions & 33 deletions src/ray/object_manager/object_directory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,36 @@ namespace ray {

/// Connection information for remote object managers.
struct RemoteConnectionInfo {
RemoteConnectionInfo() = default;
RemoteConnectionInfo(const ClientID &id, const std::string &ip_address,
uint16_t port_num)
: client_id(id), ip(ip_address), port(port_num) {}
RemoteConnectionInfo(const ClientID &id) : client_id(id) {}

// Returns whether there is enough information to connect to the remote
// object manager.
bool Connected() const { return !ip.empty(); }

ClientID client_id;
std::string ip;
uint16_t port;
};

class ObjectDirectoryInterface {
public:
ObjectDirectoryInterface() = default;
virtual ~ObjectDirectoryInterface() = default;

/// Callbacks for GetInformation.
using InfoSuccessCallback = std::function<void(const ray::RemoteConnectionInfo &info)>;
using InfoFailureCallback = std::function<void()>;
virtual ~ObjectDirectoryInterface() {}

virtual void RegisterBackend() = 0;

/// This is used to establish object manager client connections.
/// Lookup how to connect to a remote object manager.
///
/// \param connection_info The connection information to fill out. This
/// should be pre-populated with the requested client ID. If the directory
/// has information about the requested client, then the rest of the fields
/// in this struct will be populated accordingly.
virtual void LookupRemoteConnectionInfo(
RemoteConnectionInfo &connection_info) const = 0;

/// Get information for all connected remote object managers.
///
/// \param client_id The client for which information is required.
/// \param success_cb A callback which handles the success of this method.
/// \param fail_cb A callback which handles the failure of this method.
/// \return Status of whether this asynchronous request succeeded.
virtual ray::Status GetInformation(const ClientID &client_id,
const InfoSuccessCallback &success_cb,
const InfoFailureCallback &fail_cb) = 0;
/// \return A vector of information for all connected remote object managers.
virtual std::vector<RemoteConnectionInfo> LookupAllRemoteConnections() const = 0;

/// Callback for object location notifications.
using OnLocationsFound = std::function<void(const std::vector<ray::ClientID> &,
Expand Down Expand Up @@ -102,28 +103,27 @@ class ObjectDirectoryInterface {
/// \return Status of whether this method succeeded.
virtual ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) = 0;

/// Go through all the client information.
///
/// \param success_cb A callback which handles the success of this method.
/// This function will be called multiple times.
/// \return Void.
virtual void RunFunctionForEachClient(const InfoSuccessCallback &client_function) = 0;
};

/// Ray ObjectDirectory declaration.
class ObjectDirectory : public ObjectDirectoryInterface {
public:
ObjectDirectory() = default;
~ObjectDirectory() override = default;
/// Create an object directory.
///
/// \param io_service The event loop to dispatch callbacks to. This should
/// usually be the same event loop that the given gcs_client runs on.
/// \param gcs_client A Ray GCS client to request object and client
/// information from.
ObjectDirectory(boost::asio::io_service &io_service,
std::shared_ptr<gcs::AsyncGcsClient> &gcs_client);

virtual ~ObjectDirectory() {}

void RegisterBackend() override;

ray::Status GetInformation(const ClientID &client_id,
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) override;
void LookupRemoteConnectionInfo(RemoteConnectionInfo &connection_info) const override;

void RunFunctionForEachClient(const InfoSuccessCallback &client_function) override;
std::vector<RemoteConnectionInfo> LookupAllRemoteConnections() const override;

ray::Status LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) override;
Expand All @@ -139,8 +139,6 @@ class ObjectDirectory : public ObjectDirectoryInterface {
const object_manager::protocol::ObjectInfoT &object_info) override;
ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) override;
/// Ray only (not part of the OD interface).
ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_client);

/// ObjectDirectory should not be copied.
RAY_DISALLOW_COPY_AND_ASSIGN(ObjectDirectory);
Expand All @@ -154,6 +152,8 @@ class ObjectDirectory : public ObjectDirectoryInterface {
std::unordered_set<ClientID> current_object_locations;
};

/// Reference to the event loop.
boost::asio::io_service &io_service_;
/// Reference to the gcs client.
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
/// Info about subscribers to object locations.
Expand Down
Loading