Skip to content

Commit 3b631ef

Browse files
sampan-s-nayaksampanedoakes
authored
[Core] Support token auth in ray_syncer (#58176)
This PR adds support for token-based authentication in the Ray bi-directional syncer, for both client and server sides. It also includes tests to verify the functionality. --------- Signed-off-by: sampan <[email protected]> Signed-off-by: Edward Oakes <[email protected]> Co-authored-by: sampan <[email protected]> Co-authored-by: Edward Oakes <[email protected]>
1 parent a23f6ca commit 3b631ef

File tree

12 files changed

+289
-10
lines changed

12 files changed

+289
-10
lines changed

src/ray/gcs/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ ray_cc_library(
531531
"//src/ray/raylet_rpc_client:raylet_client_pool",
532532
"//src/ray/rpc:grpc_server",
533533
"//src/ray/rpc:metrics_agent_client",
534+
"//src/ray/rpc/authentication:authentication_token_loader",
534535
"//src/ray/util:counter_map",
535536
"//src/ray/util:exponential_backoff",
536537
"//src/ray/util:network_util",

src/ray/gcs/gcs_server.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "ray/observability/metric_constants.h"
4040
#include "ray/pubsub/publisher.h"
4141
#include "ray/raylet_rpc_client/raylet_client.h"
42+
#include "ray/rpc/authentication/authentication_token_loader.h"
4243
#include "ray/stats/stats.h"
4344
#include "ray/util/network_util.h"
4445

@@ -615,7 +616,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) {
615616
syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get());
616617
ray_syncer_->Register(
617618
syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get());
618-
rpc_server_.RegisterService(std::make_unique<syncer::RaySyncerService>(*ray_syncer_));
619+
rpc_server_.RegisterService(std::make_unique<syncer::RaySyncerService>(
620+
*ray_syncer_, ray::rpc::AuthenticationTokenLoader::instance().GetToken()));
619621
}
620622

621623
void GcsServer::InitFunctionManager() {

src/ray/ray_syncer/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ ray_cc_library(
1919
],
2020
deps = [
2121
"//src/ray/common:asio",
22+
"//src/ray/common:constants",
2223
"//src/ray/common:id",
2324
"//src/ray/protobuf:ray_syncer_cc_grpc",
25+
"//src/ray/rpc/authentication:authentication_token",
26+
"//src/ray/rpc/authentication:authentication_token_loader",
2427
"@com_github_grpc_grpc//:grpc++",
2528
"@com_google_absl//absl/container:flat_hash_map",
2629
],

src/ray/ray_syncer/ray_syncer.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,17 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont
244244
}
245245
RAY_LOG(INFO).WithField(NodeID::FromBinary(node_id)) << "Connection is broken.";
246246
syncer_.node_state_->RemoveNode(node_id);
247-
});
247+
},
248+
/*auth_token=*/auth_token_);
248249
RAY_LOG(DEBUG).WithField(NodeID::FromBinary(reactor->GetRemoteNodeID()))
249250
<< "Get connection";
251+
252+
// If the reactor has already called Finish() (e.g., due to authentication failure),
253+
// skip registration. The reactor will clean itself up via OnDone().
254+
if (reactor->IsFinished()) {
255+
return reactor;
256+
}
257+
250258
// Disconnect exiting connection if there is any.
251259
// This can happen when there is transient network error
252260
// and the client reconnects.

src/ray/ray_syncer/ray_syncer.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <memory>
2121
#include <string>
22+
#include <utility>
2223
#include <vector>
2324

2425
#include "absl/container/flat_hash_map.h"
@@ -28,6 +29,7 @@
2829
#include "ray/common/asio/periodical_runner.h"
2930
#include "ray/common/id.h"
3031
#include "ray/ray_syncer/common.h"
32+
#include "ray/rpc/authentication/authentication_token.h"
3133
#include "src/ray/protobuf/ray_syncer.grpc.pb.h"
3234

