Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0e18ca7
Use pubsub instead of timeout.
elibol May 16, 2018
fc2572c
Correct status message.
elibol May 17, 2018
2e8af60
eric's feedback!
elibol May 17, 2018
b57a548
Changes from Stephanie's review.
elibol May 17, 2018
a128698
object directory changes for ray.wait.
elibol May 18, 2018
fa5c32d
Merge branch 'master' into om_pubsub
elibol May 18, 2018
0ccf46b
Merge branch 'master' into om_pubsub
elibol May 18, 2018
b02de4f
Merge branch 'om_pubsub' into om_wait
elibol May 18, 2018
f9a9e16
wait without testing or timeout=0.
elibol May 18, 2018
15b7f61
Handle remaining cases for wait.
elibol May 18, 2018
a22263b
linting
elibol May 18, 2018
8ab41f0
added tests of om wait imp.
elibol May 18, 2018
98bacfa
add local test.
elibol May 18, 2018
d518a89
Merge branch 'master' into om_wait
elibol May 21, 2018
53f33e0
plasma imp.
elibol May 24, 2018
8ef35f7
block worker as with pull.
elibol May 29, 2018
6e10f9e
local scheduler implementation of wait.
elibol May 30, 2018
9a95c65
with passing tests.
elibol May 30, 2018
aa12bd7
minor adjustments.
elibol May 30, 2018
9e1602d
Merge branch 'master' into om_wait_local_scheduler
elibol May 30, 2018
304b39c
handle return statuses.
elibol May 30, 2018
5d63bb3
enable more tests.
elibol May 30, 2018
cf1fdb2
add test for existing num_returns semantics, and maintain existing nu…
elibol May 31, 2018
531d024
move error handling to both code paths.
elibol May 31, 2018
d0d3ea4
implementing another round of feedback.
elibol May 31, 2018
62ae832
Comment on OM tests.
elibol May 31, 2018
67eef67
remove check for length zero list.
elibol Jun 1, 2018
0796a17
remove elapsed.
elibol Jun 1, 2018
dd9f0db
Preserve input/output order.
elibol Jun 1, 2018
9d4ed2b
debias local objects.
elibol Jun 1, 2018
541b88c
Merge branch 'master' into om_wait_local_scheduler
elibol Jun 1, 2018
58af739
use common helper function in object directory.
elibol Jun 1, 2018
d9ef29b
updated documentation
elibol Jun 1, 2018
fa1928b
linting.
elibol Jun 1, 2018
d41b1d0
handle return status.
elibol Jun 1, 2018
aeaab5b
simplify order preservation test + fix valgrind test error.
elibol Jun 1, 2018
048f45f
update name of final Lookup callback.
elibol Jun 2, 2018
0aa7525
Merge branch 'master' into om_wait_local_scheduler
elibol Jun 2, 2018
833939f
linting
elibol Jun 2, 2018
8e1947c
c++ style casting.
elibol Jun 2, 2018
83d04dd
linting.
elibol Jun 2, 2018
080282f
linting.
elibol Jun 2, 2018
a58f5c9
incorporate second round of feedback.
elibol Jun 5, 2018
c6d8ba5
correct python tests.
elibol Jun 5, 2018
7d8d756
test comments.
elibol Jun 5, 2018
6b6e2f3
incorporate reviews.
elibol Jun 6, 2018
3a86c93
Fixes with regression tests.
elibol Jun 6, 2018
1a99f25
update documentation.
elibol Jun 6, 2018
00eafd7
reference to avoid copy.
elibol Jun 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,11 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
correspond to objects that are stored in the object store. The second list
corresponds to the rest of the object IDs (which may or may not be ready).

Ordering of the input list of object IDs is preserved: if A precedes B in
the input list, and both are in the ready list, then A will precede B in
the ready list. This also holds true if A and B are both in the remaining
list.

