Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/ray/raylet/lineage_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,20 @@ bool LineageCache::CommitTask(const Task &task) {
}
}

void LineageCache::FlushAllUncommittedTasks() {
size_t num_flushed = 0;
for (const auto &entry : lineage_.GetEntries()) {
// Flush all tasks that have not yet committed.
if (entry.second.GetStatus() == GcsStatus::UNCOMMITTED) {
RAY_CHECK(UnsubscribeTask(entry.first));
FlushTask(entry.first);
num_flushed++;
}
}

RAY_LOG(DEBUG) << "Flushed " << num_flushed << " uncommitted tasks";
}

void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) {
RAY_CHECK(!node_id.IsNil());
auto entry = lineage_.GetEntryMutable(task_id);
Expand Down
7 changes: 7 additions & 0 deletions src/ray/raylet/lineage_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ class LineageCache {
/// task was already in the COMMITTING state.
bool CommitTask(const Task &task);

/// Flush all tasks in the local cache that are not already being
/// committed. This is equivalent to all tasks in the UNCOMMITTED
/// state.
///
/// \return Void.
void FlushAllUncommittedTasks();

/// Add a task and its (estimated) uncommitted lineage to the local cache. We
/// will subscribe to commit notifications for all uncommitted tasks to
/// determine when it is safe to evict the lineage from the local cache.
Expand Down
55 changes: 54 additions & 1 deletion src/ray/raylet/lineage_cache_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,22 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
std::shared_ptr<protocol::TaskT> &task_data,
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
task_table_[task_id] = task_data;
auto callback = done;
// If we requested notifications for this task ID, send the notification as
// part of the callback.
if (subscribed_tasks_.count(task_id) == 1) {
callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
const protocol::TaskT &data) {
done(client, task_id, data);
// If we're subscribed to the task to be added, also send a
// subscription notification.
notification_callback_(client, task_id, data);
};
}

callbacks_.push_back(
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(done, task_id));
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(callback, task_id));
num_task_adds_++;
return ray::Status::OK();
}

Expand Down Expand Up @@ -78,28 +92,34 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,

const int NumRequestedNotifications() const { return num_requested_notifications_; }

const int NumTaskAdds() const { return num_task_adds_; }

private:
std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>> task_table_;
std::vector<std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>> callbacks_;
gcs::raylet::TaskTable::WriteCallback notification_callback_;
std::unordered_set<TaskID> subscribed_tasks_;
int num_requested_notifications_ = 0;
int num_task_adds_ = 0;
};

class LineageCacheTest : public ::testing::Test {
public:
LineageCacheTest()
: max_lineage_size_(10),
num_notifications_(0),
mock_gcs_(),
lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) {
mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
const ray::protocol::TaskT &data) {
lineage_cache_.HandleEntryCommitted(task_id);
num_notifications_++;
});
}

protected:
uint64_t max_lineage_size_;
uint64_t num_notifications_;
MockGcs mock_gcs_;
LineageCache lineage_cache_;
};
Expand Down Expand Up @@ -529,6 +549,39 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
}

TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) {
// Insert a chain of tasks.
std::vector<Task> tasks;
auto return_values =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids;
for (const auto &task : tasks) {
task_ids.push_back(task.GetTaskSpecification().TaskId());
}
// Check that we subscribed to each of the uncommitted tasks.
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());

// Flush all uncommitted tasks and make sure we add all tasks to
// the task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());

// Flush all GCS notifications.
mock_gcs_.Flush();
// Make sure that we unsubscribed to the uncommitted tasks before
// we flushed them.
ASSERT_EQ(num_notifications_, 0);

// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
}

} // namespace raylet

} // namespace ray
Expand Down
5 changes: 5 additions & 0 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// Notify the object directory that the client has been removed so that it
// can remove it from any cached locations.
object_directory_->HandleClientRemoved(client_id);

// Flush all uncommitted tasks from the local lineage cache. This is to
// guarantee that all tasks get flushed eventually, in case one of the tasks
// in our local cache was supposed to be flushed by the node that died.
lineage_cache_.FlushAllUncommittedTasks();
}

void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) {
Expand Down