Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/common/task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ TaskExecutionSpec::TaskExecutionSpec(TaskExecutionSpec *other)
spec_ = std::unique_ptr<TaskSpec[]>(spec_copy);
}

std::vector<ObjectID> TaskExecutionSpec::ExecutionDependencies() const {
const std::vector<ObjectID> &TaskExecutionSpec::ExecutionDependencies() const {
return execution_dependencies_;
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TaskExecutionSpec {
///
/// @return A vector of object IDs representing this task's execution
/// dependencies.
std::vector<ObjectID> ExecutionDependencies() const;
const std::vector<ObjectID> &ExecutionDependencies() const;

/// Set the task's execution dependencies.
///
Expand Down
8 changes: 4 additions & 4 deletions src/ray/object_manager/object_store_notification_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ void ObjectStoreNotificationManager::ProcessStoreNotification(
}

void ObjectStoreNotificationManager::ProcessStoreAdd(const ObjectInfoT &object_info) {
for (auto handler : add_handlers_) {
for (auto &handler : add_handlers_) {
handler(object_info);
}
}

void ObjectStoreNotificationManager::ProcessStoreRemove(const ObjectID &object_id) {
for (auto handler : rem_handlers_) {
for (auto &handler : rem_handlers_) {
handler(object_id);
}
}

void ObjectStoreNotificationManager::SubscribeObjAdded(
std::function<void(const ObjectInfoT &)> callback) {
add_handlers_.push_back(callback);
add_handlers_.push_back(std::move(callback));
}

void ObjectStoreNotificationManager::SubscribeObjDeleted(
std::function<void(const ObjectID &)> callback) {
rem_handlers_.push_back(callback);
rem_handlers_.push_back(std::move(callback));
}

} // namespace ray
30 changes: 18 additions & 12 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,14 @@ void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &cl
}
// Locate the client id in remote client table and update available resources based on
// the received heartbeat information.
if (this->cluster_resource_map_.count(client_id) == 0) {
auto it = this->cluster_resource_map_.find(client_id);
if (it == cluster_resource_map_.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the .count() == 0 is supposed to be idiomatically equivalent. Do you find your proposed way better in performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count() means we need try to find a item in the map. and cluster_resource_map_[client_id] will find it again. Although find is O(n) operation for unordered_map, but calculation hash and function call will have cost, right. The fix will change find operation from twice to once

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I know, I said "nevermind" below, when I saw you reusing the iterator ;)

// Haven't received the client registration for this client yet, skip this heartbeat.
RAY_LOG(INFO) << "[HeartbeatAdded]: received heartbeat from unknown client id "
<< client_id;
return;
}
SchedulingResources &resources = this->cluster_resource_map_[client_id];
SchedulingResources &resources = it->second;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. nvm.

ResourceSet heartbeat_resource_available(heartbeat_data.resources_available_label,
heartbeat_data.resources_available_capacity);
resources.SetAvailableResources(
Expand Down Expand Up @@ -300,26 +301,30 @@ void NodeManager::ProcessNewClient(LocalClientConnection &client) {

void NodeManager::DispatchTasks() {
// Work with a copy of scheduled tasks.
auto scheduled_tasks = local_queues_.GetScheduledTasks();
const auto &scheduled_tasks = local_queues_.GetScheduledTasks();
// Return if there are no tasks to schedule.
if (scheduled_tasks.empty()) {
return;
}
const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId();
const auto &local_resources =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atumanov @eric-jj

There may be a bug here (or at least a change in behavior). Before, we called AssignTask inside the for loop, which presumably modified local_resources. However, now we call AssignTask at the very end, so the check if (!task_resources.IsSubset(local_resources)) { may give different results now.

I will try reverting some of this and see if that makes the problem I'm seeing go away.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will investigate it.

Copy link
Contributor

@atumanov atumanov May 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertnishihara good catch. That's the reason why I didn't batch AssignTask calls before.
this->cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources())); this line acquires resources when the task is assigned. This makes sure that the locally available resources are updated at each for loop iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. you are right, I have reverted it back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why I can't reopen the request after I rebase from master branch, have created a new pull request.

cluster_resource_map_[my_client_id].GetAvailableResources();

std::unordered_set<TaskID> dispatched_task_ids;
for (const auto &task : scheduled_tasks) {
const auto &local_resources =
cluster_resource_map_[my_client_id].GetAvailableResources();
const auto &task_resources = task.GetTaskSpecification().GetRequiredResources();
if (!task_resources.IsSubset(local_resources)) {
// Not enough local resources for this task right now, skip this task.
continue;
}
// We have enough resources for this task. Assign task.
// TODO(atumanov): perform the task state/queue transition inside AssignTask.
auto dispatched_task =
local_queues_.RemoveTasks({task.GetTaskSpecification().TaskId()});
AssignTask(dispatched_task.front());
dispatched_task_ids.insert(task.GetTaskSpecification().TaskId());
}

// We have enough resources for this task. Assign task.
// TODO(atumanov): perform the task state/queue transition inside AssignTask.
auto dispatched_tasks = local_queues_.RemoveTasks(std::move(dispatched_task_ids));
for (auto &task : dispatched_tasks) {
AssignTask(task);
}
}

Expand Down Expand Up @@ -786,13 +791,14 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id)
auto client_info = gcs_client_->client_table().GetClient(node_id);

// Lookup remote server connection for this node_id and use it to send the request.
if (remote_server_connections_.count(node_id) == 0) {
auto it = remote_server_connections_.find(node_id);
if (it == remote_server_connections_.end()) {
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
return ray::Status::IOError("NodeManager connection not found");
}

auto &server_conn = remote_server_connections_.at(node_id);
auto &server_conn = it->second;
auto status = server_conn.WriteMessage(protocol::MessageType_ForwardTaskRequest,
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
Expand Down
2 changes: 1 addition & 1 deletion src/ray/raylet/scheduling_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::unordered_map<TaskID, ClientID> SchedulingPolicy::Schedule(
for (const auto &client_resource_pair : cluster_resources) {
// pair = ClientID, SchedulingResources
ClientID node_client_id = client_resource_pair.first;
SchedulingResources node_resources = client_resource_pair.second;
const auto &node_resources = client_resource_pair.second;
RAY_LOG(DEBUG) << "client_id " << node_client_id << " resources: "
<< node_resources.GetAvailableResources().ToString();
if (resource_demand.IsSubset(node_resources.GetTotalResources())) {
Expand Down
6 changes: 2 additions & 4 deletions src/ray/raylet/scheduling_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ void removeTasksFromQueue(std::list<Task> &queue, std::unordered_set<TaskID> &ta
}

// Helper function to queue the given tasks to the given queue.
void queueTasks(std::list<Task> &queue, const std::vector<Task> &tasks) {
for (auto &task : tasks) {
queue.push_back(task);
}
inline void queueTasks(std::list<Task> &queue, const std::vector<Task> &tasks) {
queue.insert(queue.end(), tasks.begin(), tasks.end());
}

std::vector<Task> SchedulingQueue::RemoveTasks(std::unordered_set<TaskID> task_ids) {
Expand Down
5 changes: 3 additions & 2 deletions src/ray/raylet/worker_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ void WorkerPool::RegisterWorker(std::shared_ptr<Worker> worker) {
auto pid = worker->Pid();
RAY_LOG(DEBUG) << "Registering worker with pid " << pid;
registered_workers_.push_back(std::move(worker));
RAY_CHECK(started_worker_pids_.count(pid) > 0);
started_worker_pids_.erase(pid);
auto it = started_worker_pids_.find(pid);
RAY_CHECK(it != started_worker_pids_.end());
started_worker_pids_.erase(it);
}

std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
Expand Down