Args:
object_ids (List[ObjectID]): List of object IDs for objects that may or
may not be ready. Note that these IDs must be unique.
Expand All @@ -2540,9 +2545,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
A list of object IDs that are ready and a list of the remaining object
IDs.
"""
if worker.use_raylet:
print("plasma_client.wait has not been implemented yet")
return

if isinstance(object_ids, ray.ObjectID):
raise TypeError(
Expand Down Expand Up @@ -2574,18 +2576,30 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if len(object_ids) == 0:
return [], []

object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
if num_returns <= 0:
raise Exception(
"Invalid number of objects to return %d." % num_returns)
if num_returns > len(object_ids):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
timeout = timeout if timeout is not None else 2**30
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
if worker.use_raylet:
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
else:
object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
return ready_ids, remaining_ids


Expand Down
54 changes: 54 additions & 0 deletions src/local_scheduler/lib/python/local_scheduler_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,58 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self,
Py_RETURN_NONE;
}

static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) {
PyObject *py_object_ids;
int num_returns;
int64_t timeout_ms;
PyObject *py_wait_local;

if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms,
&py_wait_local)) {
return NULL;
}

bool wait_local = PyObject_IsTrue(py_wait_local);

// Convert object ids.
PyObject *iter = PyObject_GetIter(py_object_ids);
if (!iter) {
return NULL;
}
std::vector<ObjectID> object_ids;
while (true) {
PyObject *next = PyIter_Next(iter);
ObjectID object_id;
if (!next) {
break;
}
if (!PyObjectToUniqueID(next, &object_id)) {
// Error parsing object id.
return NULL;
}
object_ids.push_back(object_id);
}

// Invoke wait.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result =
local_scheduler_wait(reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
object_ids, num_returns, timeout_ms,
static_cast<bool>(wait_local));

// Convert result to py object.
PyObject *py_found = PyList_New(static_cast<Py_ssize_t>(result.first.size()));
for (uint i = 0; i < result.first.size(); ++i) {
PyList_SetItem(py_found, i, PyObjectID_make(result.first[i]));
}
PyObject *py_remaining =
PyList_New(static_cast<Py_ssize_t>(result.second.size()));
for (uint i = 0; i < result.second.size(); ++i) {
PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i]));
}
return Py_BuildValue("(OO)", py_found, py_remaining);
}

static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
Expand All @@ -201,6 +253,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
(PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""},
{"set_actor_frontier",
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
{"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS,
"Wait for a list of objects to be created."},
{NULL} /* Sentinel */
};

Expand Down
39 changes: 39 additions & 0 deletions src/local_scheduler/local_scheduler_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common_protocol.h"
#include "format/local_scheduler_generated.h"
#include "ray/raylet/format/node_manager_generated.h"

#include "common/io.h"
#include "common/task.h"
Expand Down Expand Up @@ -186,3 +187,41 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
write_message(conn->conn, MessageType_SetActorFrontier, frontier.size(),
const_cast<uint8_t *>(frontier.data()));
}

std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitRequest(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
wait_local);
fbb.Finish(message);
write_message(conn->conn, ray::protocol::MessageType_WaitRequest,
fbb.GetSize(), fbb.GetBufferPointer());
// Read result.
int64_t type;
int64_t reply_size;
uint8_t *reply;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a memory leak, right? reply needs to get freed somewhere (there is a malloc in read_message).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added free at the end.

read_message(conn->conn, &type, &reply_size, &reply);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

RAY_CHECK(type == MessageType_WaitReply);

RAY_CHECK(type == ray::protocol::MessageType_WaitReply);
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
// Convert result.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result;
auto found = reply_message->found();
for (uint i = 0; i < found->size(); i++) {
ObjectID object_id = ObjectID::from_binary(found->Get(i)->str());
result.first.push_back(object_id);
}
auto remaining = reply_message->remaining();
for (uint i = 0; i < remaining->size(); i++) {
ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str());
result.second.push_back(object_id);
}
/* Free the original message from the local scheduler. */
free(reply);
return result;
}
18 changes: 18 additions & 0 deletions src/local_scheduler/local_scheduler_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,22 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier);

/// Wait for the given objects until timeout expires or num_return objects are
/// found.
///
/// \param conn The connection information.
/// \param object_ids The objects to wait for.
/// \param num_returns The number of objects to wait for.
/// \param timeout_milliseconds Duration, in milliseconds, to wait before
/// returning.
/// \param wait_local Whether to wait for objects to appear on this node.
/// \return A pair with the first element containing the object ids that were
/// found, and the second element the objects that were not found.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local);

#endif
109 changes: 80 additions & 29 deletions src/ray/object_manager/object_directory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,49 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_clien
gcs_client_ = gcs_client;
}

std::vector<ClientID> UpdateObjectLocations(
std::unordered_set<ClientID> &client_ids,
const std::vector<ObjectTableDataT> &location_history) {
// location_history contains the history of locations of the object (it is a log),
// which might look like the following:
// client1.is_eviction = false
// client1.is_eviction = true
// client2.is_eviction = false
// In such a scenario, we want to indicate client2 is the only client that contains
// the object, which the following code achieves.
for (const auto &object_table_data : location_history) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_ids.insert(client_id);
} else {
client_ids.erase(client_id);
}
}
return std::vector<ClientID>(client_ids.begin(), client_ids.end());
}

void ObjectDirectory::RegisterBackend() {
auto object_notification_callback = [this](gcs::AsyncGcsClient *client,
const ObjectID &object_id,
const std::vector<ObjectTableDataT> &data) {
auto object_notification_callback = [this](
gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Objects are added to this map in SubscribeObjectLocations.
auto entry = listeners_.find(object_id);
auto object_id_listener_pair = listeners_.find(object_id);
// Do nothing for objects we are not listening for.
if (entry == listeners_.end()) {
if (object_id_listener_pair == listeners_.end()) {
return;
}
// Update entries for this object.
auto client_id_set = entry->second.client_ids;
for (auto &object_table_data : data) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_id_set.insert(client_id);
} else {
client_id_set.erase(client_id);
std::vector<ClientID> client_id_vec = UpdateObjectLocations(
object_id_listener_pair->second.current_object_locations, location_history);
if (!client_id_vec.empty()) {
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
auto callbacks = object_id_listener_pair->second.callbacks;
// Call all callbacks associated with the object id locations we have received.
for (const auto &callback_pair : callbacks) {
callback_pair.second(client_id_vec, object_id);
}
}
if (!client_id_set.empty()) {
// Only call the callback if we have object locations.
std::vector<ClientID> client_id_vec(client_id_set.begin(), client_id_set.end());
auto callback = entry->second.locations_found_callback;
callback(client_id_vec, object_id);
}
};
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(),
Expand Down Expand Up @@ -86,25 +103,59 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
return ray::Status::OK();
}

ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id,
ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) {
if (listeners_.find(object_id) != listeners_.end()) {
RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id;
ray::Status status = ray::Status::OK();
if (listeners_.find(object_id) == listeners_.end()) {
listeners_.emplace(object_id, LocationListenerState());
status = gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
}
auto &listener_state = listeners_.find(object_id)->second;
// TODO(hme): Make this fatal after implementing Pull suppression.
if (listener_state.callbacks.count(callback_id) > 0) {
return ray::Status::OK();
}
listeners_.emplace(object_id, LocationListenerState(callback));
return gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listener_state.callbacks.emplace(callback_id, callback);
// Immediately notify of found object locations.
if (!listener_state.current_object_locations.empty()) {
std::vector<ClientID> client_id_vec(listener_state.current_object_locations.begin(),
listener_state.current_object_locations.end());
callback(client_id_vec, object_id);
}
return status;
}

ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) {
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id) {
ray::Status status = ray::Status::OK();
auto entry = listeners_.find(object_id);
if (entry == listeners_.end()) {
return ray::Status::OK();
return status;
}
ray::Status status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) {
status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
}
return status;
}

ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) {
JobID job_id = JobID::nil();
ray::Status status = gcs_client_->object_table().Lookup(
job_id, object_id,
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID> client_ids;
std::vector<ClientID> locations_vector =
UpdateObjectLocations(client_ids, location_history);
callback(locations_vector, object_id);
});
return status;
}

Expand Down
Loading