From 3b1f48786e0fa11ef9240297982f6726eec5733a Mon Sep 17 00:00:00 2001 From: davik Date: Thu, 9 Oct 2025 03:52:49 +0000 Subject: [PATCH 1/5] Refactor reference_counter out of memory store and plasma store Signed-off-by: davik --- .../runtime/object/local_mode_object_store.cc | 9 ++- src/mock/ray/core_worker/memory_store.h | 2 +- src/ray/core_worker/BUILD.bazel | 4 +- src/ray/core_worker/core_worker.cc | 60 +++++++++++--- src/ray/core_worker/core_worker.h | 8 ++ src/ray/core_worker/core_worker_process.cc | 3 +- src/ray/core_worker/future_resolver.cc | 11 ++- src/ray/core_worker/future_resolver.h | 1 + .../core_worker/object_recovery_manager.cc | 7 +- .../memory_store/memory_store.cc | 54 +++---------- .../memory_store/memory_store.h | 18 ++--- .../store_provider/plasma_store_provider.cc | 51 ++++++++---- .../store_provider/plasma_store_provider.h | 9 +-- src/ray/core_worker/task_manager.cc | 29 ++++--- src/ray/core_worker/task_manager.h | 1 + .../core_worker/task_submission/BUILD.bazel | 1 + .../task_submission/actor_task_submitter.h | 1 + .../tests/actor_task_submitter_test.cc | 15 ++-- .../tests/dependency_resolver_test.cc | 30 +++---- .../tests/normal_task_submitter_test.cc | 23 +++--- src/ray/core_worker/tests/core_worker_test.cc | 27 ++++--- .../core_worker/tests/memory_store_test.cc | 56 ++++++------- .../tests/object_recovery_manager_test.cc | 6 +- .../tests/reference_counter_test.cc | 9 +-- .../core_worker/tests/task_manager_test.cc | 80 +++++++++---------- 25 files changed, 287 insertions(+), 228 deletions(-) diff --git a/cpp/src/ray/runtime/object/local_mode_object_store.cc b/cpp/src/ray/runtime/object/local_mode_object_store.cc index 4c6ea46e22a1..b1c8030e7c16 100644 --- a/cpp/src/ray/runtime/object/local_mode_object_store.cc +++ b/cpp/src/ray/runtime/object/local_mode_object_store.cc @@ -28,7 +28,8 @@ namespace internal { LocalModeObjectStore::LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime) : io_context_("LocalModeObjectStore"), local_mode_ray_tuntime_(local_mode_ray_tuntime) { - memory_store_ = std::make_unique(io_context_.GetIoService()); + memory_store_ = std::make_unique(io_context_.GetIoService(), + /*reference_counting=*/false); } void LocalModeObjectStore::PutRaw(std::shared_ptr data, @@ -41,8 +42,11 @@ void LocalModeObjectStore::PutRaw(std::shared_ptr data, const ObjectID &object_id) { auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( reinterpret_cast(data->data()), data->size(), true); + // NOTE: you can't have reference when reference counting is disabled in local mode memory_store_->Put( - ::ray::RayObject(buffer, nullptr, std::vector()), object_id); + ::ray::RayObject(buffer, nullptr, std::vector()), + object_id, + /*has_reference=*/false); } std::shared_ptr LocalModeObjectStore::GetRaw(const ObjectID &object_id, @@ -61,7 +65,6 @@ std::vector> LocalModeObjectStore::GetRaw( (int)ids.size(), timeout_ms, local_mode_ray_tuntime_.GetWorkerContext(), - false, &results); if (!status.ok()) { throw RayException("Get object error: " + status.ToString()); diff --git a/src/mock/ray/core_worker/memory_store.h b/src/mock/ray/core_worker/memory_store.h index 33753926ed18..28ce9cef618e 100644 --- a/src/mock/ray/core_worker/memory_store.h +++ b/src/mock/ray/core_worker/memory_store.h @@ -46,7 +46,7 @@ class DefaultCoreWorkerMemoryStoreWithThread : public CoreWorkerMemoryStore { private: explicit DefaultCoreWorkerMemoryStoreWithThread( std::unique_ptr io_context) - : CoreWorkerMemoryStore(io_context->GetIoService()), + : CoreWorkerMemoryStore(io_context->GetIoService(), /*reference_counting=*/false), io_context_(std::move(io_context)) {} std::unique_ptr io_context_; diff --git a/src/ray/core_worker/BUILD.bazel b/src/ray/core_worker/BUILD.bazel index 9106b4a32fd1..6ec02b8d0fbe 100644 --- a/src/ray/core_worker/BUILD.bazel +++ b/src/ray/core_worker/BUILD.bazel @@ -272,7 +272,6 @@ ray_cc_library( hdrs = ["store_provider/memory_store/memory_store.h"], deps = [ ":core_worker_context", - ":reference_counter_interface", "//src/ray/common:asio", "//src/ray/common:id", "//src/ray/common:ray_config", @@ -290,6 +289,7 @@ ray_cc_library( name = "task_manager_interface", hdrs = ["task_manager_interface.h"], deps = [ + ":reference_counter_interface", "//src/ray/common:id", "//src/ray/common:status", "//src/ray/common:task_common", @@ -348,6 +348,7 @@ ray_cc_library( hdrs = ["future_resolver.h"], deps = [ ":memory_store", + ":reference_counter_interface", "//src/ray/common:id", "//src/ray/core_worker_rpc_client:core_worker_client_pool", ], @@ -408,7 +409,6 @@ ray_cc_library( deps = [ ":common", ":core_worker_context", - ":reference_counter_interface", "//src/ray/common:buffer", "//src/ray/common:id", "//src/ray/common:ray_config", diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 5681fcde52de..69c773d780ac 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -982,7 +982,9 @@ Status CoreWorker::PutInLocalPlasmaStore(const RayObject &object, RAY_RETURN_NOT_OK(plasma_store_provider_->Release(object_id)); } } - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); return Status::OK(); } @@ -993,7 +995,7 @@ Status CoreWorker::Put(const RayObject &object, RAY_RETURN_NOT_OK(WaitForActorRegistered(contained_object_ids)); if (options_.is_local_mode) { RAY_LOG(DEBUG).WithField(object_id) << "Put object in memory store"; - memory_store_->Put(object, object_id); + memory_store_->Put(object, object_id, reference_counter_->HasReference(object_id)); return Status::OK(); } return PutInLocalPlasmaStore(object, object_id, pin_object); @@ -1092,7 +1094,9 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( } else if (*data == nullptr) { // Object already exists in plasma. Store the in-memory value so that the // client will check the plasma store. - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), *object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + *object_id, + reference_counter_->HasReference(*object_id)); } } return Status::OK(); @@ -1189,7 +1193,9 @@ Status CoreWorker::SealExisting(const ObjectID &object_id, RAY_RETURN_NOT_OK(plasma_store_provider_->Release(object_id)); reference_counter_->FreePlasmaObjects({object_id}); } - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); return Status::OK(); } @@ -1361,13 +1367,18 @@ Status CoreWorker::GetObjects(const std::vector &ids, // If any of the objects have been promoted to plasma, then we retry their // gets at the provider plasma. Once we get the objects from plasma, we flip // the transport type again and return them for the original direct call ids. + + // Resolve owner addresses of plasma ids + absl::flat_hash_map plasma_object_ids_map = + GetObjectIdToOwnerAddressMap(plasma_object_ids); + int64_t local_timeout_ms = timeout_ms; if (timeout_ms >= 0) { local_timeout_ms = std::max(static_cast(0), timeout_ms - (current_time_ms() - start_time)); } RAY_LOG(DEBUG) << "Plasma GET timeout " << local_timeout_ms; - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, + RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids_map, local_timeout_ms, *worker_context_, &result_map, @@ -1515,8 +1526,12 @@ Status CoreWorker::Wait(const std::vector &ids, // num_objects ready since we want to at least make the request to start pulling // these objects. if (!plasma_object_ids.empty()) { + // Prepare object ids map + absl::flat_hash_map plasma_object_ids_map = + GetObjectIdToOwnerAddressMap(plasma_object_ids); + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, + plasma_object_ids_map, std::min(static_cast(plasma_object_ids.size()), num_objects - static_cast(ready.size())), timeout_ms, @@ -3003,8 +3018,12 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, reference_counter_->AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); + // Resolve owner address of return id + absl::flat_hash_map return_id_map = + GetObjectIdToOwnerAddressMap({return_id}); + auto status = plasma_store_provider_->Get( - {return_id}, 0, *worker_context_, &result_map, &got_exception); + return_id_map, 0, *worker_context_, &result_map, &got_exception); // Remove the temporary ref. RemoveLocalReference(return_id); @@ -3222,7 +3241,8 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, // otherwise, the put is a no-op. if (!options_.is_local_mode) { memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - task.ArgObjectId(i)); + task.ArgObjectId(i), + reference_counter_->HasReference(task.ArgObjectId(i))); } } else { // A pass-by-value argument. @@ -3271,8 +3291,11 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, RAY_RETURN_NOT_OK(memory_store_->Get( by_ref_ids, -1, *worker_context_, &result_map, &got_exception)); } else { + // Resolve owner addresses of by-ref ids + absl::flat_hash_map by_ref_ids_map = + GetObjectIdToOwnerAddressMap(by_ref_ids); RAY_RETURN_NOT_OK(plasma_store_provider_->Get( - by_ref_ids, -1, *worker_context_, &result_map, &got_exception)); + by_ref_ids_map, -1, *worker_context_, &result_map, &got_exception)); } for (const auto &it : result_map) { for (size_t idx : by_ref_indices[it.first]) { @@ -3463,6 +3486,17 @@ void CoreWorker::PopulateObjectStatus(const ObjectID &object_id, } } +absl::flat_hash_map CoreWorker::GetObjectIdToOwnerAddressMap( + const absl::flat_hash_set &object_ids) { + std::vector object_ids_vector(object_ids.begin(), object_ids.end()); + const auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids_vector); + absl::flat_hash_map object_id_map; + for (size_t i = 0; i < object_ids_vector.size(); i++) { + object_id_map[object_ids_vector[i]] = owner_addresses[i]; + } + return object_id_map; +} + void CoreWorker::HandleWaitForActorRefDeleted( rpc::WaitForActorRefDeletedRequest request, rpc::WaitForActorRefDeletedReply *reply, @@ -4103,7 +4137,9 @@ Status CoreWorker::DeleteImpl(const std::vector &object_ids, bool loca memory_store_->Delete(object_ids); for (const auto &object_id : object_ids) { RAY_LOG(DEBUG).WithField(object_id) << "Freeing object"; - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_FREED), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_FREED), + object_id, + reference_counter_->HasReference(object_id)); } // We only delete from plasma, which avoids hangs (issue #7105). In-memory @@ -4253,7 +4289,9 @@ void CoreWorker::HandleAssignObjectOwner(rpc::AssignObjectOwnerRequest request, /*add_local_ref=*/false, /*pinned_at_node_id=*/NodeID::FromBinary(borrower_address.node_id())); reference_counter_->AddBorrowerAddress(object_id, borrower_address); - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index f40d26ecdcb5..1d0a35c5ae4d 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1429,6 +1429,14 @@ class CoreWorker { const std::shared_ptr &obj, rpc::GetObjectStatusReply *reply); + /// Helper method to construct a map from ObjectIDs to their owner addresses. + /// This method helps prepare inputs for calls to the plasma store. + /// + /// \param[in] object_ids Set of ObjectIDs to look up. + /// \return A map from ObjectID to owner address (rpc::Address). + absl::flat_hash_map GetObjectIdToOwnerAddressMap( + const absl::flat_hash_set &object_ids); + /// /// Private methods related to task submission. /// diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index c66ea83bc2d8..fcd132236ee2 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -354,7 +354,6 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( auto plasma_store_provider = std::make_shared( options.store_socket, raylet_ipc_client, - *reference_counter, options.check_signals, /*warmup=*/ (options.worker_type != WorkerType::SPILL_WORKER && @@ -367,7 +366,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( }); auto memory_store = std::make_shared( io_service_, - reference_counter.get(), + /*reference_counting=*/reference_counter != nullptr, raylet_ipc_client, options.check_signals, [this](const RayObject &obj) { diff --git a/src/ray/core_worker/future_resolver.cc b/src/ray/core_worker/future_resolver.cc index 153c06a0f84f..8654e64d2228 100644 --- a/src/ray/core_worker/future_resolver.cc +++ b/src/ray/core_worker/future_resolver.cc @@ -52,14 +52,18 @@ void FutureResolver::ProcessResolvedObject(const ObjectID &object_id, if (!status.ok()) { // The owner is unreachable. Store an error so that an exception will be // thrown immediately when the worker tries to get the value. - in_memory_store_->Put(RayObject(rpc::ErrorType::OWNER_DIED), object_id); + in_memory_store_->Put(RayObject(rpc::ErrorType::OWNER_DIED), + object_id, + reference_counter_->HasReference(object_id)); } else if (reply.status() == rpc::GetObjectStatusReply::OUT_OF_SCOPE) { // The owner replied that the object has gone out of scope (this is an edge // case in the distributed ref counting protocol where a borrower dies // before it can notify the owner of another borrower). Store an error so // that an exception will be thrown immediately when the worker tries to // get the value. - in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_DELETED), object_id); + in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_DELETED), + object_id, + reference_counter_->HasReference(object_id)); } else if (reply.status() == rpc::GetObjectStatusReply::CREATED) { // The object is either an indicator that the object is in Plasma, or // the object has been returned directly in the reply. In either @@ -106,7 +110,8 @@ void FutureResolver::ProcessResolvedObject(const ObjectID &object_id, inlined_ref.owner_address()); } in_memory_store_->Put(RayObject(data_buffer, metadata_buffer, inlined_refs), - object_id); + object_id, + reference_counter_->HasReference(object_id)); } } diff --git a/src/ray/core_worker/future_resolver.h b/src/ray/core_worker/future_resolver.h index 04caaeba8ffb..a48471f5a19b 100644 --- a/src/ray/core_worker/future_resolver.h +++ b/src/ray/core_worker/future_resolver.h @@ -18,6 +18,7 @@ #include #include "ray/common/id.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker_rpc_client/core_worker_client_pool.h" #include "src/ray/protobuf/core_worker.pb.h" diff --git a/src/ray/core_worker/object_recovery_manager.cc b/src/ray/core_worker/object_recovery_manager.cc index 893fb4fabd6d..83a6cad70648 100644 --- a/src/ray/core_worker/object_recovery_manager.cc +++ b/src/ray/core_worker/object_recovery_manager.cc @@ -81,7 +81,9 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) { // (core_worker.cc removes the object from memory store before calling this method), // we need to add it back to indicate that it's available. // If the object is already in the memory store then the put is a no-op. - in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_.HasReference(object_id)); } return true; } @@ -121,7 +123,8 @@ void ObjectRecoveryManager::PinExistingObjectCopy( const Status &status, const rpc::PinObjectIDsReply &reply) mutable { if (status.ok() && reply.successes(0)) { in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - object_id); + object_id, + reference_counter_.HasReference(object_id)); reference_counter_.UpdateObjectPinnedAtRaylet(object_id, node_id); } else { RAY_LOG(INFO).WithField(object_id) diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index bd7e2a63deb5..6c11d7ee50c0 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -42,7 +42,6 @@ class GetRequest { public: GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception); const absl::flat_hash_set &ObjectIds() const; @@ -56,8 +55,6 @@ class GetRequest { void Set(const ObjectID &object_id, std::shared_ptr buffer); /// Get the object content for the specific object id. std::shared_ptr Get(const ObjectID &object_id) const; - /// Whether this is a `get` request. - bool ShouldRemoveObjects() const; private: /// The object IDs involved in this request. @@ -67,9 +64,6 @@ class GetRequest { /// Number of objects required. const size_t num_objects_; - // Whether the requested objects should be removed from store - // after `get` returns. - const bool remove_after_get_; // Whether we should abort the waiting if any object is an exception. const bool abort_if_any_object_is_exception_; // Whether all the requested objects are available. @@ -80,19 +74,15 @@ class GetRequest { GetRequest::GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception) : object_ids_(std::move(object_ids)), num_objects_(num_objects), - remove_after_get_(remove_after_get), abort_if_any_object_is_exception_(abort_if_any_object_is_exception) { RAY_CHECK(num_objects_ <= object_ids_.size()); } const absl::flat_hash_set &GetRequest::ObjectIds() const { return object_ids_; } -bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; } - bool GetRequest::Wait(int64_t timeout_ms) { RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); if (timeout_ms == -1) { @@ -137,14 +127,14 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore( instrumented_io_context &io_context, - ReferenceCounterInterface *counter, + bool reference_counting, std::shared_ptr raylet_ipc_client, std::function check_signals, std::function unhandled_exception_handler, std::function( const ray::RayObject &object, const ObjectID &object_id)> object_allocator) : io_context_(io_context), - ref_counter_(counter), + reference_counting_(reference_counting), raylet_ipc_client_(std::move(raylet_ipc_client)), check_signals_(std::move(check_signals)), unhandled_exception_handler_(std::move(unhandled_exception_handler)), @@ -180,7 +170,9 @@ std::shared_ptr CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob return ptr; } -void CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { +void CoreWorkerMemoryStore::Put(const RayObject &object, + const ObjectID &object_id, + const bool has_reference) { std::vector)>> async_callbacks; RAY_LOG(DEBUG).WithField(object_id) << "Putting object into memory store."; std::shared_ptr object_entry = nullptr; @@ -217,14 +209,10 @@ void CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ auto &get_requests = object_request_iter->second; for (auto &get_request : get_requests) { get_request->Set(object_id, object_entry); - // If ref counting is enabled, override the removal behaviour. - if (get_request->ShouldRemoveObjects() && ref_counter_ == nullptr) { - should_add_entry = false; - } } } // Don't put it in the store, since we won't get a callback for deletion. - if (ref_counter_ != nullptr && !ref_counter_->HasReference(object_id)) { + if (reference_counting_ && !has_reference) { should_add_entry = false; } @@ -261,13 +249,11 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results) { return GetImpl(object_ids, num_objects, timeout_ms, ctx, - remove_after_get, results, /*abort_if_any_object_is_exception=*/true, /*at_most_num_objects=*/true); @@ -277,7 +263,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception, bool at_most_num_objects) { @@ -288,7 +273,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, { absl::flat_hash_set remaining_ids; - absl::flat_hash_set ids_to_remove; bool existing_objects_has_exception = false; absl::MutexLock lock(&mu_); @@ -299,11 +283,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, if (iter != objects_.end()) { iter->second->SetAccessed(); (*results)[i] = iter->second; - if (remove_after_get) { - // Note that we cannot remove the object_id from `objects_` now, - // because `object_ids` might have duplicate ids. - ids_to_remove.insert(object_id); - } num_found += 1; if (abort_if_any_object_is_exception && iter->second->IsException() && !iter->second->IsInPlasmaError()) { @@ -318,13 +297,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, } } - // Clean up the objects if ref counting is off. - if (ref_counter_ == nullptr) { - for (const auto &object_id : ids_to_remove) { - EraseObjectAndUpdateStats(object_id); - } - } - // Return if all the objects are obtained, or any existing objects are known to have // exception. if (remaining_ids.empty() || num_found >= num_objects || @@ -335,10 +307,8 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, size_t required_objects = num_objects - num_found; // Otherwise, create a GetRequest to track remaining objects. - get_request = std::make_shared(std::move(remaining_ids), - required_objects, - remove_after_get, - abort_if_any_object_is_exception); + get_request = std::make_shared( + std::move(remaining_ids), required_objects, abort_if_any_object_is_exception); for (const auto &object_id : get_request->ObjectIds()) { object_get_requests_[object_id].push_back(get_request); } @@ -426,12 +396,7 @@ Status CoreWorkerMemoryStore::Get( bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; - RAY_RETURN_NOT_OK(Get(id_vector, - id_vector.size(), - timeout_ms, - ctx, - /*remove_after_get=*/false, - &result_objects)); + RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -460,7 +425,6 @@ Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_i num_objects, timeout_ms, ctx, - false, &result_objects, /*abort_if_any_object_is_exception=*/false, /*at_most_num_objects=*/false); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index e305a64b0d01..33b337087c2b 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -26,7 +26,6 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/reference_counter_interface.h" #include "ray/raylet_ipc_client/raylet_ipc_client_interface.h" #include "ray/rpc/utils.h" @@ -49,12 +48,10 @@ class CoreWorkerMemoryStore { /// Create a memory store. /// /// \param[in] io_context Posts async callbacks to this context. - /// \param[in] counter If not null, this enables ref counting for local objects, - /// and the `remove_after_get` flag for Get() will be ignored. /// \param[in] raylet_ipc_client If not null, used to notify tasks blocked / unblocked. explicit CoreWorkerMemoryStore( instrumented_io_context &io_context, - ReferenceCounterInterface *counter = nullptr, + bool reference_counting, std::shared_ptr raylet_ipc_client = nullptr, std::function check_signals = nullptr, std::function unhandled_exception_handler = nullptr, @@ -68,7 +65,9 @@ class CoreWorkerMemoryStore { /// /// \param[in] object The ray object. /// \param[in] object_id Object ID specified by user. - void Put(const RayObject &object, const ObjectID &object_id); + /// \param[in] has_reference Whether the object has a reference in the reference + /// counter. + void Put(const RayObject &object, const ObjectID &object_id, const bool has_reference); /// Get a list of objects from the object store. /// @@ -76,15 +75,12 @@ class CoreWorkerMemoryStore { /// \param[in] num_objects Number of objects that should appear. /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[in] ctx The current worker context. - /// \param[in] remove_after_get When to remove the objects from store after `Get` - /// finishes. This has no effect if ref counting is enabled. /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results); /// Convenience wrapper around Get() that stores results in a given result map. @@ -186,7 +182,6 @@ class CoreWorkerMemoryStore { int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception, bool at_most_num_objects); @@ -206,9 +201,8 @@ class CoreWorkerMemoryStore { instrumented_io_context &io_context_; - /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this - /// mandatory once Java is supported. - ReferenceCounterInterface *ref_counter_; + /// Set to true if reference counting is enabled (i.e. not local mode). + bool reference_counting_; // If set, this will be used to notify worker blocked / unblocked on get calls. std::shared_ptr raylet_ipc_client_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index dd0d93897182..4a0bdec644d2 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -64,7 +64,6 @@ BufferTracker::UsedObjects() const { CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_ipc_client, - ReferenceCounterInterface &reference_counter, std::function check_signals, bool warmup, std::shared_ptr store_client, @@ -72,7 +71,6 @@ CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( std::function get_current_call_site) : raylet_ipc_client_(raylet_ipc_client), store_client_(std::move(store_client)), - reference_counter_(reference_counter), check_signals_(std::move(check_signals)), fetch_batch_size_(fetch_batch_size) { if (get_current_call_site != nullptr) { @@ -180,11 +178,12 @@ Status CoreWorkerPlasmaStoreProvider::Release(const ObjectID &object_id) { Status CoreWorkerPlasmaStoreProvider::PullObjectsAndGetFromPlasmaStore( absl::flat_hash_set &remaining, const std::vector &batch_ids, + const std::vector &batch_owner_addresses, int64_t timeout_ms, absl::flat_hash_map> *results, bool *got_exception) { - const auto owner_addresses = reference_counter_.GetOwnerAddresses(batch_ids); - RAY_RETURN_NOT_OK(raylet_ipc_client_->AsyncGetObjects(batch_ids, owner_addresses)); + RAY_RETURN_NOT_OK( + raylet_ipc_client_->AsyncGetObjects(batch_ids, batch_owner_addresses)); std::vector plasma_results; RAY_RETURN_NOT_OK(store_client_->Get(batch_ids, timeout_ms, &plasma_results)); @@ -271,25 +270,31 @@ Status UnblockIfNeeded( } Status CoreWorkerPlasmaStoreProvider::Get( - const absl::flat_hash_set &object_ids, + const absl::flat_hash_map &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { - std::vector batch_ids; - absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); + absl::flat_hash_set remaining; // Send initial requests to pull all objects in parallel. - std::vector id_vector(object_ids.begin(), object_ids.end()); + std::vector> id_vector(object_ids.begin(), + object_ids.end()); int64_t total_size = static_cast(object_ids.size()); for (int64_t start = 0; start < total_size; start += fetch_batch_size_) { - batch_ids.clear(); + std::vector batch_ids; + std::vector batch_owner_addresses; for (int64_t i = start; i < start + fetch_batch_size_ && i < total_size; i++) { - batch_ids.push_back(id_vector[i]); + // Construct remaining set a batch at a time + remaining.insert(id_vector[i].first); + + batch_ids.push_back(id_vector[i].first); + batch_owner_addresses.push_back(id_vector[i].second); } RAY_RETURN_NOT_OK( PullObjectsAndGetFromPlasmaStore(remaining, batch_ids, + batch_owner_addresses, /*timeout_ms=*/0, // Mutable objects must be local before ray.get. results, @@ -310,12 +315,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( int64_t remaining_timeout = timeout_ms; auto fetch_start_time_ms = current_time_ms(); while (!remaining.empty() && !should_break) { - batch_ids.clear(); + std::vector batch_ids; + std::vector batch_owner_addresses; for (const auto &id : remaining) { if (static_cast(batch_ids.size()) == fetch_batch_size_) { break; } batch_ids.push_back(id); + batch_owner_addresses.push_back(object_ids.find(id)->second); } int64_t batch_timeout = @@ -328,8 +335,12 @@ Status CoreWorkerPlasmaStoreProvider::Get( } size_t previous_size = remaining.size(); - RAY_RETURN_NOT_OK(PullObjectsAndGetFromPlasmaStore( - remaining, batch_ids, batch_timeout, results, got_exception)); + RAY_RETURN_NOT_OK(PullObjectsAndGetFromPlasmaStore(remaining, + batch_ids, + batch_owner_addresses, + batch_timeout, + results, + got_exception)); should_break = timed_out || *got_exception; if ((previous_size - remaining.size()) < batch_ids.size()) { @@ -369,12 +380,18 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, } Status CoreWorkerPlasmaStoreProvider::Wait( - const absl::flat_hash_set &object_ids, + const absl::flat_hash_map &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready) { - std::vector id_vector(object_ids.begin(), object_ids.end()); + // Construct object ids vector and owner addresses vector + std::vector object_id_vector; + std::vector owner_addresses; + for (const auto &[object_id, owner_address] : object_ids) { + object_id_vector.push_back(object_id); + owner_addresses.push_back(owner_address); + } bool should_break = false; int64_t remaining_timeout = timeout_ms; @@ -387,10 +404,10 @@ Status CoreWorkerPlasmaStoreProvider::Wait( should_break = remaining_timeout <= 0; } - const auto owner_addresses = reference_counter_.GetOwnerAddresses(id_vector); RAY_ASSIGN_OR_RETURN( ready_in_plasma, - raylet_ipc_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout)); + raylet_ipc_client_->Wait( + object_id_vector, owner_addresses, num_objects, call_timeout)); if (ready_in_plasma.size() >= static_cast(num_objects)) { should_break = true; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index c2528a25c771..a11028827812 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -26,7 +26,6 @@ #include "ray/common/status.h" #include "ray/common/status_or.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/reference_counter_interface.h" #include "ray/object_manager/plasma/client.h" #include "ray/raylet_ipc_client/raylet_ipc_client_interface.h" #include "src/ray/protobuf/common.pb.h" @@ -96,7 +95,6 @@ class CoreWorkerPlasmaStoreProvider { CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_ipc_client, - ReferenceCounterInterface &reference_counter, std::function check_signals, bool warmup, std::shared_ptr store_client, @@ -154,7 +152,7 @@ class CoreWorkerPlasmaStoreProvider { /// argument to Get to retrieve the object data. Status Release(const ObjectID &object_id); - Status Get(const absl::flat_hash_set &object_ids, + Status Get(const absl::flat_hash_map &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, @@ -187,7 +185,7 @@ class CoreWorkerPlasmaStoreProvider { Status Contains(const ObjectID &object_id, bool *has_object); - Status Wait(const absl::flat_hash_set &object_ids, + Status Wait(const absl::flat_hash_map &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, @@ -221,6 +219,7 @@ class CoreWorkerPlasmaStoreProvider { Status PullObjectsAndGetFromPlasmaStore( absl::flat_hash_set &remaining, const std::vector &batch_ids, + const std::vector &batch_owner_addresses, int64_t timeout_ms, absl::flat_hash_map> *results, bool *got_exception); @@ -238,8 +237,6 @@ class CoreWorkerPlasmaStoreProvider { const std::shared_ptr raylet_ipc_client_; std::shared_ptr store_client_; - /// Used to look up a plasma object's owner. - ReferenceCounterInterface &reference_counter_; std::function check_signals_; std::function get_current_call_site_; uint32_t object_store_full_delay_ms_; diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 4de0318afea5..503b6283e133 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -561,7 +561,9 @@ StatusOr TaskManager::HandleTaskReturn(const ObjectID &object_id, // will choose the right raylet for any queued dependent tasks. reference_counter_.UpdateObjectPinnedAtRaylet(object_id, worker_node_id); // Mark it as in plasma with a dummy object. - in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_.HasReference(object_id)); } else { // NOTE(swang): If a direct object was promoted to plasma, then we do not // record the node ID that it was pinned at, which means that we will not @@ -595,7 +597,7 @@ StatusOr TaskManager::HandleTaskReturn(const ObjectID &object_id, return s; } } else { - in_memory_store_.Put(object, object_id); + in_memory_store_.Put(object, object_id, reference_counter_.HasReference(object_id)); direct_return = true; } } @@ -764,7 +766,8 @@ void TaskManager::MarkEndOfStream(const ObjectID &generator_id, // Put a dummy object at the end of the stream. We don't need to check if // the object should be stored in plasma because the end of the stream is a // fake ObjectRef that should never be read by the application. - in_memory_store_.Put(error, last_object_id); + in_memory_store_.Put( + error, last_object_id, reference_counter_.HasReference(last_object_id)); } } @@ -1551,10 +1554,11 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(object_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, object_id); + in_memory_store_.Put( + error, object_id, reference_counter_.HasReference(object_id)); } } else { - in_memory_store_.Put(error, object_id); + in_memory_store_.Put(error, object_id, reference_counter_.HasReference(object_id)); } } if (spec.ReturnsDynamic()) { @@ -1564,10 +1568,13 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(dynamic_return_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, dynamic_return_id); + in_memory_store_.Put(error, + dynamic_return_id, + reference_counter_.HasReference(dynamic_return_id)); } } else { - in_memory_store_.Put(error, dynamic_return_id); + in_memory_store_.Put( + error, dynamic_return_id, reference_counter_.HasReference(dynamic_return_id)); } } } @@ -1594,10 +1601,14 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(generator_return_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, generator_return_id); + in_memory_store_.Put(error, + generator_return_id, + reference_counter_.HasReference(generator_return_id)); } } else { - in_memory_store_.Put(error, generator_return_id); + in_memory_store_.Put(error, + generator_return_id, + reference_counter_.HasReference(generator_return_id)); } } } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 5be416b9552b..dbf62f4ff3f6 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -27,6 +27,7 @@ #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_event_buffer.h" #include "ray/core_worker/task_manager_interface.h" diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index 387fba21552c..e8b9cd096993 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -70,6 +70,7 @@ ray_cc_library( "//src/ray/common:id", "//src/ray/common:protobuf_utils", "//src/ray/core_worker:actor_creator", + "//src/ray/core_worker:reference_counter_interface", "//src/ray/core_worker_rpc_client:core_worker_client_pool", "//src/ray/rpc:rpc_callback_types", "//src/ray/util:time", diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index f225397768be..ada84d9a0d25 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -26,6 +26,7 @@ #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/core_worker/actor_creator.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_submission/actor_submit_queue.h" #include "ray/core_worker/task_submission/dependency_resolver.h" diff --git a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc index e1536ef89785..4bb0c9a41ce3 100644 --- a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc @@ -89,7 +89,8 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { : client_pool_(std::make_shared( [&](const rpc::Address &addr) { return worker_client_; })), worker_client_(std::make_shared()), - store_(std::make_shared(io_context)), + store_(std::make_shared(io_context, + /*reference_counting=*/true)), task_manager_(std::make_shared()), io_work(io_context.get_executor()), reference_counter_(std::make_shared()), @@ -232,11 +233,11 @@ TEST_P(ActorTaskSubmitterTest, TestDependencies) { auto data = GenerateRandomObject(); // Each Put schedules a callback onto io_context, and let's run it. - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 1); - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); @@ -280,12 +281,12 @@ TEST_P(ActorTaskSubmitterTest, TestOutOfOrderDependencies) { // submission. auto data = GenerateRandomObject(); // task2 is submitted first as we allow out of order execution. - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 1); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1)); // then task1 is submitted - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1, 0)); @@ -293,10 +294,10 @@ TEST_P(ActorTaskSubmitterTest, TestOutOfOrderDependencies) { // Put the dependencies in the store in the opposite order of task // submission. auto data = GenerateRandomObject(); - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 0); - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(0, 1)); diff --git a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc index e9766aec1281..1f116e1dcaa9 100644 --- a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc +++ b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc @@ -187,7 +187,7 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - store->Put(data, obj); + store->Put(data, obj, /*has_reference=*/true); // Wait for the async callback to call ASSERT_TRUE(dependencies_resolved.get_future().get()); ASSERT_EQ(num_resolved, 1); @@ -228,7 +228,7 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); ASSERT_EQ(num_resolved, 0); - store->Put(data, obj); + store->Put(data, obj, /*has_reference=*/true); for (const auto &cb : actor_creator.callbacks) { cb(Status()); @@ -253,7 +253,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - store->Put(data, obj1); + store->Put(data, obj1, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); bool ok = false; @@ -282,8 +282,8 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { ObjectID obj2 = ObjectID::FromRandom(); auto data = GenerateRandomObject(); // Ensure the data is already present in the local store. - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary()); @@ -326,8 +326,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); ASSERT_TRUE(dependencies_resolved.get_future().get()); // Tests that the task proto was rewritten to have inline argument values after @@ -365,8 +365,8 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); ASSERT_TRUE(dependencies_resolved.get_future().get()); // Tests that the task proto was rewritten to have inline argument values after @@ -383,7 +383,9 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { TEST(LocalDependencyResolverTest, TestCancelDependencyResolution) { InstrumentedIOContextWithThread io_context("TestCancelDependencyResolution"); - auto store = std::make_shared(io_context.GetIoService()); + // Mock reference counter as enabled + auto store = std::make_shared(io_context.GetIoService(), + /*reference_counting=*/true); auto task_manager = std::make_shared(); FakeActorCreator actor_creator; LocalDependencyResolver resolver( @@ -400,7 +402,7 @@ TEST(LocalDependencyResolverTest, TestCancelDependencyResolution) { resolver.ResolveDependencies(task, [&ok](Status) { ok = true; }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); + store->Put(*data, obj1, /*has_reference=*/true); ASSERT_TRUE(resolver.CancelDependencyResolution(task.TaskId())); // Callback is not called. @@ -428,7 +430,7 @@ TEST(LocalDependencyResolverTest, TestDependenciesAlreadyLocal) { ObjectID obj = ObjectID::FromRandom(); auto data = GenerateRandomObject(); - store->Put(*data, obj); + store->Put(*data, obj, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); @@ -471,8 +473,8 @@ TEST(LocalDependencyResolverTest, TestMixedTensorTransport) { }); auto data = GenerateRandomObject(); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index 29338e043751..79250513bd20 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -1512,7 +1512,9 @@ void TestSchedulingKey(const std::shared_ptr store, TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { InstrumentedIOContextWithThread io_context("TestSchedulingKeys"); - auto memory_store = std::make_shared(io_context.GetIoService()); + // Mock reference counter as enabled + auto memory_store = std::make_shared( + io_context.GetIoService(), /*reference_counting=*/true); std::unordered_map resources1({{"a", 1.0}}); std::unordered_map resources2({{"b", 2.0}}); @@ -1555,16 +1557,16 @@ TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { ObjectID plasma2 = ObjectID::FromRandom(); // Ensure the data is already present in the local store for direct call objects. auto data = GenerateRandomObject(); - memory_store->Put(*data, direct1); - memory_store->Put(*data, direct2); + memory_store->Put(*data, direct1, /*has_reference=*/true); + memory_store->Put(*data, direct2, /*has_reference=*/true); // Force plasma objects to be promoted. std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store->Put(plasma_data, plasma1); - memory_store->Put(plasma_data, plasma2); + memory_store->Put(plasma_data, plasma1, /*has_reference=*/true); + memory_store->Put(plasma_data, plasma2, /*has_reference=*/true); TaskSpecification same_deps_1 = BuildTaskSpec(resources1, descriptor1); same_deps_1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( @@ -1595,8 +1597,9 @@ TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { TEST_F(NormalTaskSubmitterTest, TestBacklogReport) { InstrumentedIOContextWithThread store_io_context("TestBacklogReport"); - auto memory_store = - std::make_shared(store_io_context.GetIoService()); + // Mock reference counter as enabled + auto memory_store = std::make_shared( + store_io_context.GetIoService(), /*reference_counting=*/true); auto submitter = CreateNormalTaskSubmitter(std::make_shared(1), WorkerType::WORKER, @@ -1618,8 +1621,8 @@ TEST_F(NormalTaskSubmitterTest, TestBacklogReport) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store->Put(plasma_data, plasma1); - memory_store->Put(plasma_data, plasma2); + memory_store->Put(plasma_data, plasma1, /*has_reference=*/true); + memory_store->Put(plasma_data, plasma2, /*has_reference=*/true); // Same SchedulingClass, different SchedulingKey TaskSpecification task2 = BuildTaskSpec(resources1, descriptor1); @@ -1786,7 +1789,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillResolvingTask) { ASSERT_EQ(task_manager->num_inlined_dependencies, 0); submitter.CancelTask(task, true, false); auto data = GenerateRandomObject(); - store->Put(*data, obj1); + store->Put(*data, obj1, /*has_reference=*/true); WaitForObjectIdInMemoryStore(*store, obj1); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index 452fa7390dae..a5bd0069f48b 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -153,8 +153,9 @@ class CoreWorkerTest : public ::testing::Test { [](const NodeID &) { return false; }, false); + // Mock reference counter as enabled memory_store_ = std::make_shared( - io_service_, reference_counter_.get(), nullptr); + io_service_, reference_counter_ != nullptr, nullptr); auto future_resolver = std::make_unique( memory_store_, @@ -336,7 +337,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusIdempotency) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -404,7 +405,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectPutAfterFirstRequest) { // Verify that the callback hasn't been called yet since the object doesn't exist ASSERT_FALSE(io_service_.poll_one()); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); io_service_.run_one(); @@ -441,7 +442,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectFreedBetweenRequests) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -491,7 +492,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectOutOfScope) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -556,7 +557,9 @@ ObjectID CreateInlineObjectInMemoryStoreAndRefCounter( /*object_size=*/100, /*is_reconstructable=*/false, /*add_local_ref=*/true); - memory_store.Put(memory_store_object, inlined_dependency_id); + memory_store.Put(memory_store_object, + inlined_dependency_id, + reference_counter.HasReference(inlined_dependency_id)); return inlined_dependency_id; } } // namespace @@ -639,7 +642,6 @@ TEST(BatchingPassesTwoTwoOneIntoPlasmaGet, CallsPlasmaGetInCorrectBatches) { CoreWorkerPlasmaStoreProvider provider( /*store_socket=*/"", fake_raylet, - ref_counter, /*check_signals=*/[] { return Status::OK(); }, /*warmup=*/false, /*store_client=*/fake_plasma, @@ -649,13 +651,20 @@ TEST(BatchingPassesTwoTwoOneIntoPlasmaGet, CallsPlasmaGetInCorrectBatches) { // Build a set of 5 object ids. std::vector ids; for (int i = 0; i < 5; i++) ids.push_back(ObjectID::FromRandom()); - absl::flat_hash_set idset(ids.begin(), ids.end()); + + // Prepare object ids map + const auto owner_addresses = ref_counter.GetOwnerAddresses(ids); + absl::flat_hash_map id_map; + for (size_t i = 0; i < ids.size(); i++) { + id_map[ids[i]] = owner_addresses[i]; + } absl::flat_hash_map> results; bool got_exception = false; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - ASSERT_TRUE(provider.Get(idset, /*timeout_ms=*/-1, ctx, &results, &got_exception).ok()); + ASSERT_TRUE( + provider.Get(id_map, /*timeout_ms=*/-1, ctx, &results, &got_exception).ok()); // Assert: batches seen by plasma Get are [2,2,1]. ASSERT_EQ(observed_batches.size(), 3U); diff --git a/src/ray/core_worker/tests/memory_store_test.cc b/src/ray/core_worker/tests/memory_store_test.cc index 5a90b26af481..893cf2a3c819 100644 --- a/src/ray/core_worker/tests/memory_store_test.cc +++ b/src/ray/core_worker/tests/memory_store_test.cc @@ -53,7 +53,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr memory_store = std::make_shared( io_context.GetIoService(), - nullptr, + /*reference_counting=*/true, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); @@ -64,8 +64,9 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { // Check basic put and get. ASSERT_TRUE(memory_store->GetIfExists(id1) == nullptr); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); + // Set has_reference to true to ensure the put doesn't get deleted due to no reference. + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); ASSERT_TRUE(memory_store->GetIfExists(id1) != nullptr); ASSERT_EQ(unhandled_count, 0); @@ -75,17 +76,17 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { unhandled_count = 0; // Check delete after get. - memory_store->Put(obj1, id1); - memory_store->Put(obj1, id2); - RAY_UNUSED(memory_store->Get({id1}, 1, 100, context, false, &results)); - RAY_UNUSED(memory_store->Get({id2}, 1, 100, context, false, &results)); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj1, id2, /*has_reference=*/true); + RAY_UNUSED(memory_store->Get({id1}, 1, 100, context, &results)); + RAY_UNUSED(memory_store->Get({id2}, 1, 100, context, &results)); memory_store->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); // Check delete after async get. memory_store->GetAsync({id2}, [](std::shared_ptr obj) {}); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); memory_store->GetAsync({id1}, [](std::shared_ptr obj) {}); memory_store->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); @@ -118,9 +119,9 @@ TEST(TestMemoryStore, TestMemoryStoreStats) { auto id2 = ObjectID::FromRandom(); auto id3 = ObjectID::FromRandom(); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); - memory_store->Put(obj3, id3); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); + memory_store->Put(obj3, id3, /*has_reference=*/true); memory_store->Delete({id3}); MemoryStoreStats expected_item; @@ -140,9 +141,9 @@ TEST(TestMemoryStore, TestMemoryStoreStats) { ASSERT_EQ(item.num_local_objects, expected_item2.num_local_objects); ASSERT_EQ(item.num_local_objects_bytes, expected_item2.num_local_objects_bytes); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); - memory_store->Put(obj3, id3); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); + memory_store->Put(obj3, id3, /*has_reference=*/true); MemoryStoreStats expected_item3; fill_expected_memory_stats(expected_item3); item = memory_store->GetMemoryStoreStatisticalData(); @@ -208,7 +209,7 @@ TEST(TestMemoryStore, TestObjectAllocator) { std::shared_ptr memory_store = std::make_shared(io_context.GetIoService(), - nullptr, + /*reference_counting=*/true, nullptr, nullptr, nullptr, @@ -220,7 +221,7 @@ TEST(TestMemoryStore, TestObjectAllocator) { std::vector nested_refs; auto hello_object = std::make_shared(hello_buffer, nullptr, nested_refs, true); - memory_store->Put(*hello_object, ObjectID::FromRandom()); + memory_store->Put(*hello_object, ObjectID::FromRandom(), /*has_reference=*/true); } ASSERT_EQ(max_rounds * hello.size(), mock_buffer_manager.GetBuferPressureInBytes()); } @@ -238,7 +239,8 @@ class TestMemoryStoreWait : public ::testing::Test { protected: TestMemoryStoreWait() : io_context("TestWait"), - memory_store(std::make_shared(io_context.GetIoService())), + memory_store(std::make_shared( + io_context.GetIoService(), /*reference_counting=*/true)), ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(1)), buffer("hello"), memory_store_object( @@ -259,10 +261,10 @@ TEST_F(TestMemoryStoreWait, TestWaitNoWaiting) { absl::flat_hash_set object_ids_set = {object_ids.begin(), object_ids.end()}; int num_objects = 2; - memory_store->Put(memory_store_object, object_ids[0]); - memory_store->Put(plasma_store_object, object_ids[1]); - memory_store->Put(plasma_store_object, object_ids[2]); - memory_store->Put(memory_store_object, object_ids[3]); + memory_store->Put(memory_store_object, object_ids[0], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[1], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[2], /*has_reference=*/true); + memory_store->Put(memory_store_object, object_ids[3], /*has_reference=*/true); absl::flat_hash_set ready, plasma_object_ids; const auto status = memory_store->Wait( @@ -289,8 +291,8 @@ TEST_F(TestMemoryStoreWait, TestWaitWithWaiting) { absl::flat_hash_set object_ids_set = {object_ids.begin(), object_ids.end()}; int num_objects = 4; - memory_store->Put(memory_store_object, object_ids[0]); - memory_store->Put(plasma_store_object, object_ids[1]); + memory_store->Put(memory_store_object, object_ids[0], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[1], /*has_reference=*/true); absl::flat_hash_set ready, plasma_object_ids; auto future = std::async(std::launch::async, [&]() { @@ -298,9 +300,9 @@ TEST_F(TestMemoryStoreWait, TestWaitWithWaiting) { object_ids_set, num_objects, 100, ctx, &ready, &plasma_object_ids); }); ASSERT_EQ(future.wait_for(std::chrono::milliseconds(1)), std::future_status::timeout); - memory_store->Put(plasma_store_object, object_ids[2]); + memory_store->Put(plasma_store_object, object_ids[2], /*has_reference=*/true); ASSERT_EQ(future.wait_for(std::chrono::milliseconds(1)), std::future_status::timeout); - memory_store->Put(memory_store_object, object_ids[3]); + memory_store->Put(memory_store_object, object_ids[3], /*has_reference=*/true); const auto status = future.get(); @@ -316,7 +318,7 @@ TEST_F(TestMemoryStoreWait, TestWaitTimeout) { // object 0 in plasma // waits until 10ms timeout for 2 objects absl::flat_hash_set object_ids_set = {ObjectID::FromRandom()}; - memory_store->Put(plasma_store_object, *object_ids_set.begin()); + memory_store->Put(plasma_store_object, *object_ids_set.begin(), /*has_reference=*/true); int num_objects = 2; absl::flat_hash_set ready, plasma_object_ids; diff --git a/src/ray/core_worker/tests/object_recovery_manager_test.cc b/src/ray/core_worker/tests/object_recovery_manager_test.cc index d25709af5cc2..f1be4cc6c9ab 100644 --- a/src/ray/core_worker/tests/object_recovery_manager_test.cc +++ b/src/ray/core_worker/tests/object_recovery_manager_test.cc @@ -126,8 +126,8 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { publisher_(std::make_shared()), subscriber_(std::make_shared()), object_directory_(std::make_shared()), - memory_store_( - std::make_shared(io_context_.GetIoService())), + memory_store_(std::make_shared( + io_context_.GetIoService(), /*reference_counting=*/true)), raylet_client_pool_(std::make_shared( [&](const rpc::Address &) { return raylet_client_; })), raylet_client_(std::make_shared()), @@ -159,7 +159,7 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store_->Put(data, object_id); + memory_store_->Put(data, object_id, ref_counter_->HasReference(object_id)); }) { ref_counter_->SetReleaseLineageCallback( [](const ObjectID &, std::vector *args) { return 0; }); diff --git a/src/ray/core_worker/tests/reference_counter_test.cc b/src/ray/core_worker/tests/reference_counter_test.cc index 695c6178c038..ccbfe013efb5 100644 --- a/src/ray/core_worker/tests/reference_counter_test.cc +++ b/src/ray/core_worker/tests/reference_counter_test.cc @@ -822,15 +822,15 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { subscriber.get(), /*is_node_dead=*/[](const NodeID &) { return false; }); InstrumentedIOContextWithThread io_context("TestSimple"); - CoreWorkerMemoryStore store(io_context.GetIoService(), rc.get()); + CoreWorkerMemoryStore store(io_context.GetIoService(), rc != nullptr); // Tests putting an object with no references is ignored. - store.Put(buffer, id2); + store.Put(buffer, id2, rc->HasReference(id2)); ASSERT_EQ(store.Size(), 0); - // Tests ref counting overrides remove after get option. + // Tests that objects with references remain in the store after Get. rc->AddLocalReference(id1, ""); - store.Put(buffer, id1); + store.Put(buffer, id1, rc->HasReference(id1)); ASSERT_EQ(store.Size(), 1); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); @@ -838,7 +838,6 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { /*num_objects*/ 1, /*timeout_ms*/ -1, ctx, - /*remove_after_get*/ true, &results)); ASSERT_EQ(results.size(), 1); ASSERT_EQ(store.Size(), 1); diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index cf187cb6c4f7..6a8f2dfb679b 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -156,7 +156,7 @@ class TaskManagerTest : public ::testing::Test { lineage_pinning_enabled)), io_context_("TaskManagerTest"), store_(std::make_shared(io_context_.GetIoService(), - reference_counter_.get())), + reference_counter_ != nullptr)), manager_( *store_, *reference_counter_, @@ -281,7 +281,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_FALSE(results[0]->IsException()); ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), @@ -318,7 +318,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -360,7 +360,7 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) { std::vector> results; // Caller of FlushObjectsToRecover is responsible for deleting the object // from the in-memory store and recovering the object. - ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, &results).ok()); auto objects_to_recover = reference_counter_->FlushObjectsToRecover(); ASSERT_EQ(objects_to_recover.size(), 1); ASSERT_EQ(objects_to_recover[0], return_id); @@ -388,7 +388,7 @@ TEST_F(TaskManagerTest, TestFailPendingTask) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -412,7 +412,7 @@ TEST_F(TaskManagerTest, TestFailPendingTaskAfterCancellation) { // Check that the error type is set to TASK_CANCELLED std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({spec.ReturnId(0)}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({spec.ReturnId(0)}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -442,7 +442,7 @@ TEST_F(TaskManagerTest, TestTaskReconstruction) { ASSERT_TRUE(reference_counter_->IsObjectPendingCreation(return_id)); ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); std::vector> results; - ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, &results).ok()); ASSERT_EQ(num_retries_, i + 1); ASSERT_EQ(last_delay_ms_, RayConfig::instance().task_retry_delay_ms()); } @@ -454,7 +454,7 @@ TEST_F(TaskManagerTest, TestTaskReconstruction) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -483,7 +483,7 @@ TEST_F(TaskManagerTest, TestTaskKill) { manager_.FailOrRetryPendingTask(spec.TaskId(), error); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -535,7 +535,7 @@ TEST_F(TaskManagerTest, TestTaskOomKillNoOomRetryFailsImmediately) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -555,7 +555,7 @@ TEST_F(TaskManagerTest, TestTaskOomKillNoOomRetryFailsImmediately) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -591,7 +591,7 @@ TEST_F(TaskManagerTest, TestTaskOomAndNonOomKillReturnsLastError) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -633,7 +633,7 @@ TEST_F(TaskManagerTest, TestTaskNotRetriableOomFailsImmediatelyEvenWithOomRetryC std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -660,7 +660,7 @@ TEST_F(TaskManagerTest, TestFailsImmediatelyOverridesRetry) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -684,7 +684,7 @@ TEST_F(TaskManagerTest, TestFailsImmediatelyOverridesRetry) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -1278,7 +1278,7 @@ TEST_F(TaskManagerLineageTest, TestDynamicReturnsTask) { WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); std::vector> results; - RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, &results)); ASSERT_EQ(results.size(), 3); for (int i = 0; i < 3; i++) { ASSERT_TRUE(results[i]->IsInPlasmaError()); @@ -1356,10 +1356,10 @@ TEST_F(TaskManagerLineageTest, TestResubmittedDynamicReturnsTaskFails) { // No error stored for the generator ID, which should have gone out of scope. WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); std::vector> results; - ASSERT_FALSE(store_->Get({generator_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_FALSE(store_->Get({generator_id}, 1, 0, ctx, &results).ok()); // The internal ObjectRefs have the right error. - RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, &results)); ASSERT_EQ(results.size(), 3); for (int i = 0; i < 3; i++) { rpc::ErrorType stored_error; @@ -1377,8 +1377,8 @@ TEST_F(TaskManagerTest, PlasmaPut_ObjectStoreFull_FailsTaskAndWritesError) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared( + io_context_.GetIoService(), reference_counter_ != nullptr); TaskManager failing_mgr( *local_store, @@ -1424,7 +1424,7 @@ TEST_F(TaskManagerTest, PlasmaPut_ObjectStoreFull_FailsTaskAndWritesError) { ASSERT_FALSE(failing_mgr.IsTaskPending(spec.TaskId())); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_TRUE(results[0]->IsException()); } @@ -1437,8 +1437,8 @@ TEST_F(TaskManagerTest, PlasmaPut_TransientFull_RetriesThenSucceeds) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared( + io_context_.GetIoService(), reference_counter_ != nullptr); TaskManager retry_mgr( *local_store, *local_ref_counter, @@ -1485,7 +1485,7 @@ TEST_F(TaskManagerTest, PlasmaPut_TransientFull_RetriesThenSucceeds) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_TRUE(results[0]->IsInPlasmaError()); } @@ -1498,8 +1498,8 @@ TEST_F(TaskManagerTest, DynamicReturn_PlasmaPutFailure_FailsTaskImmediately) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared( + io_context_.GetIoService(), reference_counter_ != nullptr); TaskManager dyn_mgr( *local_store, *local_ref_counter, @@ -2030,7 +2030,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); // Make sure you can read. @@ -2059,7 +2059,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { // NumObjectIDsInScope == Generator + 2 intermediate result. results.clear(); - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); // Make sure you can read. @@ -2122,10 +2122,10 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { ASSERT_EQ(store_->Size(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2135,9 +2135,9 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { // All the in memory objects should be cleaned up. The generator ref returns // a direct result that would be GCed once it goes out of scope. ASSERT_EQ(store_->Size(), 1); - ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // Clean up the generator ID. Now all lineage is safe to remove. @@ -2169,7 +2169,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); // All the in memory objects should be cleaned up. ASSERT_EQ(store_->Size(), 1); - ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); } @@ -2223,7 +2223,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { ASSERT_EQ(store_->Size(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2234,7 +2234,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { ASSERT_EQ(obj_id, dynamic_return_id); // Write one ref that will stay unconsumed. - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2244,7 +2244,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed objects should be cleaned up. The generator ref returns // a direct result that would be GCed once it goes out of scope. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // Clean up the generator ID. @@ -2258,7 +2258,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed in memory objects should be cleaned up. Check for 2 // in-memory objects: one consumed object ref and the generator ref. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // NOTE: We panic if READ is called after DELETE. The @@ -2283,7 +2283,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed in memory objects should be cleaned up. Check for 2 // in-memory objects: one consumed object ref and the generator ref. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); } @@ -2992,7 +2992,7 @@ TEST_F(TaskManagerTest, TestRetryErrorMessageSentToCallback) { /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, false); auto local_store = std::make_shared( - io_context_.GetIoService(), local_reference_counter.get()); + io_context_.GetIoService(), reference_counter_ != nullptr); TaskManager test_manager( *local_store, @@ -3072,7 +3072,7 @@ TEST_F(TaskManagerTest, TestErrorLogWhenPushErrorCallbackFails) { /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, false); auto local_store = std::make_shared( - io_context_.GetIoService(), local_reference_counter.get()); + io_context_.GetIoService(), reference_counter_ != nullptr); TaskManager test_manager( *local_store, From 4ac6478143bca5d55e7b2bd20d0c7cb7d95aea66 Mon Sep 17 00:00:00 2001 From: davik Date: Fri, 10 Oct 2025 23:56:50 +0000 Subject: [PATCH 2/5] Optimize Get/Wait perfomance in plasma store Signed-off-by: davik --- src/ray/core_worker/core_worker.cc | 42 ++++++------- src/ray/core_worker/core_worker.h | 8 --- .../store_provider/plasma_store_provider.cc | 61 ++++++++----------- .../store_provider/plasma_store_provider.h | 13 ++-- src/ray/core_worker/tests/core_worker_test.cc | 9 +-- 5 files changed, 54 insertions(+), 79 deletions(-) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 69c773d780ac..aa439e2068ca 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1368,9 +1368,10 @@ Status CoreWorker::GetObjects(const std::vector &ids, // gets at the provider plasma. Once we get the objects from plasma, we flip // the transport type again and return them for the original direct call ids. - // Resolve owner addresses of plasma ids - absl::flat_hash_map plasma_object_ids_map = - GetObjectIdToOwnerAddressMap(plasma_object_ids); + // Prepare object ids vector and owner addresses vector + std::vector object_ids = + std::vector(plasma_object_ids.begin(), plasma_object_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); int64_t local_timeout_ms = timeout_ms; if (timeout_ms >= 0) { @@ -1378,7 +1379,8 @@ Status CoreWorker::GetObjects(const std::vector &ids, timeout_ms - (current_time_ms() - start_time)); } RAY_LOG(DEBUG) << "Plasma GET timeout " << local_timeout_ms; - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids_map, + RAY_RETURN_NOT_OK(plasma_store_provider_->Get(object_ids, + owner_addresses, local_timeout_ms, *worker_context_, &result_map, @@ -1527,11 +1529,13 @@ Status CoreWorker::Wait(const std::vector &ids, // these objects. if (!plasma_object_ids.empty()) { // Prepare object ids map - absl::flat_hash_map plasma_object_ids_map = - GetObjectIdToOwnerAddressMap(plasma_object_ids); + std::vector object_ids = + std::vector(plasma_object_ids.begin(), plasma_object_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids_map, + object_ids, + owner_addresses, std::min(static_cast(plasma_object_ids.size()), num_objects - static_cast(ready.size())), timeout_ms, @@ -3019,11 +3023,11 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); // Resolve owner address of return id - absl::flat_hash_map return_id_map = - GetObjectIdToOwnerAddressMap({return_id}); + std::vector object_ids = {return_id}; + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); auto status = plasma_store_provider_->Get( - return_id_map, 0, *worker_context_, &result_map, &got_exception); + object_ids, owner_addresses, 0, *worker_context_, &result_map, &got_exception); // Remove the temporary ref. RemoveLocalReference(return_id); @@ -3292,10 +3296,11 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, by_ref_ids, -1, *worker_context_, &result_map, &got_exception)); } else { // Resolve owner addresses of by-ref ids - absl::flat_hash_map by_ref_ids_map = - GetObjectIdToOwnerAddressMap(by_ref_ids); + std::vector object_ids = + std::vector(by_ref_ids.begin(), by_ref_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); RAY_RETURN_NOT_OK(plasma_store_provider_->Get( - by_ref_ids_map, -1, *worker_context_, &result_map, &got_exception)); + object_ids, owner_addresses, -1, *worker_context_, &result_map, &got_exception)); } for (const auto &it : result_map) { for (size_t idx : by_ref_indices[it.first]) { @@ -3486,17 +3491,6 @@ void CoreWorker::PopulateObjectStatus(const ObjectID &object_id, } } -absl::flat_hash_map CoreWorker::GetObjectIdToOwnerAddressMap( - const absl::flat_hash_set &object_ids) { - std::vector object_ids_vector(object_ids.begin(), object_ids.end()); - const auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids_vector); - absl::flat_hash_map object_id_map; - for (size_t i = 0; i < object_ids_vector.size(); i++) { - object_id_map[object_ids_vector[i]] = owner_addresses[i]; - } - return object_id_map; -} - void CoreWorker::HandleWaitForActorRefDeleted( rpc::WaitForActorRefDeletedRequest request, rpc::WaitForActorRefDeletedReply *reply, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 1d0a35c5ae4d..f40d26ecdcb5 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1429,14 +1429,6 @@ class CoreWorker { const std::shared_ptr &obj, rpc::GetObjectStatusReply *reply); - /// Helper method to construct a map from ObjectIDs to their owner addresses. - /// This method helps prepare inputs for calls to the plasma store. - /// - /// \param[in] object_ids Set of ObjectIDs to look up. - /// \return A map from ObjectID to owner address (rpc::Address). - absl::flat_hash_map GetObjectIdToOwnerAddressMap( - const absl::flat_hash_set &object_ids); - /// /// Private methods related to task submission. /// diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 4a0bdec644d2..fb70e0eca1cb 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -176,7 +176,7 @@ Status CoreWorkerPlasmaStoreProvider::Release(const ObjectID &object_id) { } Status CoreWorkerPlasmaStoreProvider::PullObjectsAndGetFromPlasmaStore( - absl::flat_hash_set &remaining, + absl::flat_hash_map &remaining_object_id_to_idx, const std::vector &batch_ids, const std::vector &batch_owner_addresses, int64_t timeout_ms, @@ -207,7 +207,7 @@ Status CoreWorkerPlasmaStoreProvider::PullObjectsAndGetFromPlasmaStore( } auto result_object = std::make_shared( data, metadata, std::vector()); - remaining.erase(object_id); + remaining_object_id_to_idx.erase(object_id); if (result_object->IsException()) { RAY_CHECK(!result_object->IsInPlasmaError()); *got_exception = true; @@ -270,29 +270,27 @@ Status UnblockIfNeeded( } Status CoreWorkerPlasmaStoreProvider::Get( - const absl::flat_hash_map &object_ids, + const std::vector &object_ids, + const std::vector &owner_addresses, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { - absl::flat_hash_set remaining; + absl::flat_hash_map remaining_object_id_to_idx; // Send initial requests to pull all objects in parallel. - std::vector> id_vector(object_ids.begin(), - object_ids.end()); int64_t total_size = static_cast(object_ids.size()); for (int64_t start = 0; start < total_size; start += fetch_batch_size_) { std::vector batch_ids; std::vector batch_owner_addresses; for (int64_t i = start; i < start + fetch_batch_size_ && i < total_size; i++) { - // Construct remaining set a batch at a time - remaining.insert(id_vector[i].first); + remaining_object_id_to_idx[object_ids[i]] = i; - batch_ids.push_back(id_vector[i].first); - batch_owner_addresses.push_back(id_vector[i].second); + batch_ids.push_back(object_ids[i]); + batch_owner_addresses.push_back(owner_addresses[i]); } RAY_RETURN_NOT_OK( - PullObjectsAndGetFromPlasmaStore(remaining, + PullObjectsAndGetFromPlasmaStore(remaining_object_id_to_idx, batch_ids, batch_owner_addresses, /*timeout_ms=*/0, @@ -303,7 +301,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( // If all objects were fetched already, return. Note that we always need to // call UnblockIfNeeded() to cancel the get request. - if (remaining.empty() || *got_exception) { + if (remaining_object_id_to_idx.empty() || *got_exception) { return UnblockIfNeeded(raylet_ipc_client_, ctx); } @@ -314,15 +312,15 @@ Status CoreWorkerPlasmaStoreProvider::Get( bool timed_out = false; int64_t remaining_timeout = timeout_ms; auto fetch_start_time_ms = current_time_ms(); - while (!remaining.empty() && !should_break) { + while (!remaining_object_id_to_idx.empty() && !should_break) { std::vector batch_ids; std::vector batch_owner_addresses; - for (const auto &id : remaining) { + for (const auto &[id, idx] : remaining_object_id_to_idx) { if (static_cast(batch_ids.size()) == fetch_batch_size_) { break; } batch_ids.push_back(id); - batch_owner_addresses.push_back(object_ids.find(id)->second); + batch_owner_addresses.push_back(owner_addresses[idx]); } int64_t batch_timeout = @@ -334,8 +332,8 @@ Status CoreWorkerPlasmaStoreProvider::Get( timed_out = remaining_timeout <= 0; } - size_t previous_size = remaining.size(); - RAY_RETURN_NOT_OK(PullObjectsAndGetFromPlasmaStore(remaining, + size_t previous_size = remaining_object_id_to_idx.size(); + RAY_RETURN_NOT_OK(PullObjectsAndGetFromPlasmaStore(remaining_object_id_to_idx, batch_ids, batch_owner_addresses, batch_timeout, @@ -343,8 +341,8 @@ Status CoreWorkerPlasmaStoreProvider::Get( got_exception)); should_break = timed_out || *got_exception; - if ((previous_size - remaining.size()) < batch_ids.size()) { - WarnIfFetchHanging(fetch_start_time_ms, remaining); + if ((previous_size - remaining_object_id_to_idx.size()) < batch_ids.size()) { + WarnIfFetchHanging(fetch_start_time_ms, remaining_object_id_to_idx); } if (check_signals_) { Status status = check_signals_(); @@ -355,7 +353,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( } } if (RayConfig::instance().yield_plasma_lock_workaround() && !should_break && - remaining.size() > 0) { + remaining_object_id_to_idx.size() > 0) { // Yield the plasma lock to other threads. This is a temporary workaround since we // are holding the lock for a long time, so it can easily starve inbound RPC // requests to Release() buffers which only require holding the lock for brief @@ -364,7 +362,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( } } - if (!remaining.empty() && timed_out) { + if (!remaining_object_id_to_idx.empty() && timed_out) { RAY_RETURN_NOT_OK(UnblockIfNeeded(raylet_ipc_client_, ctx)); return Status::TimedOut("Get timed out: some object(s) not ready."); } @@ -380,19 +378,12 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, } Status CoreWorkerPlasmaStoreProvider::Wait( - const absl::flat_hash_map &object_ids, + const std::vector &object_ids, + const std::vector &owner_addresses, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready) { - // Construct object ids vector and owner addresses vector - std::vector object_id_vector; - std::vector owner_addresses; - for (const auto &[object_id, owner_address] : object_ids) { - object_id_vector.push_back(object_id); - owner_addresses.push_back(owner_address); - } - bool should_break = false; int64_t remaining_timeout = timeout_ms; absl::flat_hash_set ready_in_plasma; @@ -406,8 +397,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait( RAY_ASSIGN_OR_RETURN( ready_in_plasma, - raylet_ipc_client_->Wait( - object_id_vector, owner_addresses, num_objects, call_timeout)); + raylet_ipc_client_->Wait(object_ids, owner_addresses, num_objects, call_timeout)); if (ready_in_plasma.size() >= static_cast(num_objects)) { should_break = true; @@ -441,12 +431,13 @@ CoreWorkerPlasmaStoreProvider::UsedObjectsList() const { } void CoreWorkerPlasmaStoreProvider::WarnIfFetchHanging( - int64_t fetch_start_time_ms, const absl::flat_hash_set &remaining) { + int64_t fetch_start_time_ms, + const absl::flat_hash_map &remaining_object_id_to_idx) { int64_t duration_ms = current_time_ms() - fetch_start_time_ms; if (duration_ms > RayConfig::instance().fetch_warn_timeout_milliseconds()) { std::ostringstream oss; size_t printed = 0; - for (auto &id : remaining) { + for (auto &[id, _] : remaining_object_id_to_idx) { if (printed >= RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) { break; @@ -457,7 +448,7 @@ void CoreWorkerPlasmaStoreProvider::WarnIfFetchHanging( oss << id.Hex(); printed++; } - if (printed < remaining.size()) { + if (printed < remaining_object_id_to_idx.size()) { oss << ", etc"; } RAY_LOG(WARNING) diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index a11028827812..2cbc376881b4 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -152,7 +152,8 @@ class CoreWorkerPlasmaStoreProvider { /// argument to Get to retrieve the object data. Status Release(const ObjectID &object_id); - Status Get(const absl::flat_hash_map &object_ids, + Status Get(const std::vector &object_ids, + const std::vector &owner_addresses, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, @@ -185,7 +186,8 @@ class CoreWorkerPlasmaStoreProvider { Status Contains(const ObjectID &object_id, bool *has_object); - Status Wait(const absl::flat_hash_map &object_ids, + Status Wait(const std::vector &object_ids, + const std::vector &owner_addresses, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, @@ -217,7 +219,7 @@ class CoreWorkerPlasmaStoreProvider { /// exception. /// \return Status. Status PullObjectsAndGetFromPlasmaStore( - absl::flat_hash_set &remaining, + absl::flat_hash_map &remaining_object_id_to_idx, const std::vector &batch_ids, const std::vector &batch_owner_addresses, int64_t timeout_ms, @@ -226,8 +228,9 @@ class CoreWorkerPlasmaStoreProvider { /// Print a warning if we've attempted the fetch for too long and some /// objects are still unavailable. - static void WarnIfFetchHanging(int64_t fetch_start_time_ms, - const absl::flat_hash_set &remaining); + static void WarnIfFetchHanging( + int64_t fetch_start_time_ms, + const absl::flat_hash_map &remaining_object_id_to_idx); /// Put something in the plasma store so that subsequent plasma store accesses /// will be faster. Currently the first access is always slow, and we don't diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index a5bd0069f48b..5346d1f76909 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -651,20 +651,15 @@ TEST(BatchingPassesTwoTwoOneIntoPlasmaGet, CallsPlasmaGetInCorrectBatches) { // Build a set of 5 object ids. std::vector ids; for (int i = 0; i < 5; i++) ids.push_back(ObjectID::FromRandom()); - - // Prepare object ids map const auto owner_addresses = ref_counter.GetOwnerAddresses(ids); - absl::flat_hash_map id_map; - for (size_t i = 0; i < ids.size(); i++) { - id_map[ids[i]] = owner_addresses[i]; - } absl::flat_hash_map> results; bool got_exception = false; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); ASSERT_TRUE( - provider.Get(id_map, /*timeout_ms=*/-1, ctx, &results, &got_exception).ok()); + provider.Get(ids, owner_addresses, /*timeout_ms=*/-1, ctx, &results, &got_exception) + .ok()); // Assert: batches seen by plasma Get are [2,2,1]. ASSERT_EQ(observed_batches.size(), 3U); From a14329eb078c7c0e288489785d324c0058f50e91 Mon Sep 17 00:00:00 2001 From: davik Date: Fri, 24 Oct 2025 19:48:08 +0000 Subject: [PATCH 3/5] Address comments to set reference counting as enabled for memory store by default Signed-off-by: davik --- src/mock/ray/core_worker/memory_store.h | 2 +- .../memory_store/memory_store.cc | 6 +++--- .../store_provider/memory_store/memory_store.h | 4 ++-- .../store_provider/plasma_store_provider.cc | 1 - .../tests/actor_task_submitter_test.cc | 3 +-- .../tests/dependency_resolver_test.cc | 4 +--- .../tests/normal_task_submitter_test.cc | 7 +++---- src/ray/core_worker/tests/memory_store_test.cc | 3 +-- .../tests/object_recovery_manager_test.cc | 4 ++-- .../tests/reference_counter_test.cc | 2 +- src/ray/core_worker/tests/task_manager_test.cc | 18 ++++++------------ 11 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/mock/ray/core_worker/memory_store.h b/src/mock/ray/core_worker/memory_store.h index 28ce9cef618e..33753926ed18 100644 --- a/src/mock/ray/core_worker/memory_store.h +++ b/src/mock/ray/core_worker/memory_store.h @@ -46,7 +46,7 @@ class DefaultCoreWorkerMemoryStoreWithThread : public CoreWorkerMemoryStore { private: explicit DefaultCoreWorkerMemoryStoreWithThread( std::unique_ptr io_context) - : CoreWorkerMemoryStore(io_context->GetIoService(), /*reference_counting=*/false), + : CoreWorkerMemoryStore(io_context->GetIoService()), io_context_(std::move(io_context)) {} std::unique_ptr io_context_; diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 6c11d7ee50c0..fe97f3c1fba9 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -127,14 +127,14 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore( instrumented_io_context &io_context, - bool reference_counting, + bool reference_counting_enabled, std::shared_ptr raylet_ipc_client, std::function check_signals, std::function unhandled_exception_handler, std::function( const ray::RayObject &object, const ObjectID &object_id)> object_allocator) : io_context_(io_context), - reference_counting_(reference_counting), + reference_counting_enabled_(reference_counting_enabled), raylet_ipc_client_(std::move(raylet_ipc_client)), check_signals_(std::move(check_signals)), unhandled_exception_handler_(std::move(unhandled_exception_handler)), @@ -212,7 +212,7 @@ void CoreWorkerMemoryStore::Put(const RayObject &object, } } // Don't put it in the store, since we won't get a callback for deletion. - if (reference_counting_ && !has_reference) { + if (reference_counting_enabled_ && !has_reference) { should_add_entry = false; } diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 33b337087c2b..75d00385412a 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -51,7 +51,7 @@ class CoreWorkerMemoryStore { /// \param[in] raylet_ipc_client If not null, used to notify tasks blocked / unblocked. explicit CoreWorkerMemoryStore( instrumented_io_context &io_context, - bool reference_counting, + bool reference_counting_enabled = true, std::shared_ptr raylet_ipc_client = nullptr, std::function check_signals = nullptr, std::function unhandled_exception_handler = nullptr, @@ -202,7 +202,7 @@ class CoreWorkerMemoryStore { instrumented_io_context &io_context_; /// Set to true if reference counting is enabled (i.e. not local mode). - bool reference_counting_; + bool reference_counting_enabled_; // If set, this will be used to notify worker blocked / unblocked on get calls. std::shared_ptr raylet_ipc_client_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 6d3c7dd138f8..3ce7478e2a84 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -259,7 +259,6 @@ Status CoreWorkerPlasmaStoreProvider::Get( // TODO(57923): Need to understand if batching is necessary. If it's necessary, // then the reason needs to be documented. - // Send initial requests to pull all objects in parallel. bool got_exception = false; int64_t total_size = static_cast(object_ids.size()); for (int64_t start = 0; start < total_size; start += fetch_batch_size_) { diff --git a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc index 4bb0c9a41ce3..54ccf6a30962 100644 --- a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc @@ -89,8 +89,7 @@ class ActorTaskSubmitterTest : public ::testing::TestWithParam { : client_pool_(std::make_shared( [&](const rpc::Address &addr) { return worker_client_; })), worker_client_(std::make_shared()), - store_(std::make_shared(io_context, - /*reference_counting=*/true)), + store_(std::make_shared(io_context)), task_manager_(std::make_shared()), io_work(io_context.get_executor()), reference_counter_(std::make_shared()), diff --git a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc index 1f116e1dcaa9..28c87febf837 100644 --- a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc +++ b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc @@ -383,9 +383,7 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { TEST(LocalDependencyResolverTest, TestCancelDependencyResolution) { InstrumentedIOContextWithThread io_context("TestCancelDependencyResolution"); - // Mock reference counter as enabled - auto store = std::make_shared(io_context.GetIoService(), - /*reference_counting=*/true); + auto store = std::make_shared(io_context.GetIoService()); auto task_manager = std::make_shared(); FakeActorCreator actor_creator; LocalDependencyResolver resolver( diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index c7d51b3e42eb..815824029014 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -1513,8 +1513,7 @@ void TestSchedulingKey(const std::shared_ptr store, TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { InstrumentedIOContextWithThread io_context("TestSchedulingKeys"); // Mock reference counter as enabled - auto memory_store = std::make_shared( - io_context.GetIoService(), /*reference_counting=*/true); + auto memory_store = std::make_shared(io_context.GetIoService()); std::unordered_map resources1({{"a", 1.0}}); std::unordered_map resources2({{"b", 2.0}}); @@ -1598,8 +1597,8 @@ TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { TEST_F(NormalTaskSubmitterTest, TestBacklogReport) { InstrumentedIOContextWithThread store_io_context("TestBacklogReport"); // Mock reference counter as enabled - auto memory_store = std::make_shared( - store_io_context.GetIoService(), /*reference_counting=*/true); + auto memory_store = + std::make_shared(store_io_context.GetIoService()); auto submitter = CreateNormalTaskSubmitter(std::make_shared(1), WorkerType::WORKER, diff --git a/src/ray/core_worker/tests/memory_store_test.cc b/src/ray/core_worker/tests/memory_store_test.cc index 893cf2a3c819..c927987afd3d 100644 --- a/src/ray/core_worker/tests/memory_store_test.cc +++ b/src/ray/core_worker/tests/memory_store_test.cc @@ -239,8 +239,7 @@ class TestMemoryStoreWait : public ::testing::Test { protected: TestMemoryStoreWait() : io_context("TestWait"), - memory_store(std::make_shared( - io_context.GetIoService(), /*reference_counting=*/true)), + memory_store(std::make_shared(io_context.GetIoService())), ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(1)), buffer("hello"), memory_store_object( diff --git a/src/ray/core_worker/tests/object_recovery_manager_test.cc b/src/ray/core_worker/tests/object_recovery_manager_test.cc index f1be4cc6c9ab..6e811c45e5b0 100644 --- a/src/ray/core_worker/tests/object_recovery_manager_test.cc +++ b/src/ray/core_worker/tests/object_recovery_manager_test.cc @@ -126,8 +126,8 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { publisher_(std::make_shared()), subscriber_(std::make_shared()), object_directory_(std::make_shared()), - memory_store_(std::make_shared( - io_context_.GetIoService(), /*reference_counting=*/true)), + memory_store_( + std::make_shared(io_context_.GetIoService())), raylet_client_pool_(std::make_shared( [&](const rpc::Address &) { return raylet_client_; })), raylet_client_(std::make_shared()), diff --git a/src/ray/core_worker/tests/reference_counter_test.cc b/src/ray/core_worker/tests/reference_counter_test.cc index ccbfe013efb5..943bcb9cf3cd 100644 --- a/src/ray/core_worker/tests/reference_counter_test.cc +++ b/src/ray/core_worker/tests/reference_counter_test.cc @@ -822,7 +822,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { subscriber.get(), /*is_node_dead=*/[](const NodeID &) { return false; }); InstrumentedIOContextWithThread io_context("TestSimple"); - CoreWorkerMemoryStore store(io_context.GetIoService(), rc != nullptr); + CoreWorkerMemoryStore store(io_context.GetIoService()); // Tests putting an object with no references is ignored. store.Put(buffer, id2, rc->HasReference(id2)); diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index 04eaab0027f9..59a22cb7f992 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -155,8 +155,7 @@ class TaskManagerTest : public ::testing::Test { /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled)), io_context_("TaskManagerTest"), - store_(std::make_shared(io_context_.GetIoService(), - reference_counter_ != nullptr)), + store_(std::make_shared(io_context_.GetIoService())), manager_( *store_, *reference_counter_, @@ -1377,8 +1376,7 @@ TEST_F(TaskManagerTest, PlasmaPut_ObjectStoreFull_FailsTaskAndWritesError) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared( - io_context_.GetIoService(), reference_counter_ != nullptr); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager failing_mgr( *local_store, @@ -1437,8 +1435,7 @@ TEST_F(TaskManagerTest, PlasmaPut_TransientFull_RetriesThenSucceeds) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared( - io_context_.GetIoService(), reference_counter_ != nullptr); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager retry_mgr( *local_store, *local_ref_counter, @@ -1498,8 +1495,7 @@ TEST_F(TaskManagerTest, DynamicReturn_PlasmaPutFailure_FailsTaskImmediately) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, lineage_pinning_enabled_); - auto local_store = std::make_shared( - io_context_.GetIoService(), reference_counter_ != nullptr); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager dyn_mgr( *local_store, *local_ref_counter, @@ -2992,8 +2988,7 @@ TEST_F(TaskManagerTest, TestRetryErrorMessageSentToCallback) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, false); - auto local_store = std::make_shared( - io_context_.GetIoService(), reference_counter_ != nullptr); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager test_manager( *local_store, @@ -3072,8 +3067,7 @@ TEST_F(TaskManagerTest, TestErrorLogWhenPushErrorCallbackFails) { subscriber_.get(), /*is_node_dead=*/[this](const NodeID &) { return node_died_; }, false); - auto local_store = std::make_shared( - io_context_.GetIoService(), reference_counter_ != nullptr); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager test_manager( *local_store, From 60e327b5470e7f79d2331f792966b0803463bc16 Mon Sep 17 00:00:00 2001 From: davik Date: Tue, 4 Nov 2025 01:35:14 +0000 Subject: [PATCH 4/5] Move request id logic to worker side Signed-off-by: davik --- .../store_provider/plasma_store_provider.cc | 6 +++-- .../store_provider/plasma_store_provider.h | 1 + src/ray/flatbuffers/node_manager.fbs | 1 + src/ray/raylet/lease_dependency_manager.cc | 10 +++----- src/ray/raylet/lease_dependency_manager.h | 10 +++----- src/ray/raylet/node_manager.cc | 25 +++++-------------- src/ray/raylet/node_manager.h | 7 +++--- .../tests/lease_dependency_manager_test.cc | 18 ++++++------- .../raylet_ipc_client/raylet_ipc_client.cc | 19 +++++++------- src/ray/raylet_ipc_client/raylet_ipc_client.h | 3 ++- .../raylet_ipc_client_interface.h | 5 ++-- 11 files changed, 46 insertions(+), 59 deletions(-) diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index daa2316cf3d0..8796cca5b5da 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -75,7 +75,8 @@ CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( store_client_(std::move(store_client)), reference_counter_(reference_counter), check_signals_(std::move(check_signals)), - fetch_batch_size_(fetch_batch_size) { + fetch_batch_size_(fetch_batch_size), + get_request_counter_(0) { if (get_current_call_site != nullptr) { get_current_call_site_ = get_current_call_site; } else { @@ -256,6 +257,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( int64_t timeout_ms, absl::flat_hash_map> *results) { std::vector get_request_cleanup_handlers; + int64_t get_request_id = get_request_counter_.fetch_add(1); bool got_exception = false; absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); @@ -276,7 +278,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( std::vector owner_addresses = reference_counter_.GetOwnerAddresses(batch_ids); StatusOr status_or_cleanup = - raylet_ipc_client_->AsyncGetObjects(batch_ids, owner_addresses); + raylet_ipc_client_->AsyncGetObjects(batch_ids, owner_addresses, get_request_id); RAY_RETURN_NOT_OK(status_or_cleanup.status()); get_request_cleanup_handlers.emplace_back(std::move(status_or_cleanup.value())); diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index eed757a17c0f..42308b905aaa 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -261,6 +261,7 @@ class CoreWorkerPlasmaStoreProvider { // Pointer to the shared buffer tracker. std::shared_ptr buffer_tracker_; int64_t fetch_batch_size_ = 0; + std::atomic get_request_counter_; }; } // namespace core diff --git a/src/ray/flatbuffers/node_manager.fbs b/src/ray/flatbuffers/node_manager.fbs index d1cfa299f48b..e5526b2545f5 100644 --- a/src/ray/flatbuffers/node_manager.fbs +++ b/src/ray/flatbuffers/node_manager.fbs @@ -134,6 +134,7 @@ table AsyncGetObjectsRequest { // Object IDs that we want the Raylet to pull locally. object_ids: [string]; owner_addresses: [Address]; + get_request_id: long; } table AsyncGetObjectsReply { diff --git a/src/ray/raylet/lease_dependency_manager.cc b/src/ray/raylet/lease_dependency_manager.cc index 678e42c3ccf1..a86a6313bdba 100644 --- a/src/ray/raylet/lease_dependency_manager.cc +++ b/src/ray/raylet/lease_dependency_manager.cc @@ -115,8 +115,10 @@ void LeaseDependencyManager::CancelWaitRequest(const WorkerID &worker_id) { wait_requests_.erase(req_iter); } -GetRequestId LeaseDependencyManager::StartGetRequest( - const WorkerID &worker_id, std::vector &&required_objects) { +void LeaseDependencyManager::StartGetRequest( + const WorkerID &worker_id, + std::vector &&required_objects, + int64_t get_request_id) { std::vector object_ids; object_ids.reserve(required_objects.size()); @@ -130,16 +132,12 @@ GetRequestId LeaseDependencyManager::StartGetRequest( uint64_t new_pull_request_id = object_manager_.Pull( std::move(required_objects), BundlePriority::GET_REQUEST, {"", false}); - const GetRequestId get_request_id = get_request_counter_++; - const std::pair worker_and_request_ids = std::make_pair(worker_id, get_request_id); get_requests_.emplace(std::move(worker_and_request_ids), std::make_pair(std::move(object_ids), new_pull_request_id)); worker_to_requests_[worker_id].emplace(get_request_id); - - return get_request_id; } void LeaseDependencyManager::CancelGetRequest(const WorkerID &worker_id, diff --git a/src/ray/raylet/lease_dependency_manager.h b/src/ray/raylet/lease_dependency_manager.h index 92396fd5dade..174504f3f4fe 100644 --- a/src/ray/raylet/lease_dependency_manager.h +++ b/src/ray/raylet/lease_dependency_manager.h @@ -135,9 +135,10 @@ class LeaseDependencyManager : public LeaseDependencyManagerInterface { /// \param worker_id The ID of the worker that called `ray.get`. /// \param required_objects The objects required by the worker. - /// \return the request id which will be used for cleanup. - GetRequestId StartGetRequest(const WorkerID &worker_id, - std::vector &&required_objects); + /// \param get_request_id The ID of the get request. + void StartGetRequest(const WorkerID &worker_id, + std::vector &&required_objects, + int64_t get_request_id); /// Cleans up either an inflight or finished get request. Cancels the underlying /// pull if necessary. @@ -302,9 +303,6 @@ class LeaseDependencyManager : public LeaseDependencyManagerInterface { /// dependencies are all local or not. absl::flat_hash_map> queued_lease_requests_; - /// Used to generate monotonically increasing get request ids. - GetRequestId get_request_counter_; - // Maps a GetRequest to the PullRequest Id and the set of ObjectIDs. // Used to cleanup a finished or cancel an inflight get request. // TODO(57911): This can be slimmed down. We do not need to track the ObjectIDs. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 9e2a45641ba8..f2618d11b5d1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1608,21 +1608,7 @@ void NodeManager::HandleAsyncGetObjectsRequest( auto request = flatbuffers::GetRoot(message_data); std::vector refs = FlatbufferToObjectReferences(*request->object_ids(), *request->owner_addresses()); - int64_t request_id = AsyncGet(client, refs); - flatbuffers::FlatBufferBuilder fbb; - auto get_reply = protocol::CreateAsyncGetObjectsReply(fbb, request_id); - fbb.Finish(get_reply); - Status status = client->WriteMessage( - static_cast(protocol::MessageType::AsyncGetObjectsReply), - fbb.GetSize(), - fbb.GetBufferPointer()); - if (!status.ok()) { - DisconnectClient(client, - /*graceful=*/false, - rpc::WorkerExitType::SYSTEM_ERROR, - absl::StrFormat("Could not send AsyncGetObjectsReply because of %s", - status.ToString())); - } + AsyncGet(client, refs, request->get_request_id()); } void NodeManager::ProcessWaitRequestMessage( @@ -2357,15 +2343,16 @@ void NodeManager::HandleNotifyWorkerUnblocked( } } -int64_t NodeManager::AsyncGet(const std::shared_ptr &client, - std::vector &object_refs) { +void NodeManager::AsyncGet(const std::shared_ptr &client, + std::vector &object_refs, + int64_t get_request_id) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (!worker) { worker = worker_pool_.GetRegisteredDriver(client); } RAY_CHECK(worker); - return lease_dependency_manager_.StartGetRequest(worker->WorkerId(), - std::move(object_refs)); + lease_dependency_manager_.StartGetRequest( + worker->WorkerId(), std::move(object_refs), get_request_id); } void NodeManager::AsyncWait(const std::shared_ptr &client, diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 843f37e28f5e..5f1ab6eafcd6 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -415,9 +415,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// \param client The client that is requesting the objects. /// \param object_refs The objects that are requested. /// - /// \return the request_id that will be used to cancel the get request. - int64_t AsyncGet(const std::shared_ptr &client, - std::vector &object_refs); + /// \param get_request_id The ID of the get request. + void AsyncGet(const std::shared_ptr &client, + std::vector &object_refs, + int64_t get_request_id); /// Cancel all ongoing get requests from the client. /// diff --git a/src/ray/raylet/tests/lease_dependency_manager_test.cc b/src/ray/raylet/tests/lease_dependency_manager_test.cc index 6c662e368504..18dcb2e79095 100644 --- a/src/ray/raylet/tests/lease_dependency_manager_test.cc +++ b/src/ray/raylet/tests/lease_dependency_manager_test.cc @@ -243,15 +243,14 @@ TEST_F(LeaseDependencyManagerTest, TestLeaseArgEviction) { TEST_F(LeaseDependencyManagerTest, TestCancelingSingleGetRequestForWorker) { WorkerID worker_id = WorkerID::FromRandom(); int num_requests = 5; - std::vector requests; - for (int i = 0; i < num_requests; i++) { + for (int64_t i = 0; i < num_requests; i++) { ObjectID argument_id = ObjectID::FromRandom(); - requests.emplace_back(lease_dependency_manager_.StartGetRequest( - worker_id, ObjectIdsToRefs({argument_id}))); + lease_dependency_manager_.StartGetRequest( + worker_id, ObjectIdsToRefs({argument_id}), i); } ASSERT_EQ(object_manager_mock_.active_get_requests.size(), num_requests); - for (int i = 0; i < num_requests; i++) { - lease_dependency_manager_.CancelGetRequest(worker_id, requests[i]); + for (int64_t i = 0; i < num_requests; i++) { + lease_dependency_manager_.CancelGetRequest(worker_id, i); ASSERT_EQ(object_manager_mock_.active_get_requests.size(), num_requests - (i + 1)); } AssertNoLeaks(); @@ -260,11 +259,10 @@ TEST_F(LeaseDependencyManagerTest, TestCancelingSingleGetRequestForWorker) { TEST_F(LeaseDependencyManagerTest, TestCancelingAllGetRequestsForWorker) { WorkerID worker_id = WorkerID::FromRandom(); int num_requests = 5; - std::vector requests; - for (int i = 0; i < num_requests; i++) { + for (int64_t i = 0; i < num_requests; i++) { ObjectID argument_id = ObjectID::FromRandom(); - requests.emplace_back(lease_dependency_manager_.StartGetRequest( - worker_id, ObjectIdsToRefs({argument_id}))); + lease_dependency_manager_.StartGetRequest( + worker_id, ObjectIdsToRefs({argument_id}), i); } ASSERT_EQ(object_manager_mock_.active_get_requests.size(), num_requests); lease_dependency_manager_.CancelGetRequest(worker_id); diff --git a/src/ray/raylet_ipc_client/raylet_ipc_client.cc b/src/ray/raylet_ipc_client/raylet_ipc_client.cc index c17e3f947476..1307ecd570ce 100644 --- a/src/ray/raylet_ipc_client/raylet_ipc_client.cc +++ b/src/ray/raylet_ipc_client/raylet_ipc_client.cc @@ -195,23 +195,22 @@ Status RayletIpcClient::ActorCreationTaskDone() { StatusOr RayletIpcClient::AsyncGetObjects( const std::vector &object_ids, - const std::vector &owner_addresses) { + const std::vector &owner_addresses, + int64_t get_request_id) { RAY_CHECK(object_ids.size() == owner_addresses.size()); flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = flatbuf::to_flatbuf(fbb, object_ids); - auto message = protocol::CreateAsyncGetObjectsRequest( - fbb, object_ids_message, AddressesToFlatbuffer(fbb, owner_addresses)); + auto message = + protocol::CreateAsyncGetObjectsRequest(fbb, + object_ids_message, + AddressesToFlatbuffer(fbb, owner_addresses), + get_request_id); fbb.Finish(message); std::vector reply; // TODO(57923): This should be FATAL. Local sockets are reliable. If a worker is unable // to communicate with the raylet, there's no way to recover. - RAY_RETURN_NOT_OK(AtomicRequestReply(MessageType::AsyncGetObjectsRequest, - MessageType::AsyncGetObjectsReply, - &reply, - &fbb)); - auto reply_message = flatbuffers::GetRoot(reply.data()); - int64_t request_id = reply_message->request_id(); - return ScopedResponse([this, request_id_to_cleanup = request_id]() { + RAY_RETURN_NOT_OK(WriteMessage(MessageType::AsyncGetObjectsRequest, &fbb)); + return ScopedResponse([this, request_id_to_cleanup = get_request_id]() { return CancelGetRequest(request_id_to_cleanup); }); } diff --git a/src/ray/raylet_ipc_client/raylet_ipc_client.h b/src/ray/raylet_ipc_client/raylet_ipc_client.h index 8fbf67c1f7d9..b47d34c1c8e3 100644 --- a/src/ray/raylet_ipc_client/raylet_ipc_client.h +++ b/src/ray/raylet_ipc_client/raylet_ipc_client.h @@ -72,7 +72,8 @@ class RayletIpcClient : public RayletIpcClientInterface { StatusOr AsyncGetObjects( const std::vector &object_ids, - const std::vector &owner_addresses) override; + const std::vector &owner_addresses, + int64_t get_request_id) override; StatusOr> Wait( const std::vector &object_ids, diff --git a/src/ray/raylet_ipc_client/raylet_ipc_client_interface.h b/src/ray/raylet_ipc_client/raylet_ipc_client_interface.h index 04ededbd43b9..31cb6cab2190 100644 --- a/src/ray/raylet_ipc_client/raylet_ipc_client_interface.h +++ b/src/ray/raylet_ipc_client/raylet_ipc_client_interface.h @@ -51,7 +51,7 @@ class ScopedResponse { } ScopedResponse &operator=(ScopedResponse &&other) { - if (this == &other) { + if (this != &other) { HandleCleanup(); this->cleanup_ = other.cleanup_; other.cleanup_ = nullptr; @@ -150,7 +150,8 @@ class RayletIpcClientInterface { /// request to clean up the GetObjectsRequest upon destruction. virtual StatusOr AsyncGetObjects( const std::vector &object_ids, - const std::vector &owner_addresses) = 0; + const std::vector &owner_addresses, + int64_t get_request_id) = 0; /// Wait for the given objects until timeout expires or num_return objects are /// found. From f9cc6913c17f0dedef6525a8b3a850bd3f4ac5d4 Mon Sep 17 00:00:00 2001 From: davik Date: Wed, 12 Nov 2025 20:24:54 +0000 Subject: [PATCH 5/5] Address function header comments Signed-off-by: davik --- .../runtime/object/local_mode_object_store.cc | 5 +++-- src/ray/core_worker/core_worker.cc | 3 ++- src/ray/core_worker/core_worker_process.cc | 2 +- .../memory_store/memory_store.cc | 10 ++++------ .../store_provider/plasma_store_provider.h | 19 +++++++------------ .../core_worker/tests/memory_store_test.cc | 4 ++-- 6 files changed, 19 insertions(+), 24 deletions(-) diff --git a/cpp/src/ray/runtime/object/local_mode_object_store.cc b/cpp/src/ray/runtime/object/local_mode_object_store.cc index b1c8030e7c16..c886009cea74 100644 --- a/cpp/src/ray/runtime/object/local_mode_object_store.cc +++ b/cpp/src/ray/runtime/object/local_mode_object_store.cc @@ -28,8 +28,9 @@ namespace internal { LocalModeObjectStore::LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime) : io_context_("LocalModeObjectStore"), local_mode_ray_tuntime_(local_mode_ray_tuntime) { - memory_store_ = std::make_unique(io_context_.GetIoService(), - /*reference_counting=*/false); + memory_store_ = + std::make_unique(io_context_.GetIoService(), + /*reference_counting_enabled=*/false); } void LocalModeObjectStore::PutRaw(std::shared_ptr data, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d6eaff5462cb..a22d76b9274a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -3100,7 +3100,8 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, std::vector object_ids = {return_id}; auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); - auto status = plasma_store_provider_->Get(object_ids, owner_addresses, 0, &result_map); + Status status = + plasma_store_provider_->Get(object_ids, owner_addresses, 0, &result_map); // Remove the temporary ref. RemoveLocalReference(return_id); diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 7bf39f1f9b65..b4894d244b24 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -364,7 +364,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( }); auto memory_store = std::make_shared( io_service_, - /*reference_counting=*/reference_counter != nullptr, + /*reference_counting_enabled=*/reference_counter != nullptr, raylet_ipc_client, options.check_signals, [this](const RayObject &obj) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 22425d564b3e..fbf5e4a4e24a 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -203,7 +203,6 @@ void CoreWorkerMemoryStore::Put(const RayObject &object, object_async_get_requests_.erase(async_callback_it); } - bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { auto &get_requests = object_request_iter->second; @@ -211,12 +210,11 @@ void CoreWorkerMemoryStore::Put(const RayObject &object, get_request->Set(object_id, object_entry); } } - // Don't put it in the store, since we won't get a callback for deletion. - if (reference_counting_enabled_ && !has_reference) { - should_add_entry = false; - } - if (should_add_entry) { + // Don't put it in the store, if we can't get a callback for deletion. + // The exception here is if we are in local mode, put should still put the object in + // the store. + if (!reference_counting_enabled_ || has_reference) { // If there is no existing get request, then add the `RayObject` to map. EmplaceObjectAndUpdateStats(object_id, object_entry); } else { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 47689264286a..f98e83ba0a3b 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -223,18 +223,13 @@ class CoreWorkerPlasmaStoreProvider { /// Successfully fetched objects will be removed from the input set of remaining IDs and /// added to the results map. /// - /// \param[in/out] Map of object IDs to their indices left to get. - /// \param[in] batch_ids IDs of the objects to get. - /// \param[in] batch_owner_addresses owner addresses of the objects. - /// \param[in] timeout_ms Timeout in milliseconds. - /// \param[out] results Map of objects to write results into. This method will only - /// add to this map, not clear or remove from it, so the caller can pass in a non-empty - /// map. - /// \param[out] got_exception Set to true if any of the fetched objects contained an - /// exception. - /// \return Status::IOError if there is an error in communicating with the raylet or the - /// plasma store. - /// \return Status::OK if successful. + /// \param[in/out] remaining_object_id_to_idx map of object IDs to their indices left to + /// get. \param[in] ids IDs of the objects to get. \param[in] timeout_ms Timeout in + /// milliseconds. \param[out] results Map of objects to write results into. This method + /// will only add to this map, not clear or remove from it, so the caller can pass in a + /// non-empty map. \param[out] got_exception Set to true if any of the fetched objects + /// contained an exception. \return Status::IOError if there is an error in + /// communicating with the raylet or the plasma store. \return Status::OK if successful. Status GetObjectsFromPlasmaStore( absl::flat_hash_map &remaining_object_id_to_idx, const std::vector &ids, diff --git a/src/ray/core_worker/tests/memory_store_test.cc b/src/ray/core_worker/tests/memory_store_test.cc index c927987afd3d..30f9ac914409 100644 --- a/src/ray/core_worker/tests/memory_store_test.cc +++ b/src/ray/core_worker/tests/memory_store_test.cc @@ -53,7 +53,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr memory_store = std::make_shared( io_context.GetIoService(), - /*reference_counting=*/true, + /*reference_counting_enabled=*/true, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); @@ -209,7 +209,7 @@ TEST(TestMemoryStore, TestObjectAllocator) { std::shared_ptr memory_store = std::make_shared(io_context.GetIoService(), - /*reference_counting=*/true, + /*reference_counting_enabled=*/true, nullptr, nullptr, nullptr,