3335
namespace ray::syncer {
@@ -197,14 +199,20 @@ class RaySyncer {
197199
/// like tree-based one.
198200
class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService {
199201
public:
200-
explicit RaySyncerService(RaySyncer &syncer) : syncer_(syncer) {}
202+
explicit RaySyncerService(
203+
RaySyncer &syncer,
204+
std::optional<ray::rpc::AuthenticationToken> auth_token = std::nullopt)
205+
: syncer_(syncer), auth_token_(std::move(auth_token)) {}
201206

202207
grpc::ServerBidiReactor<RaySyncMessage, RaySyncMessage> *StartSync(
203208
grpc::CallbackServerContext *context) override;
204209

205210
private:
206211
// The ray syncer this RPC wrappers of.
207212
RaySyncer &syncer_;
213+
// Authentication token for validation, will be empty if token authentication is
214+
// disabled
215+
std::optional<ray::rpc::AuthenticationToken> auth_token_;
208216
};
209217

210218
} // namespace ray::syncer

src/ray/ray_syncer/ray_syncer_client.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <string>
1919
#include <utility>
2020

21+
#include "ray/rpc/authentication/authentication_token_loader.h"
22+
2123
namespace ray::syncer {
2224

2325
RayClientBidiReactor::RayClientBidiReactor(
@@ -32,6 +34,11 @@ RayClientBidiReactor::RayClientBidiReactor(
3234
cleanup_cb_(std::move(cleanup_cb)),
3335
stub_(std::move(stub)) {
3436
client_context_.AddMetadata("node_id", NodeID::FromBinary(local_node_id).Hex());
37+
// Add authentication token if token authentication is enabled
38+
auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken();
39+
if (auth_token.has_value() && !auth_token->empty()) {
40+
auth_token->SetMetadata(client_context_);
41+
}
3542
stub_->async()->StartSync(&client_context_, this);
3643
// Prevent this call from being terminated.
3744
// Check https://github.com/grpc/proposal/blob/master/L67-cpp-callback-api.md

src/ray/ray_syncer/ray_syncer_server.cc

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <string>
1818
#include <utility>
1919

20+
#include "ray/common/constants.h"
21+
2022
namespace ray::syncer {
2123

2224
namespace {
@@ -35,13 +37,39 @@ RayServerBidiReactor::RayServerBidiReactor(
3537
instrumented_io_context &io_context,
3638
const std::string &local_node_id,
3739
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
38-
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb)
40+
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb,
41+
const std::optional<ray::rpc::AuthenticationToken> &auth_token)
3942
: RaySyncerBidiReactorBase<ServerBidiReactor>(
4043
io_context,
4144
GetNodeIDFromServerContext(server_context),
4245
std::move(message_processor)),
4346
cleanup_cb_(std::move(cleanup_cb)),
44-
server_context_(server_context) {
47+
server_context_(server_context),
48+
auth_token_(auth_token) {
49+
if (auth_token_.has_value() && !auth_token_->empty()) {
50+
// Validate authentication token
51+
const auto &metadata = server_context->client_metadata();
52+
auto it = metadata.find(kAuthTokenKey);
53+
if (it == metadata.end()) {
54+
RAY_LOG(WARNING) << "Missing authorization header in syncer connection from node "
55+
<< NodeID::FromBinary(GetRemoteNodeID());
56+
Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED,
57+
"Missing authorization header"));
58+
return;
59+
}
60+
61+
const std::string_view header(it->second.data(), it->second.length());
62+
ray::rpc::AuthenticationToken provided_token =
63+
ray::rpc::AuthenticationToken::FromMetadata(header);
64+
65+
if (!auth_token_->Equals(provided_token)) {
66+
RAY_LOG(WARNING) << "Invalid bearer token in syncer connection from node "
67+
<< NodeID::FromBinary(GetRemoteNodeID());
68+
Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Invalid bearer token"));
69+
return;
70+
}
71+
}
72+
4573
// Send the local node id to the remote
4674
server_context_->AddInitialMetadata("node_id", NodeID::FromBinary(local_node_id).Hex());
4775
StartSendInitialMetadata();

src/ray/ray_syncer/ray_syncer_server.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616

1717
#include <gtest/gtest_prod.h>
1818

19+
#include <atomic>
20+
#include <optional>
1921
#include <string>
2022

2123
#include "ray/ray_syncer/common.h"
2224
#include "ray/ray_syncer/ray_syncer_bidi_reactor.h"
2325
#include "ray/ray_syncer/ray_syncer_bidi_reactor_base.h"
26+
#include "ray/rpc/authentication/authentication_token.h"
2427

2528
namespace ray::syncer {
2629

@@ -35,20 +38,36 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase<ServerBidiReactor>
3538
instrumented_io_context &io_context,
3639
const std::string &local_node_id,
3740
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
38-
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb);
41+
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb,
42+
const std::optional<ray::rpc::AuthenticationToken> &auth_token);
3943

4044
~RayServerBidiReactor() override = default;
4145

46+
bool IsFinished() const { return finished_.load(); }
47+
4248
private:
4349
void DoDisconnect() override;
4450
void OnCancel() override;
4551
void OnDone() override;
4652

53+
void Finish(grpc::Status status) {
54+
finished_.store(true);
55+
ServerBidiReactor::Finish(status);
56+
}
57+
4758
/// Cleanup callback when the call ends.
4859
const std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb_;
4960

5061
/// grpc callback context
5162
grpc::CallbackServerContext *server_context_;
63+
64+
/// Authentication token for validation, will be empty if token authentication is
65+
/// disabled
66+
std::optional<ray::rpc::AuthenticationToken> auth_token_;
67+
68+
/// Track if Finish() has been called to avoid using a reactor that is terminating
69+
std::atomic<bool> finished_{false};
70+
5271
FRIEND_TEST(SyncerReactorTest, TestReactorFailure);
5372
};
5473

src/ray/ray_syncer/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ray_cc_test(
1313
"//src/mock/ray/ray_syncer:mock_ray_syncer",
1414
"//src/ray/ray_syncer",
1515
"//src/ray/rpc:grpc_server",
16+
"//src/ray/rpc/authentication:authentication_token",
1617
"//src/ray/util:network_util",
1718
"//src/ray/util:path_utils",
1819
"//src/ray/util:raii",

0 commit comments

Comments
 (0)