diff --git a/cpp/src/ray/runtime/object/local_mode_object_store.cc b/cpp/src/ray/runtime/object/local_mode_object_store.cc index 4c6ea46e22a1..c886009cea74 100644 --- a/cpp/src/ray/runtime/object/local_mode_object_store.cc +++ b/cpp/src/ray/runtime/object/local_mode_object_store.cc @@ -28,7 +28,9 @@ namespace internal { LocalModeObjectStore::LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime) : io_context_("LocalModeObjectStore"), local_mode_ray_tuntime_(local_mode_ray_tuntime) { - memory_store_ = std::make_unique(io_context_.GetIoService()); + memory_store_ = + std::make_unique(io_context_.GetIoService(), + /*reference_counting_enabled=*/false); } void LocalModeObjectStore::PutRaw(std::shared_ptr data, @@ -41,8 +43,11 @@ void LocalModeObjectStore::PutRaw(std::shared_ptr data, const ObjectID &object_id) { auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( reinterpret_cast(data->data()), data->size(), true); + // NOTE: you can't have reference when reference counting is disabled in local mode memory_store_->Put( - ::ray::RayObject(buffer, nullptr, std::vector()), object_id); + ::ray::RayObject(buffer, nullptr, std::vector()), + object_id, + /*has_reference=*/false); } std::shared_ptr LocalModeObjectStore::GetRaw(const ObjectID &object_id, @@ -61,7 +66,6 @@ std::vector> LocalModeObjectStore::GetRaw( (int)ids.size(), timeout_ms, local_mode_ray_tuntime_.GetWorkerContext(), - false, &results); if (!status.ok()) { throw RayException("Get object error: " + status.ToString()); diff --git a/src/ray/core_worker/BUILD.bazel b/src/ray/core_worker/BUILD.bazel index a1a5054d6aba..25c3baa4004a 100644 --- a/src/ray/core_worker/BUILD.bazel +++ b/src/ray/core_worker/BUILD.bazel @@ -271,7 +271,6 @@ ray_cc_library( hdrs = ["store_provider/memory_store/memory_store.h"], deps = [ ":core_worker_context", - ":reference_counter_interface", "//src/ray/common:asio", "//src/ray/common:id", "//src/ray/common:metrics", @@ -290,6 +289,7 @@ ray_cc_library( name = "task_manager_interface", hdrs = ["task_manager_interface.h"], deps = [ + ":reference_counter_interface", "//src/ray/common:id", "//src/ray/common:status", "//src/ray/common:task_common", @@ -348,6 +348,7 @@ ray_cc_library( hdrs = ["future_resolver.h"], deps = [ ":memory_store", + ":reference_counter_interface", "//src/ray/common:id", "//src/ray/core_worker_rpc_client:core_worker_client_pool", ], @@ -408,7 +409,6 @@ ray_cc_library( deps = [ ":common", ":core_worker_context", - ":reference_counter_interface", "//src/ray/common:buffer", "//src/ray/common:id", "//src/ray/common:ray_config", diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 68128a78436c..cbf67e4b1087 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1005,7 +1005,9 @@ Status CoreWorker::PutInLocalPlasmaStore(const RayObject &object, RAY_RETURN_NOT_OK(plasma_store_provider_->Release(object_id)); } } - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); return Status::OK(); } @@ -1016,7 +1018,7 @@ Status CoreWorker::Put(const RayObject &object, RAY_RETURN_NOT_OK(WaitForActorRegistered(contained_object_ids)); if (options_.is_local_mode) { RAY_LOG(DEBUG).WithField(object_id) << "Put object in memory store"; - memory_store_->Put(object, object_id); + memory_store_->Put(object, object_id, reference_counter_->HasReference(object_id)); return Status::OK(); } return PutInLocalPlasmaStore(object, object_id, pin_object); @@ -1115,7 +1117,9 @@ Status CoreWorker::CreateOwnedAndIncrementLocalRef( } else if (*data == nullptr) { // Object already exists in plasma. Store the in-memory value so that the // client will check the plasma store. - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), *object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + *object_id, + reference_counter_->HasReference(*object_id)); } } return Status::OK(); @@ -1212,7 +1216,9 @@ Status CoreWorker::SealExisting(const ObjectID &object_id, RAY_RETURN_NOT_OK(plasma_store_provider_->Release(object_id)); reference_counter_->FreePlasmaObjects({object_id}); } - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); return Status::OK(); } @@ -1384,14 +1390,20 @@ Status CoreWorker::GetObjects(const std::vector &ids, // If any of the objects have been promoted to plasma, then we retry their // gets at the provider plasma. Once we get the objects from plasma, we flip // the transport type again and return them for the original direct call ids. + + // Prepare object ids vector and owner addresses vector + std::vector object_ids = + std::vector(plasma_object_ids.begin(), plasma_object_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); + int64_t local_timeout_ms = timeout_ms; if (timeout_ms >= 0) { local_timeout_ms = std::max(static_cast(0), timeout_ms - (current_time_ms() - start_time)); } RAY_LOG(DEBUG) << "Plasma GET timeout " << local_timeout_ms; - RAY_RETURN_NOT_OK( - plasma_store_provider_->Get(plasma_object_ids, local_timeout_ms, &result_map)); + RAY_RETURN_NOT_OK(plasma_store_provider_->Get( + object_ids, owner_addresses, local_timeout_ms, &result_map)); } // Loop through `ids` and fill each entry for the `results` vector, @@ -1535,8 +1547,14 @@ Status CoreWorker::Wait(const std::vector &ids, // num_objects ready since we want to at least make the request to start pulling // these objects. if (!plasma_object_ids.empty()) { + // Prepare object ids map + std::vector object_ids = + std::vector(plasma_object_ids.begin(), plasma_object_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, + object_ids, + owner_addresses, std::min(static_cast(plasma_object_ids.size()), num_objects - static_cast(ready.size())), timeout_ms, @@ -3079,7 +3097,12 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, reference_counter_->AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); - Status status = plasma_store_provider_->Get({return_id}, 0, &result_map); + // Resolve owner address of return id + std::vector object_ids = {return_id}; + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); + + Status status = + plasma_store_provider_->Get(object_ids, owner_addresses, 0, &result_map); // Remove the temporary ref. RemoveLocalReference(return_id); @@ -3297,7 +3320,8 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, // otherwise, the put is a no-op. if (!options_.is_local_mode) { memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - task.ArgObjectId(i)); + task.ArgObjectId(i), + reference_counter_->HasReference(task.ArgObjectId(i))); } } else { // A pass-by-value argument. @@ -3346,7 +3370,12 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, RAY_RETURN_NOT_OK(memory_store_->Get( by_ref_ids, -1, *worker_context_, &result_map, &got_exception)); } else { - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(by_ref_ids, -1, &result_map)); + // Resolve owner addresses of by-ref ids + std::vector object_ids = + std::vector(by_ref_ids.begin(), by_ref_ids.end()); + auto owner_addresses = reference_counter_->GetOwnerAddresses(object_ids); + RAY_RETURN_NOT_OK( + plasma_store_provider_->Get(object_ids, owner_addresses, -1, &result_map)); } for (const auto &it : result_map) { for (size_t idx : by_ref_indices[it.first]) { @@ -4182,7 +4211,9 @@ Status CoreWorker::DeleteImpl(const std::vector &object_ids, bool loca memory_store_->Delete(object_ids); for (const auto &object_id : object_ids) { RAY_LOG(DEBUG).WithField(object_id) << "Freeing object"; - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_FREED), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_FREED), + object_id, + reference_counter_->HasReference(object_id)); } // We only delete from plasma, which avoids hangs (issue #7105). In-memory @@ -4332,7 +4363,9 @@ void CoreWorker::HandleAssignObjectOwner(rpc::AssignObjectOwnerRequest request, /*add_local_ref=*/false, /*pinned_at_node_id=*/NodeID::FromBinary(borrower_address.node_id())); reference_counter_->AddBorrowerAddress(object_id, borrower_address); - memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_->HasReference(object_id)); send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 8056e627e345..0c8fd8e45d0c 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -355,7 +355,6 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( auto plasma_store_provider = std::make_shared( options.store_socket, raylet_ipc_client, - *reference_counter, options.check_signals, /*warmup=*/ (options.worker_type != WorkerType::SPILL_WORKER && @@ -368,7 +367,7 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( }); auto memory_store = std::make_shared( io_service_, - reference_counter.get(), + /*reference_counting_enabled=*/reference_counter != nullptr, raylet_ipc_client, options.check_signals, [this](const RayObject &obj) { diff --git a/src/ray/core_worker/future_resolver.cc b/src/ray/core_worker/future_resolver.cc index 153c06a0f84f..8654e64d2228 100644 --- a/src/ray/core_worker/future_resolver.cc +++ b/src/ray/core_worker/future_resolver.cc @@ -52,14 +52,18 @@ void FutureResolver::ProcessResolvedObject(const ObjectID &object_id, if (!status.ok()) { // The owner is unreachable. Store an error so that an exception will be // thrown immediately when the worker tries to get the value. - in_memory_store_->Put(RayObject(rpc::ErrorType::OWNER_DIED), object_id); + in_memory_store_->Put(RayObject(rpc::ErrorType::OWNER_DIED), + object_id, + reference_counter_->HasReference(object_id)); } else if (reply.status() == rpc::GetObjectStatusReply::OUT_OF_SCOPE) { // The owner replied that the object has gone out of scope (this is an edge // case in the distributed ref counting protocol where a borrower dies // before it can notify the owner of another borrower). Store an error so // that an exception will be thrown immediately when the worker tries to // get the value. - in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_DELETED), object_id); + in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_DELETED), + object_id, + reference_counter_->HasReference(object_id)); } else if (reply.status() == rpc::GetObjectStatusReply::CREATED) { // The object is either an indicator that the object is in Plasma, or // the object has been returned directly in the reply. In either @@ -106,7 +110,8 @@ void FutureResolver::ProcessResolvedObject(const ObjectID &object_id, inlined_ref.owner_address()); } in_memory_store_->Put(RayObject(data_buffer, metadata_buffer, inlined_refs), - object_id); + object_id, + reference_counter_->HasReference(object_id)); } } diff --git a/src/ray/core_worker/future_resolver.h b/src/ray/core_worker/future_resolver.h index 04caaeba8ffb..a48471f5a19b 100644 --- a/src/ray/core_worker/future_resolver.h +++ b/src/ray/core_worker/future_resolver.h @@ -18,6 +18,7 @@ #include #include "ray/common/id.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker_rpc_client/core_worker_client_pool.h" #include "src/ray/protobuf/core_worker.pb.h" diff --git a/src/ray/core_worker/object_recovery_manager.cc b/src/ray/core_worker/object_recovery_manager.cc index 893fb4fabd6d..83a6cad70648 100644 --- a/src/ray/core_worker/object_recovery_manager.cc +++ b/src/ray/core_worker/object_recovery_manager.cc @@ -81,7 +81,9 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) { // (core_worker.cc removes the object from memory store before calling this method), // we need to add it back to indicate that it's available. // If the object is already in the memory store then the put is a no-op. - in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_.HasReference(object_id)); } return true; } @@ -121,7 +123,8 @@ void ObjectRecoveryManager::PinExistingObjectCopy( const Status &status, const rpc::PinObjectIDsReply &reply) mutable { if (status.ok() && reply.successes(0)) { in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - object_id); + object_id, + reference_counter_.HasReference(object_id)); reference_counter_.UpdateObjectPinnedAtRaylet(object_id, node_id); } else { RAY_LOG(INFO).WithField(object_id) diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index f3e444cca5c4..fbf5e4a4e24a 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -42,7 +42,6 @@ class GetRequest { public: GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception); const absl::flat_hash_set &ObjectIds() const; @@ -56,8 +55,6 @@ class GetRequest { void Set(const ObjectID &object_id, std::shared_ptr buffer); /// Get the object content for the specific object id. std::shared_ptr Get(const ObjectID &object_id) const; - /// Whether this is a `get` request. - bool ShouldRemoveObjects() const; private: /// The object IDs involved in this request. @@ -67,9 +64,6 @@ class GetRequest { /// Number of objects required. const size_t num_objects_; - // Whether the requested objects should be removed from store - // after `get` returns. - const bool remove_after_get_; // Whether we should abort the waiting if any object is an exception. const bool abort_if_any_object_is_exception_; // Whether all the requested objects are available. @@ -80,19 +74,15 @@ class GetRequest { GetRequest::GetRequest(absl::flat_hash_set object_ids, size_t num_objects, - bool remove_after_get, bool abort_if_any_object_is_exception) : object_ids_(std::move(object_ids)), num_objects_(num_objects), - remove_after_get_(remove_after_get), abort_if_any_object_is_exception_(abort_if_any_object_is_exception) { RAY_CHECK(num_objects_ <= object_ids_.size()); } const absl::flat_hash_set &GetRequest::ObjectIds() const { return object_ids_; } -bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; } - bool GetRequest::Wait(int64_t timeout_ms) { RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); if (timeout_ms == -1) { @@ -137,14 +127,14 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore( instrumented_io_context &io_context, - ReferenceCounterInterface *counter, + bool reference_counting_enabled, std::shared_ptr raylet_ipc_client, std::function check_signals, std::function unhandled_exception_handler, std::function( const ray::RayObject &object, const ObjectID &object_id)> object_allocator) : io_context_(io_context), - ref_counter_(counter), + reference_counting_enabled_(reference_counting_enabled), raylet_ipc_client_(std::move(raylet_ipc_client)), check_signals_(std::move(check_signals)), unhandled_exception_handler_(std::move(unhandled_exception_handler)), @@ -180,7 +170,9 @@ std::shared_ptr CoreWorkerMemoryStore::GetIfExists(const ObjectID &ob return ptr; } -void CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { +void CoreWorkerMemoryStore::Put(const RayObject &object, + const ObjectID &object_id, + const bool has_reference) { std::vector)>> async_callbacks; RAY_LOG(DEBUG).WithField(object_id) << "Putting object into memory store."; std::shared_ptr object_entry = nullptr; @@ -211,24 +203,18 @@ void CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ object_async_get_requests_.erase(async_callback_it); } - bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { auto &get_requests = object_request_iter->second; for (auto &get_request : get_requests) { get_request->Set(object_id, object_entry); - // If ref counting is enabled, override the removal behaviour. - if (get_request->ShouldRemoveObjects() && ref_counter_ == nullptr) { - should_add_entry = false; - } } } - // Don't put it in the store, since we won't get a callback for deletion. - if (ref_counter_ != nullptr && !ref_counter_->HasReference(object_id)) { - should_add_entry = false; - } - if (should_add_entry) { + // Don't put it in the store, if we can't get a callback for deletion. + // The exception here is if we are in local mode, put should still put the object in + // the store. + if (!reference_counting_enabled_ || has_reference) { // If there is no existing get request, then add the `RayObject` to map. EmplaceObjectAndUpdateStats(object_id, object_entry); } else { @@ -261,13 +247,11 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results) { return GetImpl(object_ids, num_objects, timeout_ms, ctx, - remove_after_get, results, /*abort_if_any_object_is_exception=*/true, /*at_most_num_objects=*/true); @@ -277,7 +261,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception, bool at_most_num_objects) { @@ -288,7 +271,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, { absl::flat_hash_set remaining_ids; - absl::flat_hash_set ids_to_remove; bool existing_objects_has_exception = false; absl::MutexLock lock(&mu_); @@ -299,11 +281,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, if (iter != objects_.end()) { iter->second->SetAccessed(); (*results)[i] = iter->second; - if (remove_after_get) { - // Note that we cannot remove the object_id from `objects_` now, - // because `object_ids` might have duplicate ids. - ids_to_remove.insert(object_id); - } num_found += 1; if (abort_if_any_object_is_exception && iter->second->IsException() && !iter->second->IsInPlasmaError()) { @@ -318,13 +295,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, } } - // Clean up the objects if ref counting is off. - if (ref_counter_ == nullptr) { - for (const auto &object_id : ids_to_remove) { - EraseObjectAndUpdateStats(object_id); - } - } - // Return if all the objects are obtained, or any existing objects are known to have // exception. if (remaining_ids.empty() || num_found >= num_objects || @@ -335,10 +305,8 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, size_t required_objects = num_objects - num_found; // Otherwise, create a GetRequest to track remaining objects. - get_request = std::make_shared(std::move(remaining_ids), - required_objects, - remove_after_get, - abort_if_any_object_is_exception); + get_request = std::make_shared( + std::move(remaining_ids), required_objects, abort_if_any_object_is_exception); for (const auto &object_id : get_request->ObjectIds()) { object_get_requests_[object_id].push_back(get_request); } @@ -426,12 +394,7 @@ Status CoreWorkerMemoryStore::Get( bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; - RAY_RETURN_NOT_OK(Get(id_vector, - id_vector.size(), - timeout_ms, - ctx, - /*remove_after_get=*/false, - &result_objects)); + RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -460,7 +423,6 @@ Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_i num_objects, timeout_ms, ctx, - false, &result_objects, /*abort_if_any_object_is_exception=*/false, /*at_most_num_objects=*/false); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 1852a3dc1f7b..cd28cae8c98e 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -27,7 +27,6 @@ #include "ray/common/metrics.h" #include "ray/common/status.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/reference_counter_interface.h" #include "ray/raylet_ipc_client/raylet_ipc_client_interface.h" #include "ray/rpc/utils.h" @@ -50,12 +49,10 @@ class CoreWorkerMemoryStore { /// Create a memory store. /// /// \param[in] io_context Posts async callbacks to this context. - /// \param[in] counter If not null, this enables ref counting for local objects, - /// and the `remove_after_get` flag for Get() will be ignored. /// \param[in] raylet_ipc_client If not null, used to notify tasks blocked / unblocked. explicit CoreWorkerMemoryStore( instrumented_io_context &io_context, - ReferenceCounterInterface *counter = nullptr, + bool reference_counting_enabled = true, std::shared_ptr raylet_ipc_client = nullptr, std::function check_signals = nullptr, std::function unhandled_exception_handler = nullptr, @@ -69,7 +66,9 @@ class CoreWorkerMemoryStore { /// /// \param[in] object The ray object. /// \param[in] object_id Object ID specified by user. - void Put(const RayObject &object, const ObjectID &object_id); + /// \param[in] has_reference Whether the object has a reference in the reference + /// counter. + void Put(const RayObject &object, const ObjectID &object_id, const bool has_reference); /// Get a list of objects from the object store. /// @@ -77,15 +76,12 @@ class CoreWorkerMemoryStore { /// \param[in] num_objects Number of objects that should appear. /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. /// \param[in] ctx The current worker context. - /// \param[in] remove_after_get When to remove the objects from store after `Get` - /// finishes. This has no effect if ref counting is enabled. /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results); /// Convenience wrapper around Get() that stores results in a given result map. @@ -187,7 +183,6 @@ class CoreWorkerMemoryStore { int num_objects, int64_t timeout_ms, const WorkerContext &ctx, - bool remove_after_get, std::vector> *results, bool abort_if_any_object_is_exception, bool at_most_num_objects); @@ -207,9 +202,8 @@ class CoreWorkerMemoryStore { instrumented_io_context &io_context_; - /// If enabled, holds a reference to local worker ref counter. TODO(ekl) make this - /// mandatory once Java is supported. - ReferenceCounterInterface *ref_counter_; + /// Set to true if reference counting is enabled (i.e. not local mode). + bool reference_counting_enabled_; // If set, this will be used to notify worker blocked / unblocked on get calls. std::shared_ptr raylet_ipc_client_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 9e1d2f61dd43..c14e624df16c 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -65,7 +65,6 @@ BufferTracker::UsedObjects() const { CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_ipc_client, - ReferenceCounterInterface &reference_counter, std::function check_signals, bool warmup, std::shared_ptr store_client, @@ -73,7 +72,6 @@ CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( std::function get_current_call_site) : raylet_ipc_client_(raylet_ipc_client), store_client_(std::move(store_client)), - reference_counter_(reference_counter), check_signals_(std::move(check_signals)), fetch_batch_size_(fetch_batch_size), get_request_counter_(0) { @@ -180,7 +178,7 @@ Status CoreWorkerPlasmaStoreProvider::Release(const ObjectID &object_id) { } Status CoreWorkerPlasmaStoreProvider::GetObjectsFromPlasmaStore( - absl::flat_hash_set &remaining, + absl::flat_hash_map &remaining_object_id_to_idx, const std::vector &ids, int64_t timeout_ms, absl::flat_hash_map> *results, @@ -207,7 +205,7 @@ Status CoreWorkerPlasmaStoreProvider::GetObjectsFromPlasmaStore( } auto result_object = std::make_shared( data, metadata, std::vector()); - remaining.erase(object_id); + remaining_object_id_to_idx.erase(object_id); if (result_object->IsException()) { RAY_CHECK(!result_object->IsInPlasmaError()); *got_exception = true; @@ -253,37 +251,36 @@ Status CoreWorkerPlasmaStoreProvider::GetExperimentalMutableObject( } Status CoreWorkerPlasmaStoreProvider::Get( - const absl::flat_hash_set &object_ids, + const std::vector &object_ids, + const std::vector &owner_addresses, int64_t timeout_ms, absl::flat_hash_map> *results) { std::vector get_request_cleanup_handlers; - - bool got_exception = false; - absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); - std::vector id_vector(object_ids.begin(), object_ids.end()); - std::vector batch_ids; - - int64_t num_total_objects = static_cast(object_ids.size()); + absl::flat_hash_map remaining_object_id_to_idx; // TODO(57923): Need to understand if batching is necessary. If it's necessary, // then the reason needs to be documented. + bool got_exception = false; + int64_t num_total_objects = static_cast(object_ids.size()); for (int64_t start = 0; start < num_total_objects; start += fetch_batch_size_) { - batch_ids.clear(); + std::vector batch_ids; + std::vector batch_owner_addresses; for (int64_t i = start; i < start + fetch_batch_size_ && i < num_total_objects; i++) { - batch_ids.push_back(id_vector[i]); + remaining_object_id_to_idx[object_ids[i]] = i; + + batch_ids.push_back(object_ids[i]); + batch_owner_addresses.push_back(owner_addresses[i]); } // 1. Make the request to pull all objects into local plasma if not local already. - std::vector owner_addresses = - reference_counter_.GetOwnerAddresses(batch_ids); StatusOr status_or_cleanup = raylet_ipc_client_->AsyncGetObjects( - batch_ids, owner_addresses, get_request_counter_.fetch_add(1)); + batch_ids, batch_owner_addresses, get_request_counter_.fetch_add(1)); RAY_RETURN_NOT_OK(status_or_cleanup.status()); get_request_cleanup_handlers.emplace_back(std::move(status_or_cleanup.value())); // 2. Try to Get all objects that are already local from the plasma store. RAY_RETURN_NOT_OK( - GetObjectsFromPlasmaStore(remaining, + GetObjectsFromPlasmaStore(remaining_object_id_to_idx, batch_ids, /*timeout_ms=*/0, // Mutable objects must be local before ray.get. @@ -291,7 +288,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( &got_exception)); } - if (remaining.empty() || got_exception) { + if (remaining_object_id_to_idx.empty() || got_exception) { return Status::OK(); } @@ -302,13 +299,15 @@ Status CoreWorkerPlasmaStoreProvider::Get( bool timed_out = false; int64_t remaining_timeout = timeout_ms; auto fetch_start_time_ms = current_time_ms(); - while (!remaining.empty() && !should_break) { - batch_ids.clear(); - for (const auto &id : remaining) { + while (!remaining_object_id_to_idx.empty() && !should_break) { + std::vector batch_ids; + std::vector batch_owner_addresses; + for (const auto &[id, idx] : remaining_object_id_to_idx) { if (static_cast(batch_ids.size()) == fetch_batch_size_) { break; } batch_ids.push_back(id); + batch_owner_addresses.push_back(owner_addresses[idx]); } int64_t batch_timeout = @@ -320,13 +319,13 @@ Status CoreWorkerPlasmaStoreProvider::Get( timed_out = remaining_timeout <= 0; } - size_t previous_size = remaining.size(); + size_t previous_size = remaining_object_id_to_idx.size(); RAY_RETURN_NOT_OK(GetObjectsFromPlasmaStore( - remaining, batch_ids, batch_timeout, results, &got_exception)); + remaining_object_id_to_idx, batch_ids, batch_timeout, results, &got_exception)); should_break = timed_out || got_exception; - if ((previous_size - remaining.size()) < batch_ids.size()) { - WarnIfFetchHanging(fetch_start_time_ms, remaining); + if ((previous_size - remaining_object_id_to_idx.size()) < batch_ids.size()) { + WarnIfFetchHanging(fetch_start_time_ms, remaining_object_id_to_idx); } if (check_signals_) { Status status = check_signals_(); @@ -335,7 +334,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( } } if (RayConfig::instance().yield_plasma_lock_workaround() && !should_break && - remaining.size() > 0) { + remaining_object_id_to_idx.size() > 0) { // Yield the plasma lock to other threads. This is a temporary workaround since we // are holding the lock for a long time, so it can easily starve inbound RPC // requests to Release() buffers which only require holding the lock for brief @@ -344,13 +343,13 @@ Status CoreWorkerPlasmaStoreProvider::Get( } } - if (!remaining.empty() && timed_out) { + if (!remaining_object_id_to_idx.empty() && timed_out) { return Status::TimedOut(absl::StrFormat( "Could not fetch %d objects within the timeout of %dms. %d objects were not " "ready.", object_ids.size(), timeout_ms, - remaining.size())); + remaining_object_id_to_idx.size())); } return Status::OK(); } @@ -361,13 +360,12 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, } Status CoreWorkerPlasmaStoreProvider::Wait( - const absl::flat_hash_set &object_ids, + const std::vector &object_ids, + const std::vector &owner_addresses, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready) { - std::vector id_vector(object_ids.begin(), object_ids.end()); - bool should_break = false; int64_t remaining_timeout = timeout_ms; absl::flat_hash_set ready_in_plasma; @@ -379,10 +377,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait( should_break = remaining_timeout <= 0; } - const auto owner_addresses = reference_counter_.GetOwnerAddresses(id_vector); RAY_ASSIGN_OR_RETURN( ready_in_plasma, - raylet_ipc_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout)); + raylet_ipc_client_->Wait(object_ids, owner_addresses, num_objects, call_timeout)); if (ready_in_plasma.size() >= static_cast(num_objects)) { should_break = true; @@ -416,12 +413,13 @@ CoreWorkerPlasmaStoreProvider::UsedObjectsList() const { } void CoreWorkerPlasmaStoreProvider::WarnIfFetchHanging( - int64_t fetch_start_time_ms, const absl::flat_hash_set &remaining) { + int64_t fetch_start_time_ms, + const absl::flat_hash_map &remaining_object_id_to_idx) { int64_t duration_ms = current_time_ms() - fetch_start_time_ms; if (duration_ms > RayConfig::instance().fetch_warn_timeout_milliseconds()) { std::ostringstream oss; size_t printed = 0; - for (auto &id : remaining) { + for (auto &[id, _] : remaining_object_id_to_idx) { if (printed >= RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) { break; @@ -432,7 +430,7 @@ void CoreWorkerPlasmaStoreProvider::WarnIfFetchHanging( oss << id.Hex(); printed++; } - if (printed < remaining.size()) { + if (printed < remaining_object_id_to_idx.size()) { oss << ", etc"; } RAY_LOG(WARNING) diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 42308b905aaa..f98e83ba0a3b 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -26,7 +26,6 @@ #include "ray/common/status.h" #include "ray/common/status_or.h" #include "ray/core_worker/context.h" -#include "ray/core_worker/reference_counter_interface.h" #include "ray/object_manager/plasma/client.h" #include "ray/raylet_ipc_client/raylet_ipc_client_interface.h" #include "src/ray/protobuf/common.pb.h" @@ -96,7 +95,6 @@ class CoreWorkerPlasmaStoreProvider { CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_ipc_client, - ReferenceCounterInterface &reference_counter, std::function check_signals, bool warmup, std::shared_ptr store_client, @@ -159,6 +157,7 @@ class CoreWorkerPlasmaStoreProvider { /// into the local plasma store from another node. /// /// \param[in] object_ids objects to fetch if they are not already in local plasma. + /// \param[in] owner_addresses owner addresses of the objects. /// \param[in] timeout_ms if the timeout elapses, the request will be canceled. /// \param[out] results objects fetched from plasma. This is only valid if the function /// @@ -169,7 +168,8 @@ class CoreWorkerPlasmaStoreProvider { /// \return Status::IntentionalSystemExit if a SIGTERM signal was was received. /// \return Status::UnexpectedSystemExit if any other signal was received. /// \return Status::OK otherwise. - Status Get(const absl::flat_hash_set &object_ids, + Status Get(const std::vector &object_ids, + const std::vector &owner_addresses, int64_t timeout_ms, absl::flat_hash_map> *results); @@ -200,7 +200,8 @@ class CoreWorkerPlasmaStoreProvider { Status Contains(const ObjectID &object_id, bool *has_object); - Status Wait(const absl::flat_hash_set &object_ids, + Status Wait(const std::vector &object_ids, + const std::vector &owner_addresses, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, @@ -222,19 +223,15 @@ class CoreWorkerPlasmaStoreProvider { /// Successfully fetched objects will be removed from the input set of remaining IDs and /// added to the results map. /// - /// \param[in/out] remaining IDs of the remaining objects to get. - /// \param[in] ids IDs of the objects to get. - /// \param[in] timeout_ms Timeout in milliseconds. - /// \param[out] results Map of objects to write results into. This method will only - /// add to this map, not clear or remove from it, so the caller can pass in a non-empty - /// map. - /// \param[out] got_exception Set to true if any of the fetched objects contained an - /// exception. - /// \return Status::IOError if there is an error in communicating with the raylet or the - /// plasma store. - /// \return Status::OK if successful. + /// \param[in/out] remaining_object_id_to_idx map of object IDs to their indices left to + /// get. \param[in] ids IDs of the objects to get. \param[in] timeout_ms Timeout in + /// milliseconds. \param[out] results Map of objects to write results into. This method + /// will only add to this map, not clear or remove from it, so the caller can pass in a + /// non-empty map. \param[out] got_exception Set to true if any of the fetched objects + /// contained an exception. \return Status::IOError if there is an error in + /// communicating with the raylet or the plasma store. \return Status::OK if successful. Status GetObjectsFromPlasmaStore( - absl::flat_hash_set &remaining, + absl::flat_hash_map &remaining_object_id_to_idx, const std::vector &ids, int64_t timeout_ms, absl::flat_hash_map> *results, @@ -242,8 +239,9 @@ class CoreWorkerPlasmaStoreProvider { /// Print a warning if we've attempted the fetch for too long and some /// objects are still unavailable. - static void WarnIfFetchHanging(int64_t fetch_start_time_ms, - const absl::flat_hash_set &remaining); + static void WarnIfFetchHanging( + int64_t fetch_start_time_ms, + const absl::flat_hash_map &remaining_object_id_to_idx); /// Put something in the plasma store so that subsequent plasma store accesses /// will be faster. Currently the first access is always slow, and we don't @@ -253,8 +251,6 @@ class CoreWorkerPlasmaStoreProvider { const std::shared_ptr raylet_ipc_client_; std::shared_ptr store_client_; - /// Used to look up a plasma object's owner. - ReferenceCounterInterface &reference_counter_; std::function check_signals_; std::function get_current_call_site_; uint32_t object_store_full_delay_ms_; diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 0b26a89d7d24..b55328668315 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -561,7 +561,9 @@ StatusOr TaskManager::HandleTaskReturn(const ObjectID &object_id, // will choose the right raylet for any queued dependent tasks. reference_counter_.UpdateObjectPinnedAtRaylet(object_id, worker_node_id); // Mark it as in plasma with a dummy object. - in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id); + in_memory_store_.Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), + object_id, + reference_counter_.HasReference(object_id)); } else { // NOTE(swang): If a direct object was promoted to plasma, then we do not // record the node ID that it was pinned at, which means that we will not @@ -595,7 +597,7 @@ StatusOr TaskManager::HandleTaskReturn(const ObjectID &object_id, return s; } } else { - in_memory_store_.Put(object, object_id); + in_memory_store_.Put(object, object_id, reference_counter_.HasReference(object_id)); direct_return = true; } } @@ -764,7 +766,8 @@ void TaskManager::MarkEndOfStream(const ObjectID &generator_id, // Put a dummy object at the end of the stream. We don't need to check if // the object should be stored in plasma because the end of the stream is a // fake ObjectRef that should never be read by the application. - in_memory_store_.Put(error, last_object_id); + in_memory_store_.Put( + error, last_object_id, reference_counter_.HasReference(last_object_id)); } } @@ -1552,10 +1555,11 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(object_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, object_id); + in_memory_store_.Put( + error, object_id, reference_counter_.HasReference(object_id)); } } else { - in_memory_store_.Put(error, object_id); + in_memory_store_.Put(error, object_id, reference_counter_.HasReference(object_id)); } } if (spec.ReturnsDynamic()) { @@ -1565,10 +1569,13 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(dynamic_return_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, dynamic_return_id); + in_memory_store_.Put(error, + dynamic_return_id, + reference_counter_.HasReference(dynamic_return_id)); } } else { - in_memory_store_.Put(error, dynamic_return_id); + in_memory_store_.Put( + error, dynamic_return_id, reference_counter_.HasReference(dynamic_return_id)); } } } @@ -1595,10 +1602,14 @@ void TaskManager::MarkTaskReturnObjectsFailed( if (!s.ok()) { RAY_LOG(WARNING).WithField(generator_return_id) << "Failed to put error object in plasma: " << s; - in_memory_store_.Put(error, generator_return_id); + in_memory_store_.Put(error, + generator_return_id, + reference_counter_.HasReference(generator_return_id)); } } else { - in_memory_store_.Put(error, generator_return_id); + in_memory_store_.Put(error, + generator_return_id, + reference_counter_.HasReference(generator_return_id)); } } } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index bb450413817c..0699a96c198b 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -27,6 +27,7 @@ #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_event_buffer.h" #include "ray/core_worker/task_manager_interface.h" diff --git a/src/ray/core_worker/task_submission/BUILD.bazel b/src/ray/core_worker/task_submission/BUILD.bazel index 387fba21552c..e8b9cd096993 100644 --- a/src/ray/core_worker/task_submission/BUILD.bazel +++ b/src/ray/core_worker/task_submission/BUILD.bazel @@ -70,6 +70,7 @@ ray_cc_library( "//src/ray/common:id", "//src/ray/common:protobuf_utils", "//src/ray/core_worker:actor_creator", + "//src/ray/core_worker:reference_counter_interface", "//src/ray/core_worker_rpc_client:core_worker_client_pool", "//src/ray/rpc:rpc_callback_types", "//src/ray/util:time", diff --git a/src/ray/core_worker/task_submission/actor_task_submitter.h b/src/ray/core_worker/task_submission/actor_task_submitter.h index f225397768be..ada84d9a0d25 100644 --- a/src/ray/core_worker/task_submission/actor_task_submitter.h +++ b/src/ray/core_worker/task_submission/actor_task_submitter.h @@ -26,6 +26,7 @@ #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/core_worker/actor_creator.h" +#include "ray/core_worker/reference_counter_interface.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_submission/actor_submit_queue.h" #include "ray/core_worker/task_submission/dependency_resolver.h" diff --git a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc index e1536ef89785..54ccf6a30962 100644 --- a/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/actor_task_submitter_test.cc @@ -232,11 +232,11 @@ TEST_P(ActorTaskSubmitterTest, TestDependencies) { auto data = GenerateRandomObject(); // Each Put schedules a callback onto io_context, and let's run it. - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 1); - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); @@ -280,12 +280,12 @@ TEST_P(ActorTaskSubmitterTest, TestOutOfOrderDependencies) { // submission. auto data = GenerateRandomObject(); // task2 is submitted first as we allow out of order execution. - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 1); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1)); // then task1 is submitted - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(1, 0)); @@ -293,10 +293,10 @@ TEST_P(ActorTaskSubmitterTest, TestOutOfOrderDependencies) { // Put the dependencies in the store in the opposite order of task // submission. auto data = GenerateRandomObject(); - store_->Put(*data, obj2); + store_->Put(*data, obj2, reference_counter_->HasReference(obj2)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 0); - store_->Put(*data, obj1); + store_->Put(*data, obj1, reference_counter_->HasReference(obj1)); ASSERT_EQ(io_context.poll_one(), 1); ASSERT_EQ(worker_client_->callbacks.size(), 2); ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(0, 1)); diff --git a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc index e9766aec1281..28c87febf837 100644 --- a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc +++ b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc @@ -187,7 +187,7 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - store->Put(data, obj); + store->Put(data, obj, /*has_reference=*/true); // Wait for the async callback to call ASSERT_TRUE(dependencies_resolved.get_future().get()); ASSERT_EQ(num_resolved, 1); @@ -228,7 +228,7 @@ TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); ASSERT_EQ(num_resolved, 0); - store->Put(data, obj); + store->Put(data, obj, /*has_reference=*/true); for (const auto &cb : actor_creator.callbacks) { cb(Status()); @@ -253,7 +253,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - store->Put(data, obj1); + store->Put(data, obj1, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); bool ok = false; @@ -282,8 +282,8 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { ObjectID obj2 = ObjectID::FromRandom(); auto data = GenerateRandomObject(); // Ensure the data is already present in the local store. - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary()); @@ -326,8 +326,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); ASSERT_TRUE(dependencies_resolved.get_future().get()); // Tests that the task proto was rewritten to have inline argument values after @@ -365,8 +365,8 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); ASSERT_TRUE(dependencies_resolved.get_future().get()); // Tests that the task proto was rewritten to have inline argument values after @@ -400,7 +400,7 @@ TEST(LocalDependencyResolverTest, TestCancelDependencyResolution) { resolver.ResolveDependencies(task, [&ok](Status) { ok = true; }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - store->Put(*data, obj1); + store->Put(*data, obj1, /*has_reference=*/true); ASSERT_TRUE(resolver.CancelDependencyResolution(task.TaskId())); // Callback is not called. @@ -428,7 +428,7 @@ TEST(LocalDependencyResolverTest, TestDependenciesAlreadyLocal) { ObjectID obj = ObjectID::FromRandom(); auto data = GenerateRandomObject(); - store->Put(*data, obj); + store->Put(*data, obj, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); @@ -471,8 +471,8 @@ TEST(LocalDependencyResolverTest, TestMixedTensorTransport) { }); auto data = GenerateRandomObject(); - store->Put(*data, obj1); - store->Put(*data, obj2); + store->Put(*data, obj1, /*has_reference=*/true); + store->Put(*data, obj2, /*has_reference=*/true); TaskSpecification task; task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); diff --git a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc index b59ada5fa33f..b5325598f788 100644 --- a/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc +++ b/src/ray/core_worker/task_submission/tests/normal_task_submitter_test.cc @@ -1517,6 +1517,7 @@ void TestSchedulingKey(const std::shared_ptr store, TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { InstrumentedIOContextWithThread io_context("TestSchedulingKeys"); + // Mock reference counter as enabled auto memory_store = std::make_shared(io_context.GetIoService()); std::unordered_map resources1({{"a", 1.0}}); @@ -1560,16 +1561,16 @@ TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { ObjectID plasma2 = ObjectID::FromRandom(); // Ensure the data is already present in the local store for direct call objects. auto data = GenerateRandomObject(); - memory_store->Put(*data, direct1); - memory_store->Put(*data, direct2); + memory_store->Put(*data, direct1, /*has_reference=*/true); + memory_store->Put(*data, direct2, /*has_reference=*/true); // Force plasma objects to be promoted. std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store->Put(plasma_data, plasma1); - memory_store->Put(plasma_data, plasma2); + memory_store->Put(plasma_data, plasma1, /*has_reference=*/true); + memory_store->Put(plasma_data, plasma2, /*has_reference=*/true); TaskSpecification same_deps_1 = BuildTaskSpec(resources1, descriptor1); same_deps_1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( @@ -1600,6 +1601,7 @@ TEST(NormalTaskSubmitterSchedulingKeyTest, TestSchedulingKeys) { TEST_F(NormalTaskSubmitterTest, TestBacklogReport) { InstrumentedIOContextWithThread store_io_context("TestBacklogReport"); + // Mock reference counter as enabled auto memory_store = std::make_shared(store_io_context.GetIoService()); auto submitter = @@ -1623,8 +1625,8 @@ TEST_F(NormalTaskSubmitterTest, TestBacklogReport) { auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store->Put(plasma_data, plasma1); - memory_store->Put(plasma_data, plasma2); + memory_store->Put(plasma_data, plasma1, /*has_reference=*/true); + memory_store->Put(plasma_data, plasma2, /*has_reference=*/true); // Same SchedulingClass, different SchedulingKey TaskSpecification task2 = BuildTaskSpec(resources1, descriptor1); @@ -1791,7 +1793,7 @@ TEST_F(NormalTaskSubmitterTest, TestKillResolvingTask) { ASSERT_EQ(task_manager->num_inlined_dependencies, 0); submitter.CancelTask(task, true, false); auto data = GenerateRandomObject(); - store->Put(*data, obj1); + store->Put(*data, obj1, /*has_reference=*/true); WaitForObjectIdInMemoryStore(*store, obj1); ASSERT_EQ(worker_client->kill_requests.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0); diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index 8219dc6fe639..e1cabcbd2c8b 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -155,8 +155,9 @@ class CoreWorkerTest : public ::testing::Test { fake_owned_object_size_gauge_, false); + // Mock reference counter as enabled memory_store_ = std::make_shared( - io_service_, reference_counter_.get(), nullptr); + io_service_, reference_counter_ != nullptr, nullptr); auto future_resolver = std::make_unique( memory_store_, @@ -349,7 +350,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusIdempotency) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -417,7 +418,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectPutAfterFirstRequest) { // Verify that the callback hasn't been called yet since the object doesn't exist ASSERT_FALSE(io_service_.poll_one()); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); io_service_.run_one(); @@ -454,7 +455,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectFreedBetweenRequests) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -504,7 +505,7 @@ TEST_F(CoreWorkerTest, HandleGetObjectStatusObjectOutOfScope) { owner_address.set_worker_id(core_worker_->GetWorkerID().Binary()); reference_counter_->AddOwnedObject(object_id, {}, owner_address, "", 0, false, true); - memory_store_->Put(*ray_object, object_id); + memory_store_->Put(*ray_object, object_id, reference_counter_->HasReference(object_id)); rpc::GetObjectStatusRequest request; request.set_object_id(object_id.Binary()); @@ -569,7 +570,9 @@ ObjectID CreateInlineObjectInMemoryStoreAndRefCounter( /*object_size=*/100, /*is_reconstructable=*/false, /*add_local_ref=*/true); - memory_store.Put(memory_store_object, inlined_dependency_id); + memory_store.Put(memory_store_object, + inlined_dependency_id, + reference_counter.HasReference(inlined_dependency_id)); return inlined_dependency_id; } } // namespace @@ -654,7 +657,6 @@ TEST(BatchingPassesTwoTwoOneIntoPlasmaGet, CallsPlasmaGetInCorrectBatches) { CoreWorkerPlasmaStoreProvider provider( /*store_socket=*/"", fake_raylet, - ref_counter, /*check_signals=*/[] { return Status::OK(); }, /*warmup=*/false, /*store_client=*/fake_plasma, @@ -664,11 +666,11 @@ TEST(BatchingPassesTwoTwoOneIntoPlasmaGet, CallsPlasmaGetInCorrectBatches) { // Build a set of 5 object ids. std::vector ids; for (int i = 0; i < 5; i++) ids.push_back(ObjectID::FromRandom()); - absl::flat_hash_set idset(ids.begin(), ids.end()); + const auto owner_addresses = ref_counter.GetOwnerAddresses(ids); absl::flat_hash_map> results; - ASSERT_TRUE(provider.Get(idset, /*timeout_ms=*/-1, &results).ok()); + ASSERT_TRUE(provider.Get(ids, owner_addresses, /*timeout_ms=*/-1, &results).ok()); // Assert: batches seen by plasma Get are [2,2,1]. ASSERT_EQ(observed_batches.size(), 3U); diff --git a/src/ray/core_worker/tests/memory_store_test.cc b/src/ray/core_worker/tests/memory_store_test.cc index 5a90b26af481..30f9ac914409 100644 --- a/src/ray/core_worker/tests/memory_store_test.cc +++ b/src/ray/core_worker/tests/memory_store_test.cc @@ -53,7 +53,7 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { std::shared_ptr memory_store = std::make_shared( io_context.GetIoService(), - nullptr, + /*reference_counting_enabled=*/true, nullptr, nullptr, [&](const RayObject &obj) { unhandled_count++; }); @@ -64,8 +64,9 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { // Check basic put and get. ASSERT_TRUE(memory_store->GetIfExists(id1) == nullptr); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); + // Set has_reference to true to ensure the put doesn't get deleted due to no reference. + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); ASSERT_TRUE(memory_store->GetIfExists(id1) != nullptr); ASSERT_EQ(unhandled_count, 0); @@ -75,17 +76,17 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) { unhandled_count = 0; // Check delete after get. - memory_store->Put(obj1, id1); - memory_store->Put(obj1, id2); - RAY_UNUSED(memory_store->Get({id1}, 1, 100, context, false, &results)); - RAY_UNUSED(memory_store->Get({id2}, 1, 100, context, false, &results)); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj1, id2, /*has_reference=*/true); + RAY_UNUSED(memory_store->Get({id1}, 1, 100, context, &results)); + RAY_UNUSED(memory_store->Get({id2}, 1, 100, context, &results)); memory_store->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); // Check delete after async get. memory_store->GetAsync({id2}, [](std::shared_ptr obj) {}); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); memory_store->GetAsync({id1}, [](std::shared_ptr obj) {}); memory_store->Delete({id1, id2}); ASSERT_EQ(unhandled_count, 0); @@ -118,9 +119,9 @@ TEST(TestMemoryStore, TestMemoryStoreStats) { auto id2 = ObjectID::FromRandom(); auto id3 = ObjectID::FromRandom(); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); - memory_store->Put(obj3, id3); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); + memory_store->Put(obj3, id3, /*has_reference=*/true); memory_store->Delete({id3}); MemoryStoreStats expected_item; @@ -140,9 +141,9 @@ TEST(TestMemoryStore, TestMemoryStoreStats) { ASSERT_EQ(item.num_local_objects, expected_item2.num_local_objects); ASSERT_EQ(item.num_local_objects_bytes, expected_item2.num_local_objects_bytes); - memory_store->Put(obj1, id1); - memory_store->Put(obj2, id2); - memory_store->Put(obj3, id3); + memory_store->Put(obj1, id1, /*has_reference=*/true); + memory_store->Put(obj2, id2, /*has_reference=*/true); + memory_store->Put(obj3, id3, /*has_reference=*/true); MemoryStoreStats expected_item3; fill_expected_memory_stats(expected_item3); item = memory_store->GetMemoryStoreStatisticalData(); @@ -208,7 +209,7 @@ TEST(TestMemoryStore, TestObjectAllocator) { std::shared_ptr memory_store = std::make_shared(io_context.GetIoService(), - nullptr, + /*reference_counting_enabled=*/true, nullptr, nullptr, nullptr, @@ -220,7 +221,7 @@ TEST(TestMemoryStore, TestObjectAllocator) { std::vector nested_refs; auto hello_object = std::make_shared(hello_buffer, nullptr, nested_refs, true); - memory_store->Put(*hello_object, ObjectID::FromRandom()); + memory_store->Put(*hello_object, ObjectID::FromRandom(), /*has_reference=*/true); } ASSERT_EQ(max_rounds * hello.size(), mock_buffer_manager.GetBuferPressureInBytes()); } @@ -259,10 +260,10 @@ TEST_F(TestMemoryStoreWait, TestWaitNoWaiting) { absl::flat_hash_set object_ids_set = {object_ids.begin(), object_ids.end()}; int num_objects = 2; - memory_store->Put(memory_store_object, object_ids[0]); - memory_store->Put(plasma_store_object, object_ids[1]); - memory_store->Put(plasma_store_object, object_ids[2]); - memory_store->Put(memory_store_object, object_ids[3]); + memory_store->Put(memory_store_object, object_ids[0], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[1], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[2], /*has_reference=*/true); + memory_store->Put(memory_store_object, object_ids[3], /*has_reference=*/true); absl::flat_hash_set ready, plasma_object_ids; const auto status = memory_store->Wait( @@ -289,8 +290,8 @@ TEST_F(TestMemoryStoreWait, TestWaitWithWaiting) { absl::flat_hash_set object_ids_set = {object_ids.begin(), object_ids.end()}; int num_objects = 4; - memory_store->Put(memory_store_object, object_ids[0]); - memory_store->Put(plasma_store_object, object_ids[1]); + memory_store->Put(memory_store_object, object_ids[0], /*has_reference=*/true); + memory_store->Put(plasma_store_object, object_ids[1], /*has_reference=*/true); absl::flat_hash_set ready, plasma_object_ids; auto future = std::async(std::launch::async, [&]() { @@ -298,9 +299,9 @@ TEST_F(TestMemoryStoreWait, TestWaitWithWaiting) { object_ids_set, num_objects, 100, ctx, &ready, &plasma_object_ids); }); ASSERT_EQ(future.wait_for(std::chrono::milliseconds(1)), std::future_status::timeout); - memory_store->Put(plasma_store_object, object_ids[2]); + memory_store->Put(plasma_store_object, object_ids[2], /*has_reference=*/true); ASSERT_EQ(future.wait_for(std::chrono::milliseconds(1)), std::future_status::timeout); - memory_store->Put(memory_store_object, object_ids[3]); + memory_store->Put(memory_store_object, object_ids[3], /*has_reference=*/true); const auto status = future.get(); @@ -316,7 +317,7 @@ TEST_F(TestMemoryStoreWait, TestWaitTimeout) { // object 0 in plasma // waits until 10ms timeout for 2 objects absl::flat_hash_set object_ids_set = {ObjectID::FromRandom()}; - memory_store->Put(plasma_store_object, *object_ids_set.begin()); + memory_store->Put(plasma_store_object, *object_ids_set.begin(), /*has_reference=*/true); int num_objects = 2; absl::flat_hash_set ready, plasma_object_ids; diff --git a/src/ray/core_worker/tests/object_recovery_manager_test.cc b/src/ray/core_worker/tests/object_recovery_manager_test.cc index 8aef24f5b9a4..493de35793f2 100644 --- a/src/ray/core_worker/tests/object_recovery_manager_test.cc +++ b/src/ray/core_worker/tests/object_recovery_manager_test.cc @@ -162,7 +162,7 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { std::make_shared(metadata, meta.size()); auto data = RayObject(nullptr, meta_buffer, std::vector()); - memory_store_->Put(data, object_id); + memory_store_->Put(data, object_id, ref_counter_->HasReference(object_id)); }) { ref_counter_->SetReleaseLineageCallback( [](const ObjectID &, std::vector *args) { return 0; }); diff --git a/src/ray/core_worker/tests/reference_counter_test.cc b/src/ray/core_worker/tests/reference_counter_test.cc index 7b56bb39a7ba..efaac61e9cf7 100644 --- a/src/ray/core_worker/tests/reference_counter_test.cc +++ b/src/ray/core_worker/tests/reference_counter_test.cc @@ -848,15 +848,15 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { *owned_object_count_metric, *owned_object_size_metric); InstrumentedIOContextWithThread io_context("TestSimple"); - CoreWorkerMemoryStore store(io_context.GetIoService(), rc.get()); + CoreWorkerMemoryStore store(io_context.GetIoService()); // Tests putting an object with no references is ignored. - store.Put(buffer, id2); + store.Put(buffer, id2, rc->HasReference(id2)); ASSERT_EQ(store.Size(), 0); - // Tests ref counting overrides remove after get option. + // Tests that objects with references remain in the store after Get. rc->AddLocalReference(id1, ""); - store.Put(buffer, id1); + store.Put(buffer, id1, rc->HasReference(id1)); ASSERT_EQ(store.Size(), 1); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); @@ -864,7 +864,6 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { /*num_objects*/ 1, /*timeout_ms*/ -1, ctx, - /*remove_after_get*/ true, &results)); ASSERT_EQ(results.size(), 1); ASSERT_EQ(store.Size(), 1); diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index 353a4b61fa5b..14f80e4211c0 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -157,8 +157,7 @@ class TaskManagerTest : public ::testing::Test { *std::make_shared(), lineage_pinning_enabled)), io_context_("TaskManagerTest"), - store_(std::make_shared(io_context_.GetIoService(), - reference_counter_.get())), + store_(std::make_shared(io_context_.GetIoService())), manager_( *store_, *reference_counter_, @@ -285,7 +284,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_FALSE(results[0]->IsException()); ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), @@ -322,7 +321,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -364,7 +363,7 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) { std::vector> results; // Caller of FlushObjectsToRecover is responsible for deleting the object // from the in-memory store and recovering the object. - ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, &results).ok()); auto objects_to_recover = reference_counter_->FlushObjectsToRecover(); ASSERT_EQ(objects_to_recover.size(), 1); ASSERT_EQ(objects_to_recover[0], return_id); @@ -392,7 +391,7 @@ TEST_F(TaskManagerTest, TestFailPendingTask) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -416,7 +415,7 @@ TEST_F(TaskManagerTest, TestFailPendingTaskAfterCancellation) { // Check that the error type is set to TASK_CANCELLED std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({spec.ReturnId(0)}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({spec.ReturnId(0)}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -446,7 +445,7 @@ TEST_F(TaskManagerTest, TestTaskReconstruction) { ASSERT_TRUE(reference_counter_->IsObjectPendingCreation(return_id)); ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); std::vector> results; - ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, &results).ok()); ASSERT_EQ(num_retries_, i + 1); ASSERT_EQ(last_delay_ms_, RayConfig::instance().task_retry_delay_ms()); } @@ -458,7 +457,7 @@ TEST_F(TaskManagerTest, TestTaskReconstruction) { ASSERT_FALSE(reference_counter_->IsObjectPendingCreation(return_id)); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -487,7 +486,7 @@ TEST_F(TaskManagerTest, TestTaskKill) { manager_.FailOrRetryPendingTask(spec.TaskId(), error); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); std::vector> results; - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -539,7 +538,7 @@ TEST_F(TaskManagerTest, TestTaskOomKillNoOomRetryFailsImmediately) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -559,7 +558,7 @@ TEST_F(TaskManagerTest, TestTaskOomKillNoOomRetryFailsImmediately) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -595,7 +594,7 @@ TEST_F(TaskManagerTest, TestTaskOomAndNonOomKillReturnsLastError) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -637,7 +636,7 @@ TEST_F(TaskManagerTest, TestTaskNotRetriableOomFailsImmediatelyEvenWithOomRetryC std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -664,7 +663,7 @@ TEST_F(TaskManagerTest, TestFailsImmediatelyOverridesRetry) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -688,7 +687,7 @@ TEST_F(TaskManagerTest, TestFailsImmediatelyOverridesRetry) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); @@ -1282,7 +1281,7 @@ TEST_F(TaskManagerLineageTest, TestDynamicReturnsTask) { WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); std::vector> results; - RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, &results)); ASSERT_EQ(results.size(), 3); for (int i = 0; i < 3; i++) { ASSERT_TRUE(results[i]->IsInPlasmaError()); @@ -1360,10 +1359,10 @@ TEST_F(TaskManagerLineageTest, TestResubmittedDynamicReturnsTaskFails) { // No error stored for the generator ID, which should have gone out of scope. WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); std::vector> results; - ASSERT_FALSE(store_->Get({generator_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_FALSE(store_->Get({generator_id}, 1, 0, ctx, &results).ok()); // The internal ObjectRefs have the right error. - RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get(dynamic_return_ids, 3, -1, ctx, &results)); ASSERT_EQ(results.size(), 3); for (int i = 0; i < 3; i++) { rpc::ErrorType stored_error; @@ -1383,8 +1382,7 @@ TEST_F(TaskManagerTest, PlasmaPut_ObjectStoreFull_FailsTaskAndWritesError) { *std::make_shared(), *std::make_shared(), lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager failing_mgr( *local_store, @@ -1431,7 +1429,7 @@ TEST_F(TaskManagerTest, PlasmaPut_ObjectStoreFull_FailsTaskAndWritesError) { ASSERT_FALSE(failing_mgr.IsTaskPending(spec.TaskId())); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_TRUE(results[0]->IsException()); } @@ -1446,8 +1444,7 @@ TEST_F(TaskManagerTest, PlasmaPut_TransientFull_RetriesThenSucceeds) { *std::make_shared(), *std::make_shared(), lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager retry_mgr( *local_store, *local_ref_counter, @@ -1495,7 +1492,7 @@ TEST_F(TaskManagerTest, PlasmaPut_TransientFull_RetriesThenSucceeds) { std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, false, &results)); + RAY_CHECK_OK(local_store->Get({return_id}, 1, 0, ctx, &results)); ASSERT_EQ(results.size(), 1); ASSERT_TRUE(results[0]->IsInPlasmaError()); } @@ -1510,8 +1507,7 @@ TEST_F(TaskManagerTest, DynamicReturn_PlasmaPutFailure_FailsTaskImmediately) { *std::make_shared(), *std::make_shared(), lineage_pinning_enabled_); - auto local_store = std::make_shared(io_context_.GetIoService(), - local_ref_counter.get()); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager dyn_mgr( *local_store, *local_ref_counter, @@ -2043,7 +2039,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); // Make sure you can read. @@ -2072,7 +2068,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { // NumObjectIDsInScope == Generator + 2 intermediate result. results.clear(); - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); // Make sure you can read. @@ -2135,10 +2131,10 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { ASSERT_EQ(store_->Size(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2148,9 +2144,9 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { // All the in memory objects should be cleaned up. The generator ref returns // a direct result that would be GCed once it goes out of scope. ASSERT_EQ(store_->Size(), 1); - ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // Clean up the generator ID. Now all lineage is safe to remove. @@ -2182,7 +2178,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); // All the in memory objects should be cleaned up. ASSERT_EQ(store_->Size(), 1); - ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); } @@ -2236,7 +2232,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { ASSERT_EQ(store_->Size(), 2); std::vector> results; WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2247,7 +2243,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { ASSERT_EQ(obj_id, dynamic_return_id); // Write one ref that will stay unconsumed. - RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results)); ASSERT_EQ(results.size(), 1); results.clear(); @@ -2257,7 +2253,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed objects should be cleaned up. The generator ref returns // a direct result that would be GCed once it goes out of scope. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // Clean up the generator ID. @@ -2271,7 +2267,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed in memory objects should be cleaned up. Check for 2 // in-memory objects: one consumed object ref and the generator ref. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); // NOTE: We panic if READ is called after DELETE. The @@ -2296,7 +2292,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferencesLineageInScope) { // All the unconsumed in memory objects should be cleaned up. Check for 2 // in-memory objects: one consumed object ref and the generator ref. ASSERT_EQ(store_->Size(), 2); - ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); + ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, &results).IsTimedOut()); results.clear(); } @@ -3007,8 +3003,7 @@ TEST_F(TaskManagerTest, TestRetryErrorMessageSentToCallback) { *std::make_shared(), *std::make_shared(), false); - auto local_store = std::make_shared( - io_context_.GetIoService(), local_reference_counter.get()); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager test_manager( *local_store, @@ -3090,8 +3085,7 @@ TEST_F(TaskManagerTest, TestErrorLogWhenPushErrorCallbackFails) { *std::make_shared(), *std::make_shared(), false); - auto local_store = std::make_shared( - io_context_.GetIoService(), local_reference_counter.get()); + auto local_store = std::make_shared(io_context_.GetIoService()); TaskManager test_manager( *local_store,