-
Notifications
You must be signed in to change notification settings - Fork 7k
[xray] Implements ray.wait #2162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
0e18ca7
fc2572c
2e8af60
b57a548
a128698
fa5c32d
0ccf46b
b02de4f
f9a9e16
15b7f61
a22263b
8ab41f0
98bacfa
d518a89
53f33e0
8ef35f7
6e10f9e
9a95c65
aa12bd7
9e1602d
304b39c
5d63bb3
cf1fdb2
531d024
d0d3ea4
62ae832
67eef67
0796a17
dd9f0db
9d4ed2b
541b88c
58af739
d9ef29b
fa1928b
d41b1d0
aeaab5b
048f45f
0aa7525
833939f
8e1947c
83d04dd
080282f
a58f5c9
c6d8ba5
7d8d756
6b6e2f3
3a86c93
1a99f25
00eafd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2540,9 +2540,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): | |
| A list of object IDs that are ready and a list of the remaining object | ||
| IDs. | ||
| """ | ||
| if worker.use_raylet: | ||
| print("plasma_client.wait has not been implemented yet") | ||
| return | ||
|
|
||
| if isinstance(object_ids, ray.ObjectID): | ||
| raise TypeError( | ||
|
|
@@ -2574,18 +2571,32 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): | |
| if len(object_ids) == 0: | ||
| return [], [] | ||
|
|
||
| object_id_strs = [ | ||
| plasma.ObjectID(object_id.id()) for object_id in object_ids | ||
| ] | ||
| timeout = timeout if timeout is not None else 2**30 | ||
| ready_ids, remaining_ids = worker.plasma_client.wait( | ||
| object_id_strs, timeout, num_returns) | ||
| ready_ids = [ | ||
| ray.ObjectID(object_id.binary()) for object_id in ready_ids | ||
| ] | ||
| remaining_ids = [ | ||
| ray.ObjectID(object_id.binary()) for object_id in remaining_ids | ||
| ] | ||
| if worker.use_raylet: | ||
| if len(object_ids) != len(set(object_ids)): | ||
| raise Exception("Wait requires a list of unique object IDs.") | ||
| if len(object_ids) <= 0: | ||
| raise Exception("Invalid number of objects %d." % len(object_ids)) | ||
| if num_returns <= 0: | ||
| raise Exception("Invalid number of objects to return %d." % num_returns) | ||
| if num_returns > len(object_ids): | ||
| raise Exception("num_returns cannot be greater than the number " | ||
| "of objects provided to ray.wait.") | ||
|
||
| timeout = timeout if timeout is not None else 2**30 | ||
|
||
| ready_ids, remaining_ids = worker.local_scheduler_client.wait( | ||
| object_ids, num_returns, timeout, False) | ||
| else: | ||
| object_id_strs = [ | ||
| plasma.ObjectID(object_id.id()) for object_id in object_ids | ||
| ] | ||
| timeout = timeout if timeout is not None else 2**30 | ||
|
||
| ready_ids, remaining_ids = worker.plasma_client.wait( | ||
| object_id_strs, timeout, num_returns) | ||
| ready_ids = [ | ||
| ray.ObjectID(object_id.binary()) for object_id in ready_ids | ||
| ] | ||
| remaining_ids = [ | ||
| ray.ObjectID(object_id.binary()) for object_id in remaining_ids | ||
| ] | ||
| return ready_ids, remaining_ids | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -179,6 +179,53 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self, | |
| Py_RETURN_NONE; | ||
| } | ||
|
|
||
| static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { | ||
| PyObject *py_object_ids; | ||
| int num_returns; | ||
| int64_t timeout_ms; | ||
| int wait_local; | ||
|
|
||
| if (!PyArg_ParseTuple(args, "Oili", &py_object_ids, &num_returns, &timeout_ms, | ||
| &wait_local)) { | ||
|
||
| return NULL; | ||
| } | ||
| // Convert object ids. | ||
| PyObject *iter = PyObject_GetIter(py_object_ids); | ||
| if (!iter) { | ||
| return NULL; | ||
| } | ||
| std::vector<ObjectID> object_ids; | ||
| while (true) { | ||
| PyObject *next = PyIter_Next(iter); | ||
| ObjectID object_id; | ||
| if (!next) { | ||
| break; | ||
| } | ||
| if (!PyObjectToUniqueID(next, &object_id)) { | ||
| // Error parsing object id. | ||
| return NULL; | ||
| } | ||
| object_ids.push_back(object_id); | ||
| } | ||
|
|
||
| // Invoke wait. | ||
| std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result = | ||
| local_scheduler_wait( | ||
| ((PyLocalSchedulerClient *) self)->local_scheduler_connection, | ||
|
||
| object_ids, num_returns, timeout_ms, static_cast<bool>(wait_local)); | ||
|
|
||
| // Convert result to py object. | ||
| PyObject *py_found = PyList_New((Py_ssize_t) result.first.size()); | ||
|
||
| for (uint i = 0; i < result.first.size(); ++i) { | ||
| PyList_SetItem(py_found, i, PyObjectID_make(result.first[i])); | ||
| } | ||
| PyObject *py_remaining = PyList_New((Py_ssize_t) result.second.size()); | ||
|
||
| for (uint i = 0; i < result.second.size(); ++i) { | ||
| PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i])); | ||
| } | ||
| return Py_BuildValue("(OO)", py_found, py_remaining); | ||
| } | ||
|
|
||
| static PyMethodDef PyLocalSchedulerClient_methods[] = { | ||
| {"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS, | ||
| "Notify the local scheduler that this client is exiting gracefully."}, | ||
|
|
@@ -201,6 +248,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { | |
| (PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""}, | ||
| {"set_actor_frontier", | ||
| (PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""}, | ||
| {"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS, | ||
| "Wait for a list of objects to be created."}, | ||
| {NULL} /* Sentinel */ | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| #include "common_protocol.h" | ||
| #include "format/local_scheduler_generated.h" | ||
| #include "ray/raylet/format/node_manager_generated.h" | ||
|
|
||
| #include "common/io.h" | ||
| #include "common/task.h" | ||
|
|
@@ -186,3 +187,37 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, | |
| write_message(conn->conn, MessageType_SetActorFrontier, frontier.size(), | ||
| const_cast<uint8_t *>(frontier.data())); | ||
| } | ||
|
|
||
| std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait( | ||
| LocalSchedulerConnection *conn, | ||
| const std::vector<ObjectID> &object_ids, | ||
| int num_returns, | ||
| int64_t timeout, | ||
| bool wait_local) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No; I just added the piping for the argument. The back-end returns an unimplemented status if it's set to true. I have an idea of how to implement this whenever we'd like to add it. |
||
| // Write request. | ||
| flatbuffers::FlatBufferBuilder fbb; | ||
| auto message = ray::protocol::CreateWaitRequest( | ||
| fbb, to_flatbuf(fbb, object_ids), num_returns, timeout, wait_local); | ||
| fbb.Finish(message); | ||
| write_message(conn->conn, ray::protocol::MessageType_WaitRequest, | ||
| fbb.GetSize(), fbb.GetBufferPointer()); | ||
| // Read result. | ||
| int64_t type; | ||
| int64_t reply_size; | ||
| uint8_t *reply; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a memory leak, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| read_message(conn->conn, &type, &reply_size, &reply); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add RAY_CHECK(type == MessageType_WaitReply); |
||
| auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply); | ||
| // Convert result. | ||
| std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result; | ||
| auto found = reply_message->found(); | ||
| for (uint i = 0; i < found->size(); i++) { | ||
| ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); | ||
| result.first.push_back(object_id); | ||
| } | ||
| auto remaining = reply_message->remaining(); | ||
| for (uint i = 0; i < remaining->size(); i++) { | ||
| ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); | ||
| result.second.push_back(object_id); | ||
| } | ||
| return result; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -169,4 +169,22 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier( | |
| void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, | ||
| const std::vector<uint8_t> &frontier); | ||
|
|
||
| /// Wait for the given objects until timeout expires or num_return objects are | ||
| /// found. | ||
| /// | ||
| /// \param conn The connection information. | ||
| /// \param object_ids The objects to wait for. | ||
| /// \param num_returns The number of objects to wait for. | ||
| /// \param timeout The duration to wait before returning. | ||
|
||
| /// \param wait_local Whether to wait for objects to appear on this node. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This argument doesn't seem to match the current semantics for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| /// \return A pair with the first element containing the object ids that were | ||
| /// found, | ||
| /// and the second element the objects that were not found. | ||
|
||
| std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait( | ||
| LocalSchedulerConnection *conn, | ||
| const std::vector<ObjectID> &object_ids, | ||
| int num_returns, | ||
| int64_t timeout, | ||
| bool wait_local); | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,35 +7,47 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_clien | |
| } | ||
|
|
||
| void ObjectDirectory::RegisterBackend() { | ||
| auto object_notification_callback = [this](gcs::AsyncGcsClient *client, | ||
| const ObjectID &object_id, | ||
| const std::vector<ObjectTableDataT> &data) { | ||
| auto object_notification_callback = [this]( | ||
| gcs::AsyncGcsClient *client, const ObjectID &object_id, | ||
| const std::vector<ObjectTableDataT> &object_location_ids) { | ||
| // Objects are added to this map in SubscribeObjectLocations. | ||
| auto entry = listeners_.find(object_id); | ||
| auto object_id_listener_pair = listeners_.find(object_id); | ||
| // Do nothing for objects we are not listening for. | ||
| if (entry == listeners_.end()) { | ||
| if (object_id_listener_pair == listeners_.end()) { | ||
| return; | ||
| } | ||
| // Update entries for this object. | ||
| auto client_id_set = entry->second.client_ids; | ||
| for (auto &object_table_data : data) { | ||
| auto &location_client_id_set = object_id_listener_pair->second.location_client_ids; | ||
| // object_location_ids contains the history of locations of the object (it is a log), | ||
| // which might look like the following: | ||
| // client1.is_eviction = false | ||
| // client1.is_eviction = true | ||
| // client2.is_eviction = false | ||
| // In such a scenario, we want to indicate client2 is the only client that contains | ||
| // the object, which the following code achieves. | ||
| for (const auto &object_table_data : object_location_ids) { | ||
| ClientID client_id = ClientID::from_binary(object_table_data.manager); | ||
| if (!object_table_data.is_eviction) { | ||
| client_id_set.insert(client_id); | ||
| location_client_id_set.insert(client_id); | ||
| } else { | ||
| client_id_set.erase(client_id); | ||
| location_client_id_set.erase(client_id); | ||
| } | ||
| } | ||
| if (!client_id_set.empty()) { | ||
| // Only call the callback if we have object locations. | ||
| std::vector<ClientID> client_id_vec(client_id_set.begin(), client_id_set.end()); | ||
| auto callback = entry->second.locations_found_callback; | ||
| callback(client_id_vec, object_id); | ||
| if (!location_client_id_set.empty()) { | ||
| std::vector<ClientID> client_id_vec(location_client_id_set.begin(), | ||
| location_client_id_set.end()); | ||
| // Copy the callbacks so that the callbacks can unsubscribe without interrupting | ||
| // looping over the callbacks. | ||
| auto callbacks = object_id_listener_pair->second.callbacks; | ||
| // Call all callbacks associated with the object id locations we have received. | ||
| for (const auto &callback_pair : callbacks) { | ||
| callback_pair.second(client_id_vec, object_id); | ||
| } | ||
| } | ||
| }; | ||
| RAY_CHECK_OK(gcs_client_->object_table().Subscribe( | ||
| UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), | ||
| object_notification_callback, nullptr)); | ||
| RAY_CHECK_OK(gcs_client_->object_table().Subscribe(UniqueID::nil(), | ||
| gcs_client_->client_table().GetLocalClientId(), | ||
| object_notification_callback, nullptr)); | ||
| } | ||
|
|
||
| ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id, | ||
|
|
@@ -86,25 +98,65 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id, | |
| return ray::Status::OK(); | ||
| } | ||
|
|
||
| ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id, | ||
| ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id, | ||
| const ObjectID &object_id, | ||
| const OnLocationsFound &callback) { | ||
| if (listeners_.find(object_id) != listeners_.end()) { | ||
| RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id; | ||
| ray::Status status = ray::Status::OK(); | ||
| if (listeners_.find(object_id) == listeners_.end()) { | ||
| listeners_.emplace(object_id, LocationListenerState()); | ||
| status = gcs_client_->object_table().RequestNotifications( | ||
| JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); | ||
| } | ||
| if (listeners_[object_id].callbacks.count(callback_id) > 0) { | ||
|
||
| return ray::Status::OK(); | ||
| } | ||
| listeners_.emplace(object_id, LocationListenerState(callback)); | ||
| return gcs_client_->object_table().RequestNotifications( | ||
| JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); | ||
| listeners_[object_id].callbacks.emplace(callback_id, callback); | ||
| // Immediately notify of found object locations. | ||
| if (!listeners_[object_id].location_client_ids.empty()) { | ||
| std::vector<ClientID> client_id_vec(listeners_[object_id].location_client_ids.begin(), | ||
| listeners_[object_id].location_client_ids.end()); | ||
| callback(client_id_vec, object_id); | ||
| } | ||
| return status; | ||
| } | ||
|
|
||
| ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) { | ||
| ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id, | ||
| const ObjectID &object_id) { | ||
| ray::Status status = ray::Status::OK(); | ||
| auto entry = listeners_.find(object_id); | ||
| if (entry == listeners_.end()) { | ||
| return ray::Status::OK(); | ||
| return status; | ||
| } | ||
| ray::Status status = gcs_client_->object_table().CancelNotifications( | ||
| JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); | ||
| listeners_.erase(entry); | ||
| entry->second.callbacks.erase(callback_id); | ||
| if (entry->second.callbacks.empty()) { | ||
| status = gcs_client_->object_table().CancelNotifications( | ||
| JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); | ||
| listeners_.erase(entry); | ||
| } | ||
| return status; | ||
| } | ||
|
|
||
| ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, | ||
| const OnLocationsFound &callback) { | ||
| JobID job_id = JobID::nil(); | ||
| ray::Status status = gcs_client_->object_table().Lookup( | ||
| job_id, object_id, | ||
| [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, | ||
| const std::vector<ObjectTableDataT> &location_entries) { | ||
| // Build the set of current locations based on the entries in the log. | ||
| std::unordered_set<ClientID> locations; | ||
| for (auto entry : location_entries) { | ||
| ClientID client_id = ClientID::from_binary(entry.manager); | ||
| if (!entry.is_eviction) { | ||
| locations.insert(client_id); | ||
| } else { | ||
| locations.erase(client_id); | ||
| } | ||
| } | ||
| // Invoke the callback. | ||
| std::vector<ClientID> locations_vector(locations.begin(), locations.end()); | ||
| callback(locations_vector, object_id); | ||
| }); | ||
| return status; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would
len(object_ids) < 0ever be true?Is the
len(object_ids) == 0)case not handled by the backend?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, it should be handled by the backend, but there's a bug and it was quicker to patch it here
#1969
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the topic of bugs, if you're wondering about the
2**30, that's #58