diff --git a/.travis.yml b/.travis.yml index 7402e6fb6e78..2266833f8414 100644 --- a/.travis.yml +++ b/.travis.yml @@ -148,6 +148,9 @@ install: - ./ci/suppress_output bazel build //:stats_test -c opt - ./bazel-bin/stats_test + # core worker test. + - ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh + # Raylet tests. - ./ci/suppress_output bash src/ray/test/run_object_manager_tests.sh - ./ci/suppress_output bazel test --build_tests_only --test_lang_filters=cc //:all diff --git a/BUILD.bazel b/BUILD.bazel index 0bdbe5741cf8..2b75d5d04e77 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -77,6 +77,7 @@ cc_library( "src/ray/raylet/mock_gcs_client.cc", "src/ray/raylet/monitor_main.cc", "src/ray/raylet/*_test.cc", + "src/ray/raylet/main.cc", ], ), hdrs = glob([ @@ -122,15 +123,18 @@ cc_library( deps = [ ":ray_common", ":ray_util", + ":raylet_lib", ], ) -cc_test( +# This test is run by src/ray/test/run_core_worker_tests.sh +cc_binary( name = "core_worker_test", srcs = ["src/ray/core_worker/core_worker_test.cc"], copts = COPTS, deps = [ ":core_worker_lib", + ":gcs", "@com_google_googletest//:gtest_main", ], ) @@ -320,6 +324,7 @@ cc_library( ":node_manager_fbs", ":ray_util", "@boost//:asio", + "@plasma//:plasma_client", ], ) diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h index 358d903799c7..4340c74a8d38 100644 --- a/src/ray/common/buffer.h +++ b/src/ray/common/buffer.h @@ -3,6 +3,11 @@ #include #include +#include "plasma/client.h" + +namespace arrow { +class Buffer; +} namespace ray { @@ -15,7 +20,7 @@ class Buffer { /// Size of this buffer. virtual size_t Size() const = 0; - virtual ~Buffer() {} + virtual ~Buffer(){}; bool operator==(const Buffer &rhs) const { return this->Data() == rhs.Data() && this->Size() == rhs.Size(); @@ -40,6 +45,21 @@ class LocalMemoryBuffer : public Buffer { size_t size_; }; +/// Represents a byte buffer for plasma object. +class PlasmaBuffer : public Buffer { + public: + PlasmaBuffer(std::shared_ptr buffer) : buffer_(buffer) {} + + uint8_t *Data() const override { return const_cast(buffer_->data()); } + + size_t Size() const override { return buffer_->size(); } + + private: + /// shared_ptr to arrow buffer which can potentially hold a reference + /// for the object (when it's a plasma::PlasmaBuffer). + std::shared_ptr buffer_; +}; + } // namespace ray #endif // RAY_COMMON_BUFFER_H diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index b53c35b25fa8..ad9485e53826 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -45,13 +45,13 @@ class TaskArg { bool IsPassedByReference() const { return id_ != nullptr; } /// Get the reference object ID. - ObjectID &GetReference() { + const ObjectID &GetReference() const { RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference."; return *id_; } /// Get the value. - std::shared_ptr GetValue() { + std::shared_ptr GetValue() const { RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value."; return data_; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc new file mode 100644 index 000000000000..fedcfc6625d9 --- /dev/null +++ b/src/ray/core_worker/context.cc @@ -0,0 +1,81 @@ + +#include "context.h" + +namespace ray { + +/// per-thread context for core worker. +struct WorkerThreadContext { + WorkerThreadContext() + : current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {} + + int GetNextTaskIndex() { return ++task_index; } + + int GetNextPutIndex() { return ++put_index; } + + const TaskID &GetCurrentTaskID() const { return current_task_id; } + + void SetCurrentTask(const TaskID &task_id) { + current_task_id = task_id; + task_index = 0; + put_index = 0; + } + + void SetCurrentTask(const raylet::TaskSpecification &spec) { + SetCurrentTask(spec.TaskId()); + } + + private: + /// The task ID for current task. + TaskID current_task_id; + + /// Number of tasks that have been submitted from current task. + int task_index; + + /// Number of objects that have been put from current task. + int put_index; +}; + +thread_local std::unique_ptr WorkerContext::thread_context_ = + nullptr; + +WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id) + : worker_type(worker_type), + worker_id(worker_type == WorkerType::DRIVER + ? ClientID::FromBinary(driver_id.Binary()) + : ClientID::FromRandom()), + current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) { + // For worker main thread which initializes the WorkerContext, + // set task_id according to whether current worker is a driver. + // (For other threads it's set to randmom ID via GetThreadContext). + GetThreadContext().SetCurrentTask( + (worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); +} + +const WorkerType WorkerContext::GetWorkerType() const { return worker_type; } + +const ClientID &WorkerContext::GetWorkerID() const { return worker_id; } + +int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); } + +int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); } + +const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; } + +const TaskID &WorkerContext::GetCurrentTaskID() const { + return GetThreadContext().GetCurrentTaskID(); +} + +void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) { + current_driver_id = spec.DriverId(); + GetThreadContext().SetCurrentTask(spec); +} + +WorkerThreadContext &WorkerContext::GetThreadContext() { + if (thread_context_ == nullptr) { + thread_context_ = std::unique_ptr(new WorkerThreadContext()); + } + + return *thread_context_; +} + +} // namespace ray diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h new file mode 100644 index 000000000000..6e0cf3f9f2cf --- /dev/null +++ b/src/ray/core_worker/context.h @@ -0,0 +1,48 @@ +#ifndef RAY_CORE_WORKER_CONTEXT_H +#define RAY_CORE_WORKER_CONTEXT_H + +#include "common.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +struct WorkerThreadContext; + +class WorkerContext { + public: + WorkerContext(WorkerType worker_type, const DriverID &driver_id); + + const WorkerType GetWorkerType() const; + + const ClientID &GetWorkerID() const; + + const DriverID &GetCurrentDriverID() const; + + const TaskID &GetCurrentTaskID() const; + + void SetCurrentTask(const raylet::TaskSpecification &spec); + + int GetNextTaskIndex(); + + int GetNextPutIndex(); + + private: + /// Type of the worker. + const WorkerType worker_type; + + /// ID for this worker. + const ClientID worker_id; + + /// Driver ID for this worker. + DriverID current_driver_id; + + private: + static WorkerThreadContext &GetThreadContext(); + + /// Per-thread worker context. + static thread_local std::unique_ptr thread_context_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_CONTEXT_H diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc new file mode 100644 index 000000000000..82f2d885ec58 --- /dev/null +++ b/src/ray/core_worker/core_worker.cc @@ -0,0 +1,39 @@ +#include "core_worker.h" +#include "context.h" + +namespace ray { + +CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language, + const std::string &store_socket, const std::string &raylet_socket, + DriverID driver_id) + : worker_type_(worker_type), + language_(language), + worker_context_(worker_type, driver_id), + store_socket_(store_socket), + raylet_socket_(raylet_socket), + task_interface_(*this), + object_interface_(*this), + task_execution_interface_(*this) {} + +Status CoreWorker::Connect() { + // connect to plasma. + RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_)); + + // connect to raylet. + ::Language lang = ::Language::PYTHON; + if (language_ == ray::Language::JAVA) { + lang = ::Language::JAVA; + } + + // TODO: currently RayletClient would crash in its constructor if it cannot + // connect to Raylet after a number of retries, this needs to be changed + // so that the worker (java/python .etc) can retrieve and handle the error + // instead of crashing. + raylet_client_ = std::unique_ptr( + new RayletClient(raylet_socket_, worker_context_.GetWorkerID(), + (worker_type_ == ray::WorkerType::WORKER), + worker_context_.GetCurrentDriverID(), lang)); + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 96e51dbc4532..951b55451f09 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -2,8 +2,10 @@ #define RAY_CORE_WORKER_CORE_WORKER_H #include "common.h" +#include "context.h" #include "object_interface.h" #include "ray/common/buffer.h" +#include "ray/raylet/raylet_client.h" #include "task_execution.h" #include "task_interface.h" @@ -18,15 +20,12 @@ class CoreWorker { /// /// \param[in] worker_type Type of this worker. /// \param[in] langauge Language of this worker. - CoreWorker(const WorkerType worker_type, const Language language) - : worker_type_(worker_type), - language_(language), - task_interface_(*this), - object_interface_(*this), - task_execution_interface_(*this) {} + CoreWorker(const WorkerType worker_type, const Language language, + const std::string &store_socket, const std::string &raylet_socket, + DriverID driver_id = DriverID::Nil()); - /// Connect this worker to Raylet. - Status Connect() { return Status::OK(); } + /// Connect to raylet. + Status Connect(); /// Type of this worker. enum WorkerType WorkerType() const { return worker_type_; } @@ -53,6 +52,21 @@ class CoreWorker { /// Language of this worker. const enum Language language_; + /// Worker context per thread. + WorkerContext worker_context_; + + /// Plasma store socket name. + std::string store_socket_; + + /// raylet socket name. + std::string raylet_socket_; + + /// Plasma store client. + plasma::PlasmaClient store_client_; + + /// Raylet client. + std::unique_ptr raylet_client_; + /// The `CoreWorkerTaskInterface` instance. CoreWorkerTaskInterface task_interface_; @@ -61,6 +75,10 @@ class CoreWorker { /// The `CoreWorkerTaskExecutionInterface` instance. CoreWorkerTaskExecutionInterface task_execution_interface_; + + friend class CoreWorkerTaskInterface; + friend class CoreWorkerObjectInterface; + friend class CoreWorkerTaskExecutionInterface; }; } // namespace ray diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index 6711c874a973..e440aae24d67 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -1,20 +1,137 @@ +#include #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "context.h" #include "core_worker.h" #include "ray/common/buffer.h" +#include "ray/raylet/raylet_client.h" + +#include +#include +#include + +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" namespace ray { +std::string store_executable; +std::string raylet_executable; + +ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } + +static void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); + redisFree(context); +} + class CoreWorkerTest : public ::testing::Test { public: - CoreWorkerTest() : core_worker_(WorkerType::WORKER, Language::PYTHON) {} + CoreWorkerTest(int num_nodes) { + RAY_CHECK(num_nodes >= 0); + if (num_nodes > 0) { + raylet_socket_names_.resize(num_nodes); + raylet_store_socket_names_.resize(num_nodes); + } + + // start plasma store. + for (auto &store_socket : raylet_store_socket_names_) { + store_socket = StartStore(); + } + + // start raylet on each node + for (int i = 0; i < num_nodes; i++) { + raylet_socket_names_[i] = StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", + "127.0.0.1", "\"CPU,4.0\""); + } + } + + ~CoreWorkerTest() { + for (const auto &raylet_socket : raylet_socket_names_) { + StopRaylet(raylet_socket); + } + + for (const auto &store_socket : raylet_store_socket_names_) { + StopStore(store_socket); + } + } + + std::string StartStore() { + std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex(); + std::string store_pid = store_socket_name + ".pid"; + std::string plasma_command = store_executable + " -m 10000000 -s " + + store_socket_name + + " 1> /dev/null 2> /dev/null & echo $! > " + store_pid; + RAY_LOG(INFO) << plasma_command; + RAY_CHECK(system(plasma_command.c_str()) == 0); + usleep(200 * 1000); + return store_socket_name; + } + + void StopStore(std::string store_socket_name) { + std::string store_pid = store_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + store_pid + "`"; + RAY_LOG(INFO) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + store_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + store_socket_name + ".pid").c_str()) == 0); + } + + std::string StartRaylet(std::string store_socket_name, std::string node_ip_address, + std::string redis_address, std::string resource) { + std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex(); + std::string ray_start_cmd = raylet_executable; + ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name) + .append(" --store_socket_name=" + store_socket_name) + .append(" --object_manager_port=0 --node_manager_port=0") + .append(" --node_ip_address=" + node_ip_address) + .append(" --redis_address=" + redis_address) + .append(" --redis_port=6379") + .append(" --num_initial_workers=0") + .append(" --maximum_startup_concurrency=10") + .append(" --static_resource_list=" + resource) + .append(" --python_worker_command=NoneCmd") + .append(" & echo $! > " + raylet_socket_name + ".pid"); + + RAY_LOG(INFO) << "Ray Start command: " << ray_start_cmd; + RAY_CHECK(system(ray_start_cmd.c_str()) == 0); + usleep(200 * 1000); + return raylet_socket_name; + } + + void StopRaylet(std::string raylet_socket_name) { + std::string raylet_pid = raylet_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + raylet_pid + "`"; + RAY_LOG(INFO) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); + } + + void SetUp() { flushall_redis(); } + + void TearDown() {} protected: - CoreWorker core_worker_; + std::vector raylet_socket_names_; + std::vector raylet_store_socket_names_; }; -TEST_F(CoreWorkerTest, TestTaskArg) { +class ZeroNodeTest : public CoreWorkerTest { + public: + ZeroNodeTest() : CoreWorkerTest(0) {} +}; + +class SingleNodeTest : public CoreWorkerTest { + public: + SingleNodeTest() : CoreWorkerTest(1) {} +}; + +TEST_F(ZeroNodeTest, TestTaskArg) { // Test by-reference argument. ObjectID id = ObjectID::FromRandom(); TaskArg by_ref = TaskArg::PassByReference(id); @@ -30,9 +147,100 @@ TEST_F(CoreWorkerTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -TEST_F(CoreWorkerTest, TestAttributeGetters) { - ASSERT_EQ(core_worker_.WorkerType(), WorkerType::WORKER); - ASSERT_EQ(core_worker_.Language(), Language::PYTHON); +TEST_F(ZeroNodeTest, TestAttributeGetters) { + CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", "", + DriverID::FromRandom()); + ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER); + ASSERT_EQ(core_worker.Language(), Language::PYTHON); +} + +TEST_F(ZeroNodeTest, TestWorkerContext) { + auto driver_id = DriverID::FromRandom(); + + WorkerContext context(WorkerType::WORKER, driver_id); + ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); + ASSERT_EQ(context.GetNextTaskIndex(), 1); + ASSERT_EQ(context.GetNextTaskIndex(), 2); + ASSERT_EQ(context.GetNextPutIndex(), 1); + ASSERT_EQ(context.GetNextPutIndex(), 2); + + auto thread_func = [&context]() { + // Verify that task_index, put_index are thread-local. + ASSERT_TRUE(!context.GetCurrentTaskID().IsNil()); + ASSERT_EQ(context.GetNextTaskIndex(), 1); + ASSERT_EQ(context.GetNextPutIndex(), 1); + }; + + std::thread async_thread(thread_func); + async_thread.join(); + + // Verify that these fields are thread-local. + ASSERT_EQ(context.GetNextTaskIndex(), 3); + ASSERT_EQ(context.GetNextPutIndex(), 3); +} + +TEST_F(SingleNodeTest, TestObjectInterface) { + CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + RAY_CHECK_OK(core_worker.Connect()); + + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + std::vector buffers; + buffers.emplace_back(array1, sizeof(array1)); + buffers.emplace_back(array2, sizeof(array2)); + + std::vector ids(buffers.size()); + for (int i = 0; i < ids.size(); i++) { + core_worker.Objects().Put(buffers[i], &ids[i]); + } + + // Test Get(). + std::vector> results; + core_worker.Objects().Get(ids, 0, &results); + + ASSERT_EQ(results.size(), 2); + for (int i = 0; i < ids.size(); i++) { + ASSERT_EQ(results[i]->Size(), buffers[i].Size()); + ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); + } + + // Test Wait(). + ObjectID non_existent_id = ObjectID::FromRandom(); + std::vector all_ids(ids); + all_ids.push_back(non_existent_id); + + std::vector wait_results; + core_worker.Objects().Wait(all_ids, 2, -1, &wait_results); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + core_worker.Objects().Wait(all_ids, 3, 100, &wait_results); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + // Test Delete(). + // clear the reference held by PlasmaBuffer. + results.clear(); + core_worker.Objects().Delete(ids, true, false); + + // Note that Delete() calls RayletClient::FreeObjects and would not + // wait for objects being deleted, so wait a while for plasma store + // to process the command. + usleep(200 * 1000); + core_worker.Objects().Get(ids, 0, &results); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); } } // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ray::store_executable = std::string(argv[1]); + ray::raylet_executable = std::string(argv[2]); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index d5d5d6f883f6..c966192610c7 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,25 +1,128 @@ #include "object_interface.h" +#include "context.h" +#include "core_worker.h" +#include "ray/ray_config.h" namespace ray { -Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID *object_id) { +CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker) + : core_worker_(core_worker) {} + +Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) { + ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(), + core_worker_.worker_context_.GetNextPutIndex()); + *object_id = put_id; + + auto plasma_id = put_id.ToPlasmaId(); + std::shared_ptr data; + RAY_ARROW_RETURN_NOT_OK( + core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); return Status::OK(); } Status CoreWorkerObjectInterface::Get(const std::vector &ids, - int64_t timeout_ms, std::vector *results) { + int64_t timeout_ms, + std::vector> *results) { + (*results).resize(ids.size(), nullptr); + + bool was_blocked = false; + + std::unordered_map unready; + for (int i = 0; i < ids.size(); i++) { + unready.insert({ids[i], i}); + } + + int num_attempts = 0; + bool should_break = false; + int64_t remaining_timeout = timeout_ms; + // Repeat until we get all objects. + while (!unready.empty() && !should_break) { + std::vector unready_ids; + for (const auto &entry : unready) { + unready_ids.push_back(entry.first); + } + + // For the initial fetch, we only fetch the objects, do not reconstruct them. + bool fetch_only = num_attempts == 0; + if (!fetch_only) { + // If fetch_only is false, this worker will be blocked. + was_blocked = true; + } + + // TODO: can call `fetchOrReconstruct` in batches as an optimization. + RAY_CHECK_OK(core_worker_.raylet_client_->FetchOrReconstruct( + unready_ids, fetch_only, core_worker_.worker_context_.GetCurrentTaskID())); + + // Get the objects from the object store, and parse the result. + int64_t get_timeout; + if (remaining_timeout >= 0) { + get_timeout = + std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); + remaining_timeout -= get_timeout; + should_break = remaining_timeout <= 0; + } else { + get_timeout = RayConfig::instance().get_timeout_milliseconds(); + } + + std::vector plasma_ids; + for (const auto &id : unready_ids) { + plasma_ids.push_back(id.ToPlasmaId()); + } + + std::vector object_buffers; + auto status = + core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); + + for (int i = 0; i < object_buffers.size(); i++) { + if (object_buffers[i].data != nullptr) { + const auto &object_id = unready_ids[i]; + (*results)[unready[object_id]] = + std::make_shared(object_buffers[i].data); + unready.erase(object_id); + } + } + + num_attempts += 1; + // TODO: log a message if attempted too many times. + } + + if (was_blocked) { + RAY_CHECK_OK(core_worker_.raylet_client_->NotifyUnblocked( + core_worker_.worker_context_.GetCurrentTaskID())); + } + return Status::OK(); } Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, int num_objects, int64_t timeout_ms, std::vector *results) { - return Status::OK(); + WaitResultPair result_pair; + auto status = core_worker_.raylet_client_->Wait( + object_ids, num_objects, timeout_ms, false, + core_worker_.worker_context_.GetCurrentTaskID(), &result_pair); + std::unordered_set ready_ids; + for (const auto &entry : result_pair.first) { + ready_ids.insert(entry); + } + + // TODO: change RayletClient::Wait() to return a bit set, so that we don't need + // to do this translation. + (*results).resize(object_ids.size()); + for (int i = 0; i < object_ids.size(); i++) { + (*results)[i] = ready_ids.count(object_ids[i]) > 0; + } + + return status; } Status CoreWorkerObjectInterface::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - return Status::OK(); + return core_worker_.raylet_client_->FreeObjects(object_ids, local_only, + delete_creating_tasks); } } // namespace ray diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 424c123ee543..f14c5297c456 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -2,6 +2,7 @@ #define RAY_CORE_WORKER_OBJECT_INTERFACE_H #include "common.h" +#include "plasma/client.h" #include "ray/common/buffer.h" #include "ray/id.h" #include "ray/status.h" @@ -13,14 +14,14 @@ class CoreWorker; /// The interface that contains all `CoreWorker` methods that are related to object store. class CoreWorkerObjectInterface { public: - CoreWorkerObjectInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + CoreWorkerObjectInterface(CoreWorker &core_worker); /// Put an object into object store. /// /// \param[in] buffer Data buffer of the object. /// \param[out] object_id Generated ID of the object. /// \return Status. - Status Put(const Buffer &buffer, const ObjectID *object_id); + Status Put(const Buffer &buffer, ObjectID *object_id); /// Get a list of objects from the object store. /// @@ -29,7 +30,7 @@ class CoreWorkerObjectInterface { /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &ids, int64_t timeout_ms, - std::vector *results); + std::vector> *results); /// Wait for a list of objects to appear in the object store. /// diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 6a7742c6b5a4..13450a4b7642 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -791,7 +791,7 @@ int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleString return REDISMODULE_OK; } -Status is_nil(bool *out, const std::string &data) { +Status IsNil(bool *out, const std::string &data) { if (data.size() != kUniqueIDSize) { return Status::RedisError("Size of data doesn't match size of UniqueID"); } @@ -836,7 +836,7 @@ int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg static_cast(update->test_state_bitmask()); bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(is_nil(&is_nil_result, update->test_raylet_id()->str())); + REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); if (!is_nil_result) { do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); } diff --git a/src/ray/status.h b/src/ray/status.h index fb6252b34667..340ffb3112cc 100644 --- a/src/ray/status.h +++ b/src/ray/status.h @@ -54,6 +54,17 @@ // This macro is used to replace the "ARROW_CHECK_OK" macro. #define RAY_ARROW_CHECK_OK(s) RAY_ARROW_CHECK_OK_PREPEND(s, "Bad status") +// If arrow status is not ok, return a ray IOError status +// with the error message. +#define RAY_ARROW_RETURN_NOT_OK(s) \ + do { \ + ::arrow::Status _s = (s); \ + if (RAY_PREDICT_FALSE(!_s.ok())) { \ + return ray::Status::IOError(_s.message()); \ + ; \ + } \ + } while (0) + namespace ray { enum class StatusCode : char { diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh new file mode 100644 index 000000000000..5f1dd2eda69f --- /dev/null +++ b/src/ray/test/run_core_worker_tests.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +# This needs to be run in the root directory. + +# Cause the script to exit if a single command fails. +set -e +set -x + +bazel build "//:core_worker_test" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" + +# Get the directory in which this script is executing. +SCRIPT_DIR="`dirname \"$0\"`" +RAY_ROOT="$SCRIPT_DIR/../../.." +# Makes $RAY_ROOT an absolute path. +RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`" +if [ -z "$RAY_ROOT" ] ; then + exit 1 +fi +# Ensure we're in the right directory. +if [ ! -d "$RAY_ROOT/python" ]; then + echo "Unable to find root Ray directory. Has this script moved?" + exit 1 +fi + +REDIS_MODULE="./bazel-bin/libray_redis_module.so" +LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" +STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" +RAYLET_EXEC="./bazel-bin/raylet" + +# Allow cleanup commands to fail. +bazel run //:redis-cli -- -p 6379 shutdown || true +sleep 1s +bazel run //:redis-cli -- -p 6380 shutdown || true +sleep 1s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 & +sleep 2s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & +sleep 2s +# Run tests. +./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC +sleep 1s +bazel run //:redis-cli -- -p 6379 shutdown +bazel run //:redis-cli -- -p 6380 shutdown +sleep 1s + +# Include raylet integration test once it's ready. +# ./bazel-bin/object_manager_integration_test $STORE_EXEC