diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index d81c266b0469..b4d88585bb96 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -531,6 +531,7 @@ ray_cc_library( "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/rpc:grpc_server", "//src/ray/rpc:metrics_agent_client", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:counter_map", "//src/ray/util:exponential_backoff", "//src/ray/util:network_util", diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 337e49ed3f67..ddc1d697de6a 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -39,6 +39,7 @@ #include "ray/observability/metric_constants.h" #include "ray/pubsub/publisher.h" #include "ray/raylet_rpc_client/raylet_client.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/stats/stats.h" #include "ray/util/network_util.h" @@ -615,7 +616,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get()); ray_syncer_->Register( syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get()); - rpc_server_.RegisterService(std::make_unique(*ray_syncer_)); + rpc_server_.RegisterService(std::make_unique( + *ray_syncer_, ray::rpc::AuthenticationTokenLoader::instance().GetToken())); } void GcsServer::InitFunctionManager() { diff --git a/src/ray/ray_syncer/BUILD.bazel b/src/ray/ray_syncer/BUILD.bazel index 680ea322d920..c7cd8ca2a0b0 100644 --- a/src/ray/ray_syncer/BUILD.bazel +++ b/src/ray/ray_syncer/BUILD.bazel @@ -19,8 +19,11 @@ ray_cc_library( ], deps = [ "//src/ray/common:asio", + "//src/ray/common:constants", "//src/ray/common:id", "//src/ray/protobuf:ray_syncer_cc_grpc", + "//src/ray/rpc/authentication:authentication_token", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/src/ray/ray_syncer/ray_syncer.cc b/src/ray/ray_syncer/ray_syncer.cc index 49c0e6894313..837b314ed104 100644 --- a/src/ray/ray_syncer/ray_syncer.cc +++ b/src/ray/ray_syncer/ray_syncer.cc @@ -244,9 +244,17 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont } RAY_LOG(INFO).WithField(NodeID::FromBinary(node_id)) << "Connection is broken."; syncer_.node_state_->RemoveNode(node_id); - }); + }, + /*auth_token=*/auth_token_); RAY_LOG(DEBUG).WithField(NodeID::FromBinary(reactor->GetRemoteNodeID())) << "Get connection"; + + // If the reactor has already called Finish() (e.g., due to authentication failure), + // skip registration. The reactor will clean itself up via OnDone(). + if (reactor->IsFinished()) { + return reactor; + } + // Disconnect exiting connection if there is any. // This can happen when there is transient network error // and the client reconnects. diff --git a/src/ray/ray_syncer/ray_syncer.h b/src/ray/ray_syncer/ray_syncer.h index 37cebb4f03ad..b842ed1f749b 100644 --- a/src/ray/ray_syncer/ray_syncer.h +++ b/src/ray/ray_syncer/ray_syncer.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -28,6 +29,7 @@ #include "ray/common/asio/periodical_runner.h" #include "ray/common/id.h" #include "ray/ray_syncer/common.h" +#include "ray/rpc/authentication/authentication_token.h" #include "src/ray/protobuf/ray_syncer.grpc.pb.h" namespace ray::syncer { @@ -197,7 +199,10 @@ class RaySyncer { /// like tree-based one. class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { public: - explicit RaySyncerService(RaySyncer &syncer) : syncer_(syncer) {} + explicit RaySyncerService( + RaySyncer &syncer, + std::optional auth_token = std::nullopt) + : syncer_(syncer), auth_token_(std::move(auth_token)) {} grpc::ServerBidiReactor *StartSync( grpc::CallbackServerContext *context) override; @@ -205,6 +210,9 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { private: // The ray syncer this RPC wrappers of. RaySyncer &syncer_; + // Authentication token for validation, will be empty if token authentication is + // disabled + std::optional auth_token_; }; } // namespace ray::syncer diff --git a/src/ray/ray_syncer/ray_syncer_client.cc b/src/ray/ray_syncer/ray_syncer_client.cc index 745262176ff0..2f91ce3b1560 100644 --- a/src/ray/ray_syncer/ray_syncer_client.cc +++ b/src/ray/ray_syncer/ray_syncer_client.cc @@ -18,6 +18,8 @@ #include #include +#include "ray/rpc/authentication/authentication_token_loader.h" + namespace ray::syncer { RayClientBidiReactor::RayClientBidiReactor( @@ -32,6 +34,11 @@ RayClientBidiReactor::RayClientBidiReactor( cleanup_cb_(std::move(cleanup_cb)), stub_(std::move(stub)) { client_context_.AddMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); + // Add authentication token if token authentication is enabled + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(client_context_); + } stub_->async()->StartSync(&client_context_, this); // Prevent this call from being terminated. // Check https://github.com/grpc/proposal/blob/master/L67-cpp-callback-api.md diff --git a/src/ray/ray_syncer/ray_syncer_server.cc b/src/ray/ray_syncer/ray_syncer_server.cc index cefb1c52a2fa..b254a3b2b475 100644 --- a/src/ray/ray_syncer/ray_syncer_server.cc +++ b/src/ray/ray_syncer/ray_syncer_server.cc @@ -17,6 +17,8 @@ #include #include +#include "ray/common/constants.h" + namespace ray::syncer { namespace { @@ -35,13 +37,39 @@ RayServerBidiReactor::RayServerBidiReactor( instrumented_io_context &io_context, const std::string &local_node_id, std::function)> message_processor, - std::function cleanup_cb) + std::function cleanup_cb, + const std::optional &auth_token) : RaySyncerBidiReactorBase( io_context, GetNodeIDFromServerContext(server_context), std::move(message_processor)), cleanup_cb_(std::move(cleanup_cb)), - server_context_(server_context) { + server_context_(server_context), + auth_token_(auth_token) { + if (auth_token_.has_value() && !auth_token_->empty()) { + // Validate authentication token + const auto &metadata = server_context->client_metadata(); + auto it = metadata.find(kAuthTokenKey); + if (it == metadata.end()) { + RAY_LOG(WARNING) << "Missing authorization header in syncer connection from node " + << NodeID::FromBinary(GetRemoteNodeID()); + Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, + "Missing authorization header")); + return; + } + + const std::string_view header(it->second.data(), it->second.length()); + ray::rpc::AuthenticationToken provided_token = + ray::rpc::AuthenticationToken::FromMetadata(header); + + if (!auth_token_->Equals(provided_token)) { + RAY_LOG(WARNING) << "Invalid bearer token in syncer connection from node " + << NodeID::FromBinary(GetRemoteNodeID()); + Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Invalid bearer token")); + return; + } + } + // Send the local node id to the remote server_context_->AddInitialMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); StartSendInitialMetadata(); diff --git a/src/ray/ray_syncer/ray_syncer_server.h b/src/ray/ray_syncer/ray_syncer_server.h index 531720cddd19..6db427958667 100644 --- a/src/ray/ray_syncer/ray_syncer_server.h +++ b/src/ray/ray_syncer/ray_syncer_server.h @@ -16,11 +16,14 @@ #include +#include +#include #include #include "ray/ray_syncer/common.h" #include "ray/ray_syncer/ray_syncer_bidi_reactor.h" #include "ray/ray_syncer/ray_syncer_bidi_reactor_base.h" +#include "ray/rpc/authentication/authentication_token.h" namespace ray::syncer { @@ -35,20 +38,36 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase instrumented_io_context &io_context, const std::string &local_node_id, std::function)> message_processor, - std::function cleanup_cb); + std::function cleanup_cb, + const std::optional &auth_token); ~RayServerBidiReactor() override = default; + bool IsFinished() const { return finished_.load(); } + private: void DoDisconnect() override; void OnCancel() override; void OnDone() override; + void Finish(grpc::Status status) { + finished_.store(true); + ServerBidiReactor::Finish(status); + } + /// Cleanup callback when the call ends. const std::function cleanup_cb_; /// grpc callback context grpc::CallbackServerContext *server_context_; + + /// Authentication token for validation, will be empty if token authentication is + /// disabled + std::optional auth_token_; + + /// Track if Finish() has been called to avoid using a reactor that is terminating + std::atomic finished_{false}; + FRIEND_TEST(SyncerReactorTest, TestReactorFailure); }; diff --git a/src/ray/ray_syncer/tests/BUILD.bazel b/src/ray/ray_syncer/tests/BUILD.bazel index 435bc4366ed4..9d06b8c35e03 100644 --- a/src/ray/ray_syncer/tests/BUILD.bazel +++ b/src/ray/ray_syncer/tests/BUILD.bazel @@ -13,6 +13,7 @@ ray_cc_test( "//src/mock/ray/ray_syncer:mock_ray_syncer", "//src/ray/ray_syncer", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", "//src/ray/util:network_util", "//src/ray/util:path_utils", "//src/ray/util:raii", diff --git a/src/ray/ray_syncer/tests/ray_syncer_test.cc b/src/ray/ray_syncer/tests/ray_syncer_test.cc index 869c930bacb6..be2e14ae53f7 100644 --- a/src/ray/ray_syncer/tests/ray_syncer_test.cc +++ b/src/ray/ray_syncer/tests/ray_syncer_test.cc @@ -37,6 +37,7 @@ #include "ray/ray_syncer/ray_syncer.h" #include "ray/ray_syncer/ray_syncer_client.h" #include "ray/ray_syncer/ray_syncer_server.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/util/network_util.h" #include "ray/util/path_utils.h" @@ -840,8 +841,12 @@ struct MockRaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackServic io_context(_io_context) {} grpc::ServerBidiReactor *StartSync( grpc::CallbackServerContext *context) override { - reactor = new RayServerBidiReactor( - context, io_context, node_id.Binary(), message_processor, cleanup_cb); + reactor = new RayServerBidiReactor(context, + io_context, + node_id.Binary(), + message_processor, + cleanup_cb, + std::nullopt); return reactor; } @@ -983,6 +988,200 @@ TEST_F(SyncerReactorTest, TestReactorFailure) { ASSERT_EQ(true, c_cleanup.second); } +// Authentication tests +class SyncerAuthenticationTest : public ::testing::Test { + protected: + void SetUp() override { + // Clear any existing environment variables and reset state + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + RayConfig::instance().auth_mode() = "disabled"; + } + + void TearDown() override { + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + RayConfig::instance().auth_mode() = "disabled"; + } + + struct AuthenticatedSyncerServerTest { + std::string server_port; + instrumented_io_context io_context; + boost::asio::executor_work_guard work_guard; + std::unique_ptr thread; + std::unique_ptr syncer; + std::unique_ptr service; + std::unique_ptr server; + + AuthenticatedSyncerServerTest(const std::string &port, const std::string &token) + : server_port(port), work_guard(io_context.get_executor()) { + // Setup syncer and grpc server + syncer = std::make_unique(io_context, NodeID::FromRandom().Binary()); + thread = std::make_unique([this] { io_context.run(); }); + + // Create service with authentication token + service = std::make_unique( + *syncer, + token.empty() ? std::nullopt + : std::make_optional(ray::rpc::AuthenticationToken(token))); + + auto server_address = BuildAddress("0.0.0.0", port); + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + server = builder.BuildAndStart(); + } + + ~AuthenticatedSyncerServerTest() { + server->Shutdown(); + server->Wait(); + work_guard.reset(); + io_context.stop(); + thread->join(); + } + }; + + std::unique_ptr CreateAuthenticatedServer( + const std::string &port, const std::string &token) { + return std::make_unique(port, token); + } + + // Helper struct to manage client io_context and syncer + struct ClientSyncer { + instrumented_io_context io_context; + boost::asio::executor_work_guard work_guard; + std::thread thread; + std::unique_ptr syncer; + std::string remote_node_id; + + ClientSyncer() + : work_guard(boost::asio::make_work_guard(io_context.get_executor())), + thread([this]() { io_context.run(); }) { + syncer = std::make_unique(io_context, NodeID::FromRandom().Binary()); + remote_node_id = NodeID::FromRandom().Binary(); + } + + ~ClientSyncer() { + if (syncer) { + syncer->Disconnect(remote_node_id); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + syncer.reset(); + } + work_guard.reset(); + io_context.stop(); + thread.join(); + } + + void Connect(const std::shared_ptr &channel) { + syncer->Connect(remote_node_id, channel); + } + }; +}; + +TEST_F(SyncerAuthenticationTest, MatchingTokens) { + // Test that connections succeed when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + // Set client token via environment variable + setenv("RAY_AUTH_TOKEN", test_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server + auto server = CreateAuthenticatedServer("37892", test_token); + + // Create client with separate io_context + ClientSyncer client; + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37892"), + grpc::InsecureChannelCredentials()); + + // Should connect successfully with matching token + client.Connect(channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection is established + ASSERT_GT(client.syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, MismatchedTokens) { + // Test that connections fail when client and server use different tokens + const std::string server_token = "server-token-12345"; + const std::string client_token = "different-client-token"; + + // Set client token via environment variable + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server with different token + auto server = CreateAuthenticatedServer("37893", server_token); + + // Create client with separate io_context + ClientSyncer client; + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37893"), + grpc::InsecureChannelCredentials()); + + // Should fail to connect with mismatched token + client.Connect(channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection fails - no connected nodes + ASSERT_EQ(client.syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, ServerHasTokenClientDoesNot) { + // Test that connections fail when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + + // Client has no token - auth mode is disabled (default from SetUp) + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server + auto server = CreateAuthenticatedServer("37895", server_token); + + // Create client with separate io_context + ClientSyncer client; + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37895"), + grpc::InsecureChannelCredentials()); + + // Should fail to connect without token + client.Connect(channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection fails - no connected nodes + ASSERT_EQ(client.syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, ClientHasTokenServerDoesNotRequire) { + // Test that connections succeed when client has token but server doesn't require it + const std::string server_token = ""; + const std::string client_token = "different-client-token"; + + // Set client token + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create server without authentication (empty token) + auto server = CreateAuthenticatedServer("37896", server_token); + + // Create client with separate io_context + ClientSyncer client; + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37896"), + grpc::InsecureChannelCredentials()); + + // Should connect successfully - server accepts any client when auth is not required + client.Connect(channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection is established + ASSERT_GT(client.syncer->GetAllConnectedNodeIDs().size(), 0); +} + } // namespace syncer } // namespace ray diff --git a/src/ray/raylet/BUILD.bazel b/src/ray/raylet/BUILD.bazel index 0db9a2516d0b..3090f94aa001 100644 --- a/src/ray/raylet/BUILD.bazel +++ b/src/ray/raylet/BUILD.bazel @@ -260,6 +260,7 @@ ray_cc_library( "//src/ray/raylet/scheduling:scheduler", "//src/ray/rpc:node_manager_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/stats:stats_lib", "//src/ray/util:cmd_line_utils", "//src/ray/util:container_util", diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 9e2a45641ba8..a916676b9227 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -50,6 +50,7 @@ #include "ray/raylet/worker_killing_policy_group_by_owner.h" #include "ray/raylet/worker_pool.h" #include "ray/raylet_ipc_client/client_connection.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/stats/metric_defs.h" #include "ray/util/cmd_line_utils.h" #include "ray/util/event.h" @@ -254,8 +255,9 @@ NodeManager::NodeManager( // Run the node manager rpc server. node_manager_server_.RegisterService( std::make_unique(io_service, *this), false); - node_manager_server_.RegisterService( - std::make_unique(ray_syncer_)); + // Pass auth token from the RPC server to the syncer service + node_manager_server_.RegisterService(std::make_unique( + ray_syncer_, ray::rpc::AuthenticationTokenLoader::instance().GetToken())); node_manager_server_.Run(); // GCS will check the health of the service named with the node id. // Fail to setup this will lead to the health check failure.