Skip to content

Commit b674c4a

Browse files
zhijunfuraulchen
authored andcommitted
[Core Worker] implement ObjectInterface and add test framework (#4899)
1 parent 89722ff commit b674c4a

14 files changed

+611
-27
lines changed

.travis.yml

+3
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ install:
148148
- ./ci/suppress_output bazel build //:stats_test -c opt
149149
- ./bazel-bin/stats_test
150150

151+
# core worker test.
152+
- ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh
153+
151154
# Raylet tests.
152155
- ./ci/suppress_output bash src/ray/test/run_object_manager_tests.sh
153156
- ./ci/suppress_output bazel test --build_tests_only --test_lang_filters=cc //:all

BUILD.bazel

+6-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ cc_library(
7777
"src/ray/raylet/mock_gcs_client.cc",
7878
"src/ray/raylet/monitor_main.cc",
7979
"src/ray/raylet/*_test.cc",
80+
"src/ray/raylet/main.cc",
8081
],
8182
),
8283
hdrs = glob([
@@ -122,15 +123,18 @@ cc_library(
122123
deps = [
123124
":ray_common",
124125
":ray_util",
126+
":raylet_lib",
125127
],
126128
)
127129

128-
cc_test(
130+
# This test is run by src/ray/test/run_core_worker_tests.sh
131+
cc_binary(
129132
name = "core_worker_test",
130133
srcs = ["src/ray/core_worker/core_worker_test.cc"],
131134
copts = COPTS,
132135
deps = [
133136
":core_worker_lib",
137+
":gcs",
134138
"@com_google_googletest//:gtest_main",
135139
],
136140
)
@@ -320,6 +324,7 @@ cc_library(
320324
":node_manager_fbs",
321325
":ray_util",
322326
"@boost//:asio",
327+
"@plasma//:plasma_client",
323328
],
324329
)
325330

src/ray/common/buffer.h

+21-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
#include <cstdint>
55
#include <cstdio>
6+
#include "plasma/client.h"
7+
8+
namespace arrow {
9+
class Buffer;
10+
}
611

712
namespace ray {
813

@@ -15,7 +20,7 @@ class Buffer {
1520
/// Size of this buffer.
1621
virtual size_t Size() const = 0;
1722

18-
virtual ~Buffer() {}
23+
virtual ~Buffer(){};
1924

2025
bool operator==(const Buffer &rhs) const {
2126
return this->Data() == rhs.Data() && this->Size() == rhs.Size();
@@ -40,6 +45,21 @@ class LocalMemoryBuffer : public Buffer {
4045
size_t size_;
4146
};
4247

48+
/// Represents a byte buffer for plasma object.
49+
class PlasmaBuffer : public Buffer {
50+
public:
51+
PlasmaBuffer(std::shared_ptr<arrow::Buffer> buffer) : buffer_(buffer) {}
52+
53+
uint8_t *Data() const override { return const_cast<uint8_t *>(buffer_->data()); }
54+
55+
size_t Size() const override { return buffer_->size(); }
56+
57+
private:
58+
/// shared_ptr to arrow buffer which can potentially hold a reference
59+
/// for the object (when it's a plasma::PlasmaBuffer).
60+
std::shared_ptr<arrow::Buffer> buffer_;
61+
};
62+
4363
} // namespace ray
4464

4565
#endif // RAY_COMMON_BUFFER_H

src/ray/core_worker/common.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ class TaskArg {
4545
bool IsPassedByReference() const { return id_ != nullptr; }
4646

4747
/// Get the reference object ID.
48-
ObjectID &GetReference() {
48+
const ObjectID &GetReference() const {
4949
RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference.";
5050
return *id_;
5151
}
5252

5353
/// Get the value.
54-
std::shared_ptr<Buffer> GetValue() {
54+
std::shared_ptr<Buffer> GetValue() const {
5555
RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value.";
5656
return data_;
5757
}

src/ray/core_worker/context.cc

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
2+
#include "context.h"
3+
4+
namespace ray {
5+
6+
/// per-thread context for core worker.
7+
struct WorkerThreadContext {
8+
WorkerThreadContext()
9+
: current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}
10+
11+
int GetNextTaskIndex() { return ++task_index; }
12+
13+
int GetNextPutIndex() { return ++put_index; }
14+
15+
const TaskID &GetCurrentTaskID() const { return current_task_id; }
16+
17+
void SetCurrentTask(const TaskID &task_id) {
18+
current_task_id = task_id;
19+
task_index = 0;
20+
put_index = 0;
21+
}
22+
23+
void SetCurrentTask(const raylet::TaskSpecification &spec) {
24+
SetCurrentTask(spec.TaskId());
25+
}
26+
27+
private:
28+
/// The task ID for current task.
29+
TaskID current_task_id;
30+
31+
/// Number of tasks that have been submitted from current task.
32+
int task_index;
33+
34+
/// Number of objects that have been put from current task.
35+
int put_index;
36+
};
37+
38+
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
39+
nullptr;
40+
41+
WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id)
42+
: worker_type(worker_type),
43+
worker_id(worker_type == WorkerType::DRIVER
44+
? ClientID::FromBinary(driver_id.Binary())
45+
: ClientID::FromRandom()),
46+
current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) {
47+
// For worker main thread which initializes the WorkerContext,
48+
// set task_id according to whether current worker is a driver.
49+
// (For other threads it's set to randmom ID via GetThreadContext).
50+
GetThreadContext().SetCurrentTask(
51+
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
52+
}
53+
54+
const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
55+
56+
const ClientID &WorkerContext::GetWorkerID() const { return worker_id; }
57+
58+
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
59+
60+
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
61+
62+
const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; }
63+
64+
const TaskID &WorkerContext::GetCurrentTaskID() const {
65+
return GetThreadContext().GetCurrentTaskID();
66+
}
67+
68+
void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) {
69+
current_driver_id = spec.DriverId();
70+
GetThreadContext().SetCurrentTask(spec);
71+
}
72+
73+
WorkerThreadContext &WorkerContext::GetThreadContext() {
74+
if (thread_context_ == nullptr) {
75+
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());
76+
}
77+
78+
return *thread_context_;
79+
}
80+
81+
} // namespace ray

src/ray/core_worker/context.h

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef RAY_CORE_WORKER_CONTEXT_H
2+
#define RAY_CORE_WORKER_CONTEXT_H
3+
4+
#include "common.h"
5+
#include "ray/raylet/task_spec.h"
6+
7+
namespace ray {
8+
9+
struct WorkerThreadContext;
10+
11+
class WorkerContext {
12+
public:
13+
WorkerContext(WorkerType worker_type, const DriverID &driver_id);
14+
15+
const WorkerType GetWorkerType() const;
16+
17+
const ClientID &GetWorkerID() const;
18+
19+
const DriverID &GetCurrentDriverID() const;
20+
21+
const TaskID &GetCurrentTaskID() const;
22+
23+
void SetCurrentTask(const raylet::TaskSpecification &spec);
24+
25+
int GetNextTaskIndex();
26+
27+
int GetNextPutIndex();
28+
29+
private:
30+
/// Type of the worker.
31+
const WorkerType worker_type;
32+
33+
/// ID for this worker.
34+
const ClientID worker_id;
35+
36+
/// Driver ID for this worker.
37+
DriverID current_driver_id;
38+
39+
private:
40+
static WorkerThreadContext &GetThreadContext();
41+
42+
/// Per-thread worker context.
43+
static thread_local std::unique_ptr<WorkerThreadContext> thread_context_;
44+
};
45+
46+
} // namespace ray
47+
48+
#endif // RAY_CORE_WORKER_CONTEXT_H

src/ray/core_worker/core_worker.cc

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "core_worker.h"
2+
#include "context.h"
3+
4+
namespace ray {
5+
6+
CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
7+
const std::string &store_socket, const std::string &raylet_socket,
8+
DriverID driver_id)
9+
: worker_type_(worker_type),
10+
language_(language),
11+
worker_context_(worker_type, driver_id),
12+
store_socket_(store_socket),
13+
raylet_socket_(raylet_socket),
14+
task_interface_(*this),
15+
object_interface_(*this),
16+
task_execution_interface_(*this) {}
17+
18+
Status CoreWorker::Connect() {
19+
// connect to plasma.
20+
RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_));
21+
22+
// connect to raylet.
23+
::Language lang = ::Language::PYTHON;
24+
if (language_ == ray::Language::JAVA) {
25+
lang = ::Language::JAVA;
26+
}
27+
28+
// TODO: currently RayletClient would crash in its constructor if it cannot
29+
// connect to Raylet after a number of retries, this needs to be changed
30+
// so that the worker (java/python .etc) can retrieve and handle the error
31+
// instead of crashing.
32+
raylet_client_ = std::unique_ptr<RayletClient>(
33+
new RayletClient(raylet_socket_, worker_context_.GetWorkerID(),
34+
(worker_type_ == ray::WorkerType::WORKER),
35+
worker_context_.GetCurrentDriverID(), lang));
36+
return Status::OK();
37+
}
38+
39+
} // namespace ray

