diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 795f2b54a6cb..fcae3f8e04c9 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -221,6 +221,20 @@ bool LineageCache::CommitTask(const Task &task) { } } +void LineageCache::FlushAllUncommittedTasks() { + size_t num_flushed = 0; + for (const auto &entry : lineage_.GetEntries()) { + // Flush all tasks that have not yet committed. + if (entry.second.GetStatus() == GcsStatus::UNCOMMITTED) { + RAY_CHECK(UnsubscribeTask(entry.first)); + FlushTask(entry.first); + num_flushed++; + } + } + + RAY_LOG(DEBUG) << "Flushed " << num_flushed << " uncommitted tasks"; +} + void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { RAY_CHECK(!node_id.IsNil()); auto entry = lineage_.GetEntryMutable(task_id); diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 2dff0e94a4d1..5436fa372fa4 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -231,6 +231,13 @@ class LineageCache { /// task was already in the COMMITTING state. bool CommitTask(const Task &task); + /// Flush all tasks in the local cache that are not already being + /// committed. This is equivalent to all tasks in the UNCOMMITTED + /// state. + /// + /// \return Void. + void FlushAllUncommittedTasks(); + /// Add a task and its (estimated) uncommitted lineage to the local cache. We /// will subscribe to commit notifications for all uncommitted tasks to /// determine when it is safe to evict the lineage from the local cache. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index e5c126bcf078..43e64e400292 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -26,8 +26,22 @@ class MockGcs : public gcs::TableInterface, std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; + auto callback = done; + // If we requested notifications for this task ID, send the notification as + // part of the callback. + if (subscribed_tasks_.count(task_id) == 1) { + callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + const protocol::TaskT &data) { + done(client, task_id, data); + // If we're subscribed to the task to be added, also send a + // subscription notification. + notification_callback_(client, task_id, data); + }; + } + callbacks_.push_back( - std::pair(done, task_id)); + std::pair(callback, task_id)); + num_task_adds_++; return ray::Status::OK(); } @@ -78,28 +92,34 @@ class MockGcs : public gcs::TableInterface, const int NumRequestedNotifications() const { return num_requested_notifications_; } + const int NumTaskAdds() const { return num_task_adds_; } + private: std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; int num_requested_notifications_ = 0; + int num_task_adds_ = 0; }; class LineageCacheTest : public ::testing::Test { public: LineageCacheTest() : max_lineage_size_(10), + num_notifications_(0), mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); + num_notifications_++; }); } protected: uint64_t max_lineage_size_; + uint64_t num_notifications_; MockGcs mock_gcs_; LineageCache lineage_cache_; }; @@ -529,6 +549,39 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); } +TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) { + // Insert a chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + // Check that we subscribed to each of the uncommitted tasks. + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Flush all uncommitted tasks and make sure we add all tasks to + // the task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + + // Flush all GCS notifications. + mock_gcs_.Flush(); + // Make sure that we unsubscribed to the uncommitted tasks before + // we flushed them. + ASSERT_EQ(num_notifications_, 0); + + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b710b0873b0c..83f4636541e1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -475,6 +475,11 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // Notify the object directory that the client has been removed so that it // can remove it from any cached locations. object_directory_->HandleClientRemoved(client_id); + + // Flush all uncommitted tasks from the local lineage cache. This is to + // guarantee that all tasks get flushed eventually, in case one of the tasks + // in our local cache was supposed to be flushed by the node that died. + lineage_cache_.FlushAllUncommittedTasks(); } void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) {