Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ cc_library(
srcs = glob(
[
"src/ray/core_worker/*.cc",
"src/ray/core_worker/store_provider/*.cc",
"src/ray/core_worker/transport/*.cc",
],
exclude = [
"src/ray/core_worker/*_test.cc",
Expand All @@ -119,6 +121,8 @@ cc_library(
),
hdrs = glob([
"src/ray/core_worker/*.h",
"src/ray/core_worker/store_provider/*.h",
"src/ray/core_worker/transport/*.h",
]),
copts = COPTS,
deps = [
Expand Down
31 changes: 31 additions & 0 deletions src/ray/core_worker/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/raylet/raylet_client.h"
#include "ray/raylet/task_spec.h"

namespace ray {

Expand Down Expand Up @@ -66,6 +68,35 @@ class TaskArg {
const std::shared_ptr<Buffer> data_;
};

/// Task specification, which includes the immutable information about the task
/// which are determined at the submission time.
/// TODO(zhijunfu): this can be removed after everything is moved to protobuf.
class TaskSpec {
public:
TaskSpec(const raylet::TaskSpecification &task_spec,
const std::vector<ObjectID> &dependencies)
: task_spec_(task_spec), dependencies_(dependencies) {}

TaskSpec(const raylet::TaskSpecification &&task_spec,
const std::vector<ObjectID> &&dependencies)
: task_spec_(task_spec), dependencies_(dependencies) {}

const raylet::TaskSpecification &GetTaskSpecification() const { return task_spec_; }

const std::vector<ObjectID> &GetDependencies() const { return dependencies_; }

private:
/// Raylet task specification.
raylet::TaskSpecification task_spec_;

/// Dependencies.
std::vector<ObjectID> dependencies_;
};

enum class StoreProviderType { PLASMA };

enum class TaskTransportType { RAYLET };

} // namespace ray

#endif // RAY_CORE_WORKER_COMMON_H
44 changes: 21 additions & 23 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,39 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type,
DriverID driver_id)
: worker_type_(worker_type),
language_(language),
worker_context_(worker_type, driver_id),
store_socket_(store_socket),
raylet_socket_(raylet_socket),
is_initialized_(false),
worker_context_(worker_type, driver_id),
raylet_client_(raylet_socket_, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {
switch (language_) {
// TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot
// connect to Raylet after a number of retries, this needs to be changed
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto status = store_client_.Connect(store_socket_);
if (!status.ok()) {
RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct"
<< " core worker: " << status.message();
throw std::runtime_error(status.message());
}
}

::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) {
switch (language) {
case ray::WorkerLanguage::JAVA:
task_language_ = ::Language::JAVA;
return ::Language::JAVA;
break;
case ray::WorkerLanguage::PYTHON:
task_language_ = ::Language::PYTHON;
return ::Language::PYTHON;
break;
default:
RAY_LOG(FATAL) << "Unsupported worker language: " << static_cast<int>(language_);
RAY_LOG(FATAL) << "invalid language specified: " << static_cast<int>(language);
break;
}
}

Status CoreWorker::Connect() {
// connect to plasma.
RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_));

// connect to raylet.
// TODO: currently RayletClient would crash in its constructor if it cannot
// connect to Raylet after a number of retries, this needs to be changed
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
raylet_client_ = std::unique_ptr<RayletClient>(
new RayletClient(raylet_socket_, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentDriverID(), task_language_));
is_initialized_ = true;
return Status::OK();
}

} // namespace ray
29 changes: 14 additions & 15 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ class CoreWorker {
///
/// \param[in] worker_type Type of this worker.
/// \param[in] langauge Language of this worker.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
CoreWorker(const WorkerType worker_type, const WorkerLanguage language,
const std::string &store_socket, const std::string &raylet_socket,
DriverID driver_id = DriverID::Nil());

/// Connect to raylet.
Status Connect();

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }

Expand All @@ -46,23 +45,26 @@ class CoreWorker {
CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; }

private:
/// Translate from WorkLanguage to Language type (required by raylet client).
///
/// \param[in] language Language for a task.
/// \return Translated task language.
::Language ToTaskLanguage(WorkerLanguage language);

/// Type of this worker.
const enum WorkerType worker_type_;

/// Language of this worker.
const enum WorkerLanguage language_;

/// Language of this worker as specified in flatbuf (used by task spec).
::Language task_language_;

/// Worker context per thread.
WorkerContext worker_context_;

/// Plasma store socket name.
std::string store_socket_;
const std::string store_socket_;

/// raylet socket name.
std::string raylet_socket_;
const std::string raylet_socket_;

/// Worker context.
WorkerContext worker_context_;

/// Plasma store client.
plasma::PlasmaClient store_client_;
Expand All @@ -71,10 +73,7 @@ class CoreWorker {
std::mutex store_client_mutex_;

/// Raylet client.
std::unique_ptr<RayletClient> raylet_client_;

/// Whether this worker has been initialized.
bool is_initialized_;
RayletClient raylet_client_;

/// The `CoreWorkerTaskInterface` instance.
CoreWorkerTaskInterface task_interface_;
Expand Down
22 changes: 9 additions & 13 deletions src/ray/core_worker/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ class CoreWorkerTest : public ::testing::Test {
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());

RAY_CHECK_OK(driver.Connect());

// Test pass by value.
{
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
Expand Down Expand Up @@ -187,7 +185,6 @@ class CoreWorkerTest : public ::testing::Test {
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(driver.Connect());

std::unique_ptr<ActorHandle> actor_handle;

Expand Down Expand Up @@ -277,13 +274,6 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
ASSERT_EQ(*data, *buffer);
}

TEST_F(ZeroNodeTest, TestAttributeGetters) {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", "",
DriverID::FromRandom());
ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER);
ASSERT_EQ(core_worker.Language(), WorkerLanguage::PYTHON);
}

TEST_F(ZeroNodeTest, TestWorkerContext) {
auto driver_id = DriverID::FromRandom();

Expand Down Expand Up @@ -313,7 +303,6 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(core_worker.Connect());

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -370,12 +359,10 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(worker1.Connect());

CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[1], raylet_socket_names_[1],
DriverID::FromRandom());
RAY_CHECK_OK(worker2.Connect());

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -456,6 +443,15 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) {
TestActorTask(resources);
}

TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) {
try {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "",
raylet_socket_names_[0], DriverID::FromRandom());
} catch (const std::exception &e) {
std::cout << "Caught exception when constructing core worker: " << e.what();
}
}

} // namespace ray

int main(int argc, char **argv) {
Expand Down
4 changes: 1 addition & 3 deletions src/ray/core_worker/mock_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket)
: worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket,
DriverID::FromRandom()) {
RAY_CHECK_OK(worker_.Connect());
}
DriverID::FromRandom()) {}

void Run() {
auto executor_func = [this](const RayFunction &ray_function,
Expand Down
Loading