src/ray/core_worker/core_worker.h

+26-8
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#define RAY_CORE_WORKER_CORE_WORKER_H
33

44
#include "common.h"
5+
#include "context.h"
56
#include "object_interface.h"
67
#include "ray/common/buffer.h"
8+
#include "ray/raylet/raylet_client.h"
79
#include "task_execution.h"
810
#include "task_interface.h"
911

@@ -18,15 +20,12 @@ class CoreWorker {
1820
///
1921
/// \param[in] worker_type Type of this worker.
2022
/// \param[in] langauge Language of this worker.
21-
CoreWorker(const WorkerType worker_type, const Language language)
22-
: worker_type_(worker_type),
23-
language_(language),
24-
task_interface_(*this),
25-
object_interface_(*this),
26-
task_execution_interface_(*this) {}
23+
CoreWorker(const WorkerType worker_type, const Language language,
24+
const std::string &store_socket, const std::string &raylet_socket,
25+
DriverID driver_id = DriverID::Nil());
2726

28-
/// Connect this worker to Raylet.
29-
Status Connect() { return Status::OK(); }
27+
/// Connect to raylet.
28+
Status Connect();
3029

3130
/// Type of this worker.
3231
enum WorkerType WorkerType() const { return worker_type_; }
@@ -53,6 +52,21 @@ class CoreWorker {
5352
/// Language of this worker.
5453
const enum Language language_;
5554

55+
/// Worker context per thread.
56+
WorkerContext worker_context_;
57+
58+
/// Plasma store socket name.
59+
std::string store_socket_;
60+
61+
/// raylet socket name.
62+
std::string raylet_socket_;
63+
64+
/// Plasma store client.
65+
plasma::PlasmaClient store_client_;
66+
67+
/// Raylet client.
68+
std::unique_ptr<RayletClient> raylet_client_;
69+
5670
/// The `CoreWorkerTaskInterface` instance.
5771
CoreWorkerTaskInterface task_interface_;
5872

@@ -61,6 +75,10 @@ class CoreWorker {
6175

6276
/// The `CoreWorkerTaskExecutionInterface` instance.
6377
CoreWorkerTaskExecutionInterface task_execution_interface_;
78+
79+
friend class CoreWorkerTaskInterface;
80+
friend class CoreWorkerObjectInterface;
81+
friend class CoreWorkerTaskExecutionInterface;
6482
};
6583

6684
} // namespace ray

0 commit comments

Comments
 (0)