Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ray/common/client_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
std::shared_ptr<ClientConnection<T>> self(
new ClientConnection(message_handler, std::move(socket)));
// Let our manager process our new connection.
client_handler(self);
client_handler(*self);
return self;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ray/common/client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ template <typename T>
class ClientConnection;

template <typename T>
using ClientHandler = std::function<void(std::shared_ptr<ClientConnection<T>>)>;
using ClientHandler = std::function<void(ClientConnection<T> &)>;
template <typename T>
using MessageHandler =
std::function<void(std::shared_ptr<ClientConnection<T>>, int64_t, const uint8_t *)>;
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ray {
namespace gcs {

AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
context_.reset(new RedisContext());
context_ = std::make_shared<RedisContext>();
client_table_.reset(new ClientTable(context_, this, client_id));
object_table_.reset(new ObjectTable(context_, this));
actor_table_.reset(new ActorTable(context_, this));
Expand Down
14 changes: 7 additions & 7 deletions src/ray/gcs/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c

// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<protocol::TaskT> d) {
const protocol::TaskT &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->task_specification, d->task_specification);
ASSERT_EQ(data->task_specification, d.task_specification);
};

// Check that the lookup returns the added task.
Expand Down Expand Up @@ -139,9 +139,9 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cli
data->manager = manager;
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<ObjectTableDataT> d) {
const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d->manager);
ASSERT_EQ(data->manager, d.manager);
};
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback));
}
Expand Down Expand Up @@ -222,7 +222,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c

// Check that we added the correct task.
auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<TaskReconstructionDataT> d) {
const TaskReconstructionDataT &d) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
Expand Down Expand Up @@ -265,8 +265,8 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) {

// Task table callbacks.
void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
const std::shared_ptr<TaskTableDataT> data) {
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
const TaskTableDataT &data) {
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
}

void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
Expand Down
15 changes: 8 additions & 7 deletions src/ray/gcs/redis_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,13 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {

int64_t RedisCallbackManager::add(const RedisCallback &function) {
num_callbacks += 1;
callbacks_.emplace(num_callbacks, std::unique_ptr<RedisCallback>(
new RedisCallback(function)));
callbacks_.emplace(num_callbacks, function);
return num_callbacks;
}

RedisCallbackManager::RedisCallback &RedisCallbackManager::get(
int64_t callback_index) {
RedisCallback &RedisCallbackManager::get(int64_t callback_index) {
RAY_CHECK(callbacks_.find(callback_index) != callbacks_.end());
return *callbacks_[callback_index];
return callbacks_[callback_index];
}

void RedisCallbackManager::remove(int64_t callback_index) {
Expand Down Expand Up @@ -185,7 +183,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
const uint8_t *data, int64_t length,
const TablePrefix prefix, const TablePubsub pubsub_channel,
int64_t callback_index, int log_length) {
RedisCallback redisCallback, int log_length) {
int64_t callback_index =
redisCallback != nullptr ? RedisCallbackManager::instance().add(redisCallback) : -1;
if (length > 0) {
if (log_length >= 0) {
std::string redis_command = command + " %d %d %b %b %d";
Expand Down Expand Up @@ -222,10 +222,11 @@ Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,

Status RedisContext::SubscribeAsync(const ClientID &client_id,
const TablePubsub pubsub_channel,
int64_t callback_index) {
const RedisCallback &redisCallback) {
RAY_CHECK(pubsub_channel != TablePubsub_NO_PUBLISH)
<< "Client requested subscribe on a table that does not support pubsub";

int64_t callback_index = RedisCallbackManager::instance().add(redisCallback);
int status = 0;
if (client_id.is_nil()) {
// Subscribe to all messages.
Expand Down
14 changes: 7 additions & 7 deletions src/ray/gcs/redis_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ struct aeEventLoop;
namespace ray {

namespace gcs {
/// Every callback should take in a vector of the results from the Redis
/// operation and return a bool indicating whether the callback should be
/// deleted once called.
using RedisCallback = std::function<bool(const std::string &)>;

class RedisCallbackManager {
public:
/// Every callback should take in a vector of the results from the Redis
/// operation and return a bool indicating whether the callback should be
/// deleted once called.
using RedisCallback = std::function<bool(const std::string &)>;

static RedisCallbackManager &instance() {
static RedisCallbackManager instance;
Expand All @@ -44,7 +44,7 @@ class RedisCallbackManager {
~RedisCallbackManager() { printf("shut down callback manager\n"); }

int64_t num_callbacks;
std::unordered_map<int64_t, std::unique_ptr<RedisCallback>> callbacks_;
std::unordered_map<int64_t, RedisCallback> callbacks_;
};

class RedisContext {
Expand All @@ -70,11 +70,11 @@ class RedisContext {
/// -1 for unused. If set, then data must be provided.
Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data,
int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index,
const TablePubsub pubsub_channel, RedisCallback redisCallback,
int log_length = -1);

Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
int64_t callback_index);
const RedisCallback &redisCallback);
redisAsyncContext *async_context() { return async_context_; }
redisAsyncContext *subscribe_context() { return subscribe_context_; };

Expand Down
Loading