From 021df65da32a9d37709a88a5b065cec6fce2c5e5 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 16 Oct 2025 08:32:37 +0000 Subject: [PATCH 01/94] [Core] Authentication for ray core rpc calls - part 1 Signed-off-by: sampan --- src/ray/common/constants.h | 1 + src/ray/common/ray_config_def.h | 7 + src/ray/core_worker/grpc_service.cc | 72 ++++--- src/ray/gcs/grpc_services.cc | 2 +- src/ray/rpc/client_call.h | 26 ++- src/ray/rpc/grpc_server.h | 42 ++-- .../rpc/node_manager/node_manager_server.h | 3 +- src/ray/rpc/object_manager_server.h | 3 +- src/ray/rpc/server_call.h | 75 +++++-- src/ray/rpc/tests/grpc_bench/grpc_bench.cc | 2 +- src/ray/rpc/tests/grpc_server_client_test.cc | 183 +++++++++++++++++- 11 files changed, 341 insertions(+), 75 deletions(-) diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index aa9d858f2811..4905daedc733 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,6 +42,7 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; +constexpr char kAuthTokenKey[] = "ray_auth_token"; constexpr char kWorkerDynamicOptionPlaceholder[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"; diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 6e8d21956162..25cd4ddc5487 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -35,6 +35,13 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) /// Whether to enable cluster authentication. RAY_CONFIG(bool, enable_cluster_auth, true) +/// Whether to enable token-based authentication for RPC calls. +RAY_CONFIG(bool, enable_token_auth, true) + +/// Authentication token for RPC calls. If empty, token auth is effectively disabled +/// even if enable_token_auth is true. +RAY_CONFIG(std::string, auth_token, "") + /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available /// in the associated process's log file. diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index a3a550dbaa55..2f1cd9574787 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -26,88 +26,100 @@ void CoreWorkerGrpcService::InitServerCallFactories( const ClusterID &cluster_id) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, PushTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + PushTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ActorCallArgWaitComplete, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RayletNotifyGCSRestart, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectStatus, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, WaitForActorRefDeleted, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubLongPolling, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubCommandBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, UpdateObjectLocationBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectLocationsOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ReportGeneratorItemReturns, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, KillActor, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + KillActor, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + CancelTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RemoteCancelTask, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RegisterMutableObjectReader, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetCoreWorkerStats, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, LocalGC, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, DeleteObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, SpillObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + LocalGC, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + DeleteObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + SpillObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RestoreSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, DeleteSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PlasmaObjectReady, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, Exit, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + CoreWorkerService, Exit, max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, AssignObjectOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, NumPendingTasks, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); } } // namespace rpc diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index f1f3c55af3f1..72113b3be474 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -48,7 +48,7 @@ void NodeInfoGrpcService::InitServerCallFactories( RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, GetClusterId, max_active_rpcs_per_handler_, - AuthType::EMPTY_AUTH); + ClusterIdAuthType::EMPTY_AUTH); RPC_SERVICE_HANDLER(NodeInfoGcsService, RegisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, UnregisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, DrainNode, max_active_rpcs_per_handler_) diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 1808409c6dc7..ab14e1314fe6 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -27,6 +27,7 @@ #include "absl/synchronization/mutex.h" #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" @@ -69,8 +70,14 @@ class ClientCallImpl : public ClientCall { /// Constructor. /// /// \param[in] callback The callback function to handle the reply. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token (empty = disabled). + /// \param[in] stats_handle Statistics handle for this call. + /// \param[in] record_stats Whether to record statistics. + /// \param[in] timeout_ms The timeout for this call in milliseconds. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, + const std::string &auth_token, std::shared_ptr stats_handle, bool record_stats, int64_t timeout_ms = -1) @@ -85,6 +92,10 @@ class ClientCallImpl : public ClientCall { if (!cluster_id.IsNil()) { context_.AddMetadata(kClusterIdKey, cluster_id.Hex()); } + // Add authentication token if provided (empty = disabled) + if (!auth_token.empty()) { + context_.AddMetadata(kAuthTokenKey, auth_token); + } } Status GetStatus() override { @@ -213,6 +224,9 @@ class ClientCallManager { int num_threads = 1, int64_t call_timeout_ms = -1) : cluster_id_(cluster_id), + auth_token_(::RayConfig::instance().enable_token_auth() + ? ::RayConfig::instance().auth_token() + : ""), main_service_(main_service), num_threads_(num_threads), record_stats_(record_stats), @@ -271,8 +285,12 @@ class ClientCallManager { method_timeout_ms = call_timeout_ms_; } - auto call = std::make_shared>( - callback, cluster_id_, std::move(stats_handle), record_stats_, method_timeout_ms); + auto call = std::make_shared>(callback, + cluster_id_, + auth_token_, + std::move(stats_handle), + record_stats_, + method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( @@ -356,6 +374,10 @@ class ClientCallManager { /// and setting the cluster ID. ClusterID cluster_id_; + /// Cached authentication token for token-based authentication. + /// Empty string means no token authentication. + const std::string auth_token_; + /// The main event loop, to which the callback functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 0727b4d550f3..68edec84b574 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -30,33 +30,35 @@ namespace ray { namespace rpc { /// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no /// limit. -#define _RPC_SERVICE_HANDLER( \ - SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, \ - &SERVICE::AsyncService::Request##HANDLER, \ - service_handler_, \ - &SERVICE##Handler::Handle##HANDLER, \ - cq, \ - main_service_, \ - #SERVICE ".grpc_server." #HANDLER, \ - AUTH_TYPE == AuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ - MAX_ACTIVE_RPCS, \ - RECORD_METRICS)); \ +#define _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, \ + &SERVICE::AsyncService::Request##HANDLER, \ + service_handler_, \ + &SERVICE##Handler::Handle##HANDLER, \ + cq, \ + main_service_, \ + #SERVICE ".grpc_server." #HANDLER, \ + AUTH_TYPE == ClusterIdAuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ + MAX_ACTIVE_RPCS, \ + RECORD_METRICS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, true) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, true) /// Define a RPC service handler with gRPC server metrics disabled. #define RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, false) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, false) /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER_CUSTOM_AUTH(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE) \ diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index eba08ba9b0af..2ca9713e0f17 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -29,7 +29,8 @@ class ServerCallFactory; /// TODO(vitsai): Remove this when auth is implemented for node manager #define RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + NodeManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. #define RAY_NODE_MANAGER_RPC_HANDLERS \ diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index 4d294b483fff..b19f2d540db2 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -28,7 +28,8 @@ namespace rpc { class ServerCallFactory; #define RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(ObjectManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + ObjectManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) #define RAY_OBJECT_MANAGER_RPC_HANDLERS \ RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(Push) \ diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 0d0779525ea6..16142b2f98f9 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -24,6 +24,7 @@ #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" @@ -34,8 +35,8 @@ namespace ray { namespace rpc { -// Authentication type of ServerCall. -enum class AuthType { +// Cluster ID authentication type of ServerCall. +enum class ClusterIdAuthType { NO_AUTH, // Do not authenticate (accept all). LAZY_AUTH, // Accept missing cluster ID, but reject incorrect one. EMPTY_AUTH, // Accept only empty cluster ID. @@ -149,7 +150,7 @@ using HandleRequestFunction = void (ServiceHandler::*)(Request, template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallImpl : public ServerCall { public: /// Constructor. @@ -159,6 +160,7 @@ class ServerCallImpl : public ServerCall { /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. /// \param[in] preprocess_function If not nullptr, it will be called before handling /// request. @@ -179,6 +181,9 @@ class ServerCallImpl : public ServerCall { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(::RayConfig::instance().enable_token_auth() + ? ::RayConfig::instance().auth_token() + : ""), start_time_(0), record_metrics_(record_metrics) { reply_ = google::protobuf::Arena::CreateMessage(&arena_); @@ -194,8 +199,24 @@ class ServerCallImpl : public ServerCall { void HandleRequest() override { stats_handle_ = io_service_.stats().RecordStart(call_name_); bool auth_success = true; + bool token_auth_failed = false; + bool cluster_id_auth_failed = false; + + // Token authentication + // Empty token = no authentication required + if (!auth_token_.empty()) { + auto &metadata = context_.client_metadata(); + auto it = metadata.find(kAuthTokenKey); + if (it == metadata.end() || it->second != auth_token_) { + RAY_LOG(WARNING) << "Invalid or missing auth token in request!"; + auth_success = false; + token_auth_failed = true; + } + } + + // Cluster ID authentication if (::RayConfig::instance().enable_cluster_auth()) { - if constexpr (EnableAuth == AuthType::LAZY_AUTH) { + if constexpr (EnableAuth == ClusterIdAuthType::LAZY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -203,8 +224,9 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Wrong cluster ID token in request! Expected: " << cluster_id_.Hex() << ", but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } - } else if constexpr (EnableAuth == AuthType::EMPTY_AUTH) { + } else if constexpr (EnableAuth == ClusterIdAuthType::EMPTY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -212,6 +234,7 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Cluster ID token in request! Expected Nil, " << "but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } } } @@ -221,24 +244,32 @@ class ServerCallImpl : public ServerCall { ray::stats::STATS_grpc_server_req_handling.Record(1.0, call_name_); } if (!io_service_.stopped()) { - io_service_.post([this, auth_success] { HandleRequestImpl(auth_success); }, - call_name_ + ".HandleRequestImpl", - // Implement the delay of the rpc server call as the - // delay of HandleRequestImpl(). - ray::asio::testing::GetDelayUs(call_name_)); + io_service_.post( + [this, auth_success, token_auth_failed, cluster_id_auth_failed] { + HandleRequestImpl(auth_success, token_auth_failed, cluster_id_auth_failed); + }, + call_name_ + ".HandleRequestImpl", + // Implement the delay of the rpc server call as the + // delay of HandleRequestImpl(). + ray::asio::testing::GetDelayUs(call_name_)); } else { // Handle service for rpc call has stopped, we must handle the call here // to send reply and remove it from cq RAY_LOG(DEBUG) << "Handle service has been closed."; if (auth_success) { SendReply(Status::Invalid("HandleServiceClosed")); + } else if (token_auth_failed) { + SendReply(Status::AuthError( + "InvalidAuthToken: Authentication token is missing or incorrect")); } else { SendReply(Status::AuthError("WrongClusterID")); } } } - void HandleRequestImpl(bool auth_success) { + void HandleRequestImpl(bool auth_success, + bool token_auth_failed, + bool cluster_id_auth_failed) { if constexpr (std::is_base_of_v) { if (!service_handler_initialized_) { service_handler_.WaitUntilInitialized(); @@ -254,10 +285,15 @@ class ServerCallImpl : public ServerCall { factory_.CreateCall(); } if (!auth_success) { - boost::asio::post(GetServerCallExecutor(), [this]() { - SendReply( - Status::AuthError("WrongClusterID: Perhaps the client is accessing GCS " - "after it has restarted.")); + boost::asio::post(GetServerCallExecutor(), [this, token_auth_failed]() { + if (token_auth_failed) { + SendReply(Status::AuthError( + "InvalidAuthToken: Authentication token is missing or incorrect")); + } else { + SendReply( + Status::AuthError("WrongClusterID: Perhaps the client is accessing GCS " + "after it has restarted.")); + } }); } else { (service_handler_.*handle_request_function_)( @@ -373,6 +409,10 @@ class ServerCallImpl : public ServerCall { /// Check skipped if empty. const ClusterID &cluster_id_; + /// Authentication token for token-based authentication. + /// Empty string means no token authentication. + const std::string auth_token_; + /// The callback when sending reply successes. std::function send_reply_success_callback_ = nullptr; @@ -385,7 +425,7 @@ class ServerCallImpl : public ServerCall { /// If true, the server call will generate gRPC server metrics. bool record_metrics_; - template + template friend class ServerCallFactoryImpl; }; @@ -413,7 +453,7 @@ template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallFactoryImpl : public ServerCallFactory { using AsyncService = typename GrpcService::AsyncService; @@ -428,6 +468,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] cq The `CompletionQueue`. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. /// \param[in] max_active_rpcs Maximum request number to handle at the same time. -1 /// means no limit. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. diff --git a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc index 552b83bff3bc..6bdf24c26106 100644 --- a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc +++ b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc @@ -59,7 +59,7 @@ class GreeterGrpcService : public GrpcService { std::vector> *server_call_factories, const ClusterID &cluster_id) override{ RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - Greeter, SayHello, -1, AuthType::NO_AUTH)} + Greeter, SayHello, -1, ClusterIdAuthType::NO_AUTH)} /// The grpc async service object. Greeter::AsyncService service_; diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index f51a80b99f73..15396d1b681f 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -89,9 +89,9 @@ class TestGrpcService : public GrpcService { std::vector> *server_call_factories, const ClusterID &cluster_id) override { RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, Ping, /*max_active_rpcs=*/1, AuthType::NO_AUTH); + TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, PingTimeout, /*max_active_rpcs=*/1, AuthType::NO_AUTH); + TestService, PingTimeout, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); } private: @@ -326,6 +326,185 @@ TEST_F(TestGrpcServerClientFixture, TestTimeoutMacro) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } +class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { + public: + void SetUpServerWithConfig(const std::string &config_json) { + RayConfig::instance().initialize(config_json); + + // Start handler thread + handler_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + handler_io_service_work_(handler_io_service_.get_executor()); + handler_io_service_.run(); + }); + + // Create and start server + grpc_server_.reset(new GrpcServer("test", 0, true)); + grpc_server_->RegisterService( + std::make_unique(handler_io_service_, test_service_handler_), + false); + grpc_server_->Run(); + + while (grpc_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + void SetUpClientWithConfig(const std::string &config_json) { + // Reconfigure for client (allows different token) + RayConfig::instance().initialize(config_json); + + // Start client thread + client_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + client_io_service_work_(client_io_service_.get_executor()); + client_io_service_.run(); + }); + + // Create client + client_call_manager_.reset( + new ClientCallManager(client_io_service_, false, /*local_address=*/"")); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); + } + + void TearDown() override { + if (grpc_client_) { + grpc_client_.reset(); + } + if (client_call_manager_) { + client_call_manager_.reset(); + } + if (client_thread_) { + client_io_service_.stop(); + if (client_thread_->joinable()) { + client_thread_->join(); + } + } + + if (grpc_server_) { + grpc_server_->Shutdown(); + } + if (handler_thread_) { + handler_io_service_.stop(); + if (handler_thread_->joinable()) { + handler_thread_->join(); + } + } + } + + // Helper to execute RPC and wait for result + struct PingResult { + bool completed; + bool success; + std::string error_msg; + }; + + PingResult ExecutePingAndWait() { + PingRequest request; + std::atomic done(false); + bool success = false; + std::string error_msg; + + Ping(request, + [&done, &success, &error_msg](const Status &status, const PingReply &reply) { + RAY_LOG(INFO) << "Token auth test replied, status=" << status; + success = status.ok(); + if (!status.ok()) { + error_msg = status.message(); + } + done = true; + }); + + // Wait for response with timeout + auto start = std::chrono::steady_clock::now(); + while (!done && std::chrono::steady_clock::now() - start < std::chrono::seconds(5)) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + return {done, success, error_msg}; + } + + protected: + VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) + + TestServiceHandler test_service_handler_; + instrumented_io_context handler_io_service_; + std::unique_ptr handler_thread_; + std::unique_ptr grpc_server_; + + instrumented_io_context client_io_service_; + std::unique_ptr client_thread_; + std::unique_ptr client_call_manager_; + std::unique_ptr> grpc_client_; +}; + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { + // Both server and client have the same token + const std::string config = + R"({"enable_token_auth": true, "auth_token": "test_secret_token_123"})"; + SetUpServerWithConfig(config); + SetUpClientWithConfig(config); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) << "Request should succeed with matching token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { + // Server and client have different tokens + SetUpServerWithConfig(R"({"enable_token_auth": true, "auth_token": "server_token"})"); + SetUpClientWithConfig( + R"({"enable_token_auth": true, "auth_token": "wrong_client_token"})"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_FALSE(result.success) << "Request should fail with wrong token"; + ASSERT_TRUE(result.error_msg.find("InvalidAuthToken") != std::string::npos || + result.error_msg.find("Authentication") != std::string::npos) + << "Error message should indicate auth failure, got: " << result.error_msg; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { + // Server expects token, client doesn't send one (empty token) + SetUpServerWithConfig(R"({"enable_token_auth": true, "auth_token": "server_token"})"); + SetUpClientWithConfig(R"({"enable_token_auth": true, "auth_token": ""})"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_FALSE(result.success) << "Request should fail when token is missing"; + ASSERT_TRUE(result.error_msg.find("InvalidAuthToken") != std::string::npos || + result.error_msg.find("Authentication") != std::string::npos) + << "Error message should indicate auth failure, got: " << result.error_msg; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthDisabled) { + // Token auth disabled, should succeed regardless + const std::string config = R"({"enable_token_auth": false})"; + SetUpServerWithConfig(config); + SetUpClientWithConfig(config); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) << "Request should succeed when token auth is disabled"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestEmptyTokenNoEnforcement) { + // Empty token with token auth enabled should not enforce + const std::string config = R"({"enable_token_auth": true, "auth_token": ""})"; + SetUpServerWithConfig(config); + SetUpClientWithConfig(config); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) + << "Request should succeed with empty token (no enforcement)"; +} } // namespace rpc } // namespace ray From c96d1f4319e468bcc4b60cef45ea8be058266b9e Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 05:28:56 +0000 Subject: [PATCH 02/94] [Core] Token auth improvements - C++ RayAuthTokenLoader singleton - Created RayAuthTokenLoader singleton class with thread-safe token caching - Loads tokens from RAY_AUTH_TOKEN env, RAY_AUTH_TOKEN_PATH, or ~/.ray/auth_token - Support for token generation with UUID (cross-platform) - Modified GrpcServer to store and pass auth token to ServerCallImpl - Updated RPC_SERVICE_HANDLER macros to pass auth token - GCS server now loads token using RayAuthTokenLoader - Removed auth_token from RayConfig (now loaded via loader) - Token precedence: env var -> path env var -> default file path Signed-off-by: sampan --- src/ray/common/ray_config_def.h | 7 +- src/ray/gcs/gcs_server.cc | 7 ++ src/ray/rpc/auth_token_loader.cc | 198 +++++++++++++++++++++++++++++++ src/ray/rpc/auth_token_loader.h | 69 +++++++++++ src/ray/rpc/grpc_server.cc | 9 +- src/ray/rpc/grpc_server.h | 12 +- src/ray/rpc/server_call.h | 13 +- 7 files changed, 302 insertions(+), 13 deletions(-) create mode 100644 src/ray/rpc/auth_token_loader.cc create mode 100644 src/ray/rpc/auth_token_loader.h diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 25cd4ddc5487..ec904fc213e9 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -36,11 +36,8 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) RAY_CONFIG(bool, enable_cluster_auth, true) /// Whether to enable token-based authentication for RPC calls. -RAY_CONFIG(bool, enable_token_auth, true) - -/// Authentication token for RPC calls. If empty, token auth is effectively disabled -/// even if enable_token_auth is true. -RAY_CONFIG(std::string, auth_token, "") +/// The authentication token is loaded via RayAuthTokenLoader. +RAY_CONFIG(bool, enable_token_auth, false) /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 1d93198b5f36..acd6d262216f 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -39,6 +39,7 @@ #include "ray/observability/metric_constants.h" #include "ray/pubsub/publisher.h" #include "ray/raylet_rpc_client/raylet_client.h" +#include "ray/rpc/auth_token_loader.h" #include "ray/stats/stats.h" #include "ray/util/network_util.h" @@ -218,6 +219,12 @@ void GcsServer::Start() { GetOrGenerateClusterId( {[this, gcs_init_data](ClusterID cluster_id) { rpc_server_.SetClusterId(cluster_id); + // Load and set authentication token if enabled + if (RayConfig::instance().enable_token_auth()) { + rpc_server_.SetAuthToken( + rpc::RayAuthTokenLoader::instance().GetToken( + false)); + } DoStart(*gcs_init_data); }, io_context_provider_.GetDefaultIOContext()}); diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc new file mode 100644 index 000000000000..ea48a1c0ebc2 --- /dev/null +++ b/src/ray/rpc/auth_token_loader.cc @@ -0,0 +1,198 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/auth_token_loader.h" + +#include +#include +#include + +#include "ray/util/logging.h" +#include "ray/util/util.h" + +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace ray { +namespace rpc { + +RayAuthTokenLoader &RayAuthTokenLoader::instance() { + static RayAuthTokenLoader instance; + return instance; +} + +const std::string &RayAuthTokenLoader::GetToken(bool generate_if_not_found) { + std::lock_guard lock(token_mutex_); + + if (token_loaded_) { + return cached_token_; + } + + // Try to load from sources + cached_token_ = LoadTokenFromSources(); + token_loaded_ = true; + + // If not found and generation is requested, generate a new token + if (cached_token_.empty() && generate_if_not_found) { + cached_token_ = GenerateToken(); + } + + return cached_token_; +} + +bool RayAuthTokenLoader::HasToken() { + // This will trigger loading if not already loaded + const std::string &token = GetToken(false); + return !token.empty(); +} + +std::string RayAuthTokenLoader::LoadTokenFromSources() { + // Precedence 1: RAY_AUTH_TOKEN environment variable + const char *env_token = std::getenv("RAY_AUTH_TOKEN"); + if (env_token != nullptr && std::string(env_token).length() > 0) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return std::string(env_token); + } + + // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable + const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); + if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { + std::ifstream token_file(env_token_path); + if (token_file.is_open()) { + std::string token; + std::getline(token_file, token); + token_file.close(); + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t\n\r\f\v")); + token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1); + if (!token.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + return token; + } + } else { + RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " + << env_token_path; + } + } + + // Precedence 3: Default token path ~/.ray/auth_token + std::string default_path = GetDefaultTokenPath(); + std::ifstream token_file(default_path); + if (token_file.is_open()) { + std::string token; + std::getline(token_file, token); + token_file.close(); + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t\n\r\f\v")); + token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1); + if (!token.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " + << default_path; + return token; + } + } + + // No token found + RAY_LOG(DEBUG) << "No authentication token found in any source"; + return ""; +} + +std::string RayAuthTokenLoader::GenerateToken() { + // Generate a UUID-like token using random hex string + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + + const char *hex_chars = "0123456789abcdef"; + std::stringstream ss; + + // Generate a 32-character hex string (similar to UUID without dashes) + for (int i = 0; i < 32; i++) { + ss << hex_chars[dis(gen)]; + } + + std::string token = ss.str(); + std::string token_path = GetDefaultTokenPath(); + + // Try to save the token to the default path + try { + // Create directory if it doesn't exist + std::string dir_path = token_path.substr(0, token_path.find_last_of("/\\")); +#ifdef _WIN32 + CreateDirectoryA(dir_path.c_str(), NULL); +#else + mkdir(dir_path.c_str(), 0700); +#endif + + // Write token to file + std::ofstream token_file(token_path, std::ios::trunc); + if (token_file.is_open()) { + token_file << token; + token_file.close(); + +#ifndef _WIN32 + // Set file permissions to 0600 on Unix systems + chmod(token_path.c_str(), S_IRUSR | S_IWUSR); +#endif + + RAY_LOG(INFO) << "Generated new authentication token and saved to " + << token_path; + } else { + RAY_LOG(WARNING) << "Failed to save generated token to " << token_path + << ". Token will only be available in memory."; + } + } catch (const std::exception &e) { + RAY_LOG(WARNING) << "Exception while saving token: " << e.what(); + } + + return token; +} + +std::string RayAuthTokenLoader::GetDefaultTokenPath() { + std::string home_dir; + +#ifdef _WIN32 + const char *userprofile = std::getenv("USERPROFILE"); + if (userprofile != nullptr) { + home_dir = userprofile; + } else { + const char *homedrive = std::getenv("HOMEDRIVE"); + const char *homepath = std::getenv("HOMEPATH"); + if (homedrive != nullptr && homepath != nullptr) { + home_dir = std::string(homedrive) + std::string(homepath); + } + } +#else + const char *home = std::getenv("HOME"); + if (home != nullptr) { + home_dir = home; + } +#endif + + if (home_dir.empty()) { + RAY_LOG(WARNING) << "Cannot determine home directory for token storage"; + return ".ray/auth_token"; + } + + return home_dir + "/.ray/auth_token"; +} + +} // namespace rpc +} // namespace ray + diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h new file mode 100644 index 000000000000..c63596772082 --- /dev/null +++ b/src/ray/rpc/auth_token_loader.h @@ -0,0 +1,69 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ray { +namespace rpc { + +/// Singleton class for loading and caching authentication tokens. +/// Supports loading tokens from multiple sources with precedence: +/// 1. RAY_AUTH_TOKEN environment variable +/// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) +/// 3. Default token path: ~/.ray/auth_token +/// +/// Thread-safe with internal caching to avoid repeated file I/O. +class RayAuthTokenLoader { + public: + /// Get the singleton instance. + static RayAuthTokenLoader &instance(); + + /// Get the authentication token. + /// \param generate_if_not_found If true, generate and save a new token if not found. + /// \return The authentication token, or empty string if not found and generation is + /// disabled. + const std::string &GetToken(bool generate_if_not_found = false); + + /// Check if an authentication token exists. + /// \return True if a token is available (cached or can be loaded). + bool HasToken(); + + // Prevent copying and moving + RayAuthTokenLoader(const RayAuthTokenLoader &) = delete; + RayAuthTokenLoader &operator=(const RayAuthTokenLoader &) = delete; + + private: + RayAuthTokenLoader() = default; + ~RayAuthTokenLoader() = default; + + /// Load token from available sources (env vars and file). + std::string LoadTokenFromSources(); + + /// Generate a new UUID token and save it to the default path. + std::string GenerateToken(); + + /// Get the default token file path (~/.ray/auth_token). + std::string GetDefaultTokenPath(); + + std::mutex token_mutex_; + std::string cached_token_; + bool token_loaded_ = false; +}; + +} // namespace rpc +} // namespace ray + diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 542326de0bce..ff03df9e7104 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -179,11 +179,12 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool token_auth) { + if (token_auth && cluster_id_.IsNil()) { + RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + } for (int i = 0; i < num_threads_; i++) { - if (token_auth && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; - } - service->InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_); + service->InitServerCallFactories( + cqs_[i], &server_call_factories_, cluster_id_, auth_token_); } services_.push_back(std::move(service)); } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 68edec84b574..13a56b8c630e 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -46,6 +46,7 @@ namespace rpc { main_service_, \ #SERVICE ".grpc_server." #HANDLER, \ AUTH_TYPE == ClusterIdAuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ + auth_token, \ MAX_ACTIVE_RPCS, \ RECORD_METRICS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); @@ -141,6 +142,10 @@ class GrpcServer { cluster_id_ = cluster_id; } + void SetAuthToken(const std::string &auth_token) { auth_token_ = auth_token; } + + const std::string &GetAuthToken() const { return auth_token_; } + protected: /// Initialize this server. void Init(); @@ -159,6 +164,8 @@ class GrpcServer { const bool listen_to_localhost_only_; /// Token representing ID of this cluster. ClusterID cluster_id_; + /// Authentication token for token-based authentication. + std::string auth_token_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. @@ -210,10 +217,13 @@ class GrpcService { /// \param[in] cq The grpc completion queue. /// \param[out] server_call_factories The `ServerCallFactory` objects, /// and the maximum number of concurrent requests that this gRPC server can handle. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. virtual void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) = 0; + const ClusterID &cluster_id, + const std::string &auth_token) = 0; /// The main event loop, to which the service handler functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 16142b2f98f9..79e2c58179d0 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -161,6 +161,7 @@ class ServerCallImpl : public ServerCall { /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. /// \param[in] preprocess_function If not nullptr, it will be called before handling /// request. @@ -171,6 +172,7 @@ class ServerCallImpl : public ServerCall { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::string &auth_token, bool record_metrics, std::function preprocess_function = nullptr) : state_(ServerCallState::PENDING), @@ -181,9 +183,7 @@ class ServerCallImpl : public ServerCall { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), - auth_token_(::RayConfig::instance().enable_token_auth() - ? ::RayConfig::instance().auth_token() - : ""), + auth_token_(auth_token), start_time_(0), record_metrics_(record_metrics) { reply_ = google::protobuf::Arena::CreateMessage(&arena_); @@ -469,6 +469,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] max_active_rpcs Maximum request number to handle at the same time. -1 /// means no limit. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. @@ -481,6 +482,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::string &auth_token, int64_t max_active_rpcs, bool record_metrics) : service_(service), @@ -491,6 +493,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(auth_token), max_active_rpcs_(max_active_rpcs), record_metrics_(record_metrics) {} @@ -504,6 +507,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_, call_name_, cluster_id_, + auth_token_, record_metrics_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. @@ -543,6 +547,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// Check skipped if empty. const ClusterID cluster_id_; + /// Authentication token for token-based authentication. + const std::string auth_token_; + /// Maximum request number to handle at the same time. /// -1 means no limit. uint64_t max_active_rpcs_; From 91f783e91cc92e35357b55ba75f9f1421bb6a5e7 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 05:29:23 +0000 Subject: [PATCH 03/94] [Core] Token auth improvements - Python token loader and CLI - Created Python auth_token_loader module with thread-safe token caching - Loads tokens from same precedence as C++: RAY_AUTH_TOKEN, RAY_AUTH_TOKEN_PATH, ~/.ray/auth_token - Added enable_token_auth parameter to ray.init() with auto-generation support - Added --enable-token-auth flag to ray start CLI (fails if no token found) - Only pass enable_token_auth flag via system_config, not the token - Each side (C++/Python) loads tokens independently using their own loaders - ray.init() auto-generates token if not found, ray start fails with helpful error Signed-off-by: sampan --- python/ray/_private/auth_token_loader.py | 168 +++++++++++++++++++++++ python/ray/_private/worker.py | 17 +++ python/ray/scripts/scripts.py | 27 ++++ 3 files changed, 212 insertions(+) create mode 100644 python/ray/_private/auth_token_loader.py diff --git a/python/ray/_private/auth_token_loader.py b/python/ray/_private/auth_token_loader.py new file mode 100644 index 000000000000..008861f3a849 --- /dev/null +++ b/python/ray/_private/auth_token_loader.py @@ -0,0 +1,168 @@ +"""Authentication token loader for Ray. + +This module provides functions to load, generate, and cache authentication tokens +for Ray's token-based authentication system. Tokens are loaded with the following +precedence: +1. RAY_AUTH_TOKEN environment variable +2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) +3. Default token path: ~/.ray/auth_token +""" + +import logging +import os +import threading +import uuid +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +# Module-level cached variables +_cached_token: Optional[str] = None +_token_lock = threading.Lock() + + +def load_auth_token(generate_if_not_found: bool = False) -> str: + """Load the authentication token with caching. + + This function loads the token from available sources with the following precedence: + 1. RAY_AUTH_TOKEN environment variable + 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) + 3. Default token path: ~/.ray/auth_token + + The token is cached after the first successful load to avoid repeated file I/O. + + Args: + generate_if_not_found: If True, generate and save a new token if not found. + If False, return empty string if no token is found. + + Returns: + The authentication token, or empty string if not found and generation is disabled. + """ + global _cached_token + + with _token_lock: + # Return cached token if already loaded + if _cached_token is not None: + return _cached_token + + # Try to load from sources + token = _load_token_from_sources() + + # Generate if requested and not found + if not token and generate_if_not_found: + token = _generate_and_save_token() + + # Cache the result (even if empty) + _cached_token = token + return _cached_token + + +def has_auth_token() -> bool: + """Check if an authentication token exists. + + Returns: + True if a token is available (cached or can be loaded), False otherwise. + """ + token = load_auth_token(generate_if_not_found=False) + return bool(token) + + +def _load_token_from_sources() -> str: + """Load token from available sources (env vars and file). + + Returns: + The authentication token, or empty string if not found. + """ + # Precedence 1: RAY_AUTH_TOKEN environment variable + env_token = os.environ.get("RAY_AUTH_TOKEN", "").strip() + if env_token: + logger.debug( + "Loaded authentication token from RAY_AUTH_TOKEN environment variable" + ) + return env_token + + # Precedence 2: RAY_AUTH_TOKEN_PATH environment variable + env_token_path = os.environ.get("RAY_AUTH_TOKEN_PATH", "").strip() + if env_token_path: + try: + token_path = Path(env_token_path).expanduser() + if token_path.exists(): + token = token_path.read_text().strip() + if token: + logger.debug(f"Loaded authentication token from file: {token_path}") + return token + else: + logger.warning( + f"RAY_AUTH_TOKEN_PATH is set but file does not exist: {token_path}" + ) + except Exception as e: + logger.warning( + f"Failed to read token from RAY_AUTH_TOKEN_PATH ({env_token_path}): {e}" + ) + + # Precedence 3: Default token path ~/.ray/auth_token + default_path = _get_default_token_path() + try: + if default_path.exists(): + token = default_path.read_text().strip() + if token: + logger.debug( + f"Loaded authentication token from default path: {default_path}" + ) + return token + except Exception as e: + logger.debug(f"Failed to read token from default path ({default_path}): {e}") + + # No token found + logger.debug("No authentication token found in any source") + return "" + + +def _generate_and_save_token() -> str: + """Generate a new UUID token and save it to the default path. + + Returns: + The newly generated authentication token. + """ + # Generate a UUID-based token + token = uuid.uuid4().hex + + # Try to save the token to the default path + token_path = _get_default_token_path() + try: + # Create directory if it doesn't exist + token_path.parent.mkdir(parents=True, exist_ok=True) + + # Write token to file + token_path.write_text(token) + + # Set file permissions to 0600 on Unix systems + try: + # This will work on Unix systems, but not on Windows + os.chmod(token_path, 0o600) + except (OSError, AttributeError): + # chmod may not work on Windows or may fail for other reasons + # This is not critical, so we just log a debug message + logger.debug( + f"Could not set file permissions to 0600 for {token_path}. " + "This is expected on Windows." + ) + + logger.info(f"Generated new authentication token and saved to {token_path}") + except Exception as e: + logger.warning( + f"Failed to save generated token to {token_path}: {e}. " + "Token will only be available in memory." + ) + + return token + + +def _get_default_token_path() -> Path: + """Get the default token file path (~/.ray/auth_token). + + Returns: + Path object pointing to ~/.ray/auth_token + """ + return Path.home() / ".ray" / "auth_token" diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 6b015cdcdc81..2d3dd00923b3 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -1448,6 +1448,7 @@ def init( enable_resource_isolation: bool = False, system_reserved_cpu: Optional[float] = None, system_reserved_memory: Optional[int] = None, + enable_token_auth: bool = False, **kwargs, ) -> BaseContext: """ @@ -1569,6 +1570,10 @@ def init( By default, the min of 10% and 25GB plus object_store_memory will be reserved. Must be >= 100MB and system_reserved_memory + object_store_bytes < total available memory. This option only works if enable_resource_isolation is True. + enable_token_auth: If True, enable token-based authentication for Ray cluster + communication. If no token is found in the environment (RAY_AUTH_TOKEN or + RAY_AUTH_TOKEN_PATH) or default path (~/.ray/auth_token), a new token will be + automatically generated and saved to ~/.ray/auth_token. _cgroup_path: The path for the cgroup the raylet should use to enforce resource isolation. By default, the cgroup used for resource isolation will be /sys/fs/cgroup. The raylet must have read/write permissions to this path. @@ -1665,6 +1670,18 @@ def init( # Fix for https://github.com/ray-project/ray/issues/26729 _skip_env_hook: bool = kwargs.pop("_skip_env_hook", False) + # Handle token-based authentication + if enable_token_auth: + from ray._private.auth_token_loader import load_auth_token + + # Load or generate token + _token = load_auth_token(generate_if_not_found=True) + # Only pass the flag via system_config, NOT the token + # C++ will load token using its own RayAuthTokenLoader + if _system_config is None: + _system_config = {} + _system_config["enable_token_auth"] = "true" + resource_isolation_config = ResourceIsolationConfig( enable_resource_isolation=enable_resource_isolation, cgroup_path=_cgroup_path, diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 3c1dffe3b8f9..a5b02d01d170 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -674,6 +674,14 @@ def debug(address: str, verbose: bool): "Cgroup memory and cpu controllers be enabled for this cgroup. " "This option only works if --enable-resource-isolation is set.", ) +@click.option( + "--enable-token-auth", + is_flag=True, + default=False, + help="Enable token-based authentication. Requires an existing token from " + "environment variables (RAY_AUTH_TOKEN or RAY_AUTH_TOKEN_PATH) or ~/.ray/auth_token. " + "Use ray.init(enable_token_auth=True) to auto-generate a token.", +) @add_click_logging_options @PublicAPI def start( @@ -722,6 +730,7 @@ def start( system_reserved_cpu, system_reserved_memory, cgroup_path, + enable_token_auth, ): """Start Ray processes manually on the local machine.""" @@ -784,6 +793,24 @@ def start( system_reserved_memory=system_reserved_memory, ) + # Handle token-based authentication + if enable_token_auth: + from ray._private.auth_token_loader import load_auth_token + + # Try to load token (don't generate in CLI) + token = load_auth_token(generate_if_not_found=False) + if not token: + cli_logger.abort( + "Token authentication is enabled but no token found. " + "Please set RAY_AUTH_TOKEN environment variable, " + "RAY_AUTH_TOKEN_PATH, or create ~/.ray/auth_token. " + "Alternatively, use ray.init(enable_token_auth=True) to auto-generate a token." + ) + # Only pass the flag, not the token + if system_config is None: + system_config = {} + system_config["enable_token_auth"] = "true" + redirect_output = None if not no_redirect_output else True # no client, no port -> ok From 54d4eac7cd20ab948a383a8cdadfffffc1e6e924 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 05:31:56 +0000 Subject: [PATCH 04/94] [Core][Tests] Add unit tests for RayAuthTokenLoader - Test token loading from RAY_AUTH_TOKEN environment variable - Test token loading from RAY_AUTH_TOKEN_PATH file - Test token loading from default ~/.ray/auth_token path - Test precedence order (env var > path env var > default file) - Test token generation with GetToken(true) - Test token caching behavior - Test thread safety with concurrent GetToken calls - Test whitespace trimming from token files - Test behavior when no token is found Signed-off-by: sampan --- src/ray/rpc/tests/auth_token_loader_test.cc | 213 ++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 src/ray/rpc/tests/auth_token_loader_test.cc diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc new file mode 100644 index 000000000000..f824146d44f2 --- /dev/null +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -0,0 +1,213 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/auth_token_loader.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ray/util/logging.h" + +namespace ray { +namespace rpc { + +class RayAuthTokenLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Clean up environment variables before each test + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + + // Clean up default token file + std::string home_dir = getenv("HOME"); + default_token_path_ = home_dir + "/.ray/auth_token"; + remove(default_token_path_.c_str()); + } + + void TearDown() override { + // Clean up after test + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + remove(default_token_path_.c_str()); + } + + std::string default_token_path_; +}; + +TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) { + // Set token in environment variable + setenv("RAY_AUTH_TOKEN", "test-token-from-env", 1); + + // Create a new instance to avoid cached state + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + EXPECT_EQ(token, "test-token-from-env"); + EXPECT_TRUE(loader.HasToken()); +} + +TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { + // Create a temporary token file + std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid()); + std::ofstream token_file(temp_token_path); + token_file << "test-token-from-file"; + token_file.close(); + + // Set path in environment variable + setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + EXPECT_EQ(token, "test-token-from-file"); + EXPECT_TRUE(loader.HasToken()); + + // Clean up + remove(temp_token_path.c_str()); +} + +TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { + // Create directory + std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + + // Create token file in default location + std::ofstream token_file(default_token_path_); + token_file << "test-token-from-default"; + token_file.close(); + + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + EXPECT_EQ(token, "test-token-from-default"); + EXPECT_TRUE(loader.HasToken()); +} + +TEST_F(RayAuthTokenLoaderTest, TestPrecedenceOrder) { + // Set all three sources + setenv("RAY_AUTH_TOKEN", "token-from-env", 1); + + std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid()); + std::ofstream temp_file(temp_token_path); + temp_file << "token-from-path"; + temp_file.close(); + setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + + std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + std::ofstream default_file(default_token_path_); + default_file << "token-from-default"; + default_file.close(); + + // Environment variable should have highest precedence + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + EXPECT_EQ(token, "token-from-env"); + + // Clean up + remove(temp_token_path.c_str()); +} + +TEST_F(RayAuthTokenLoaderTest, TestNoTokenFound) { + // No token set anywhere + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + EXPECT_EQ(token, ""); + EXPECT_FALSE(loader.HasToken()); +} + +TEST_F(RayAuthTokenLoaderTest, TestGenerateToken) { + // No token exists, but request generation + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(true); + + // Token should be generated (32 character hex string) + EXPECT_EQ(token.length(), 32); + EXPECT_TRUE(loader.HasToken()); + + // Token should be saved to default path + std::ifstream token_file(default_token_path_); + EXPECT_TRUE(token_file.is_open()); + std::string saved_token; + std::getline(token_file, saved_token); + EXPECT_EQ(saved_token, token); +} + +TEST_F(RayAuthTokenLoaderTest, TestCaching) { + // Set token in environment + setenv("RAY_AUTH_TOKEN", "cached-token", 1); + + auto &loader = RayAuthTokenLoader::instance(); + std::string token1 = loader.GetToken(false); + + // Change environment variable (shouldn't affect cached value) + setenv("RAY_AUTH_TOKEN", "new-token", 1); + std::string token2 = loader.GetToken(false); + + // Should still return the cached token + EXPECT_EQ(token1, token2); + EXPECT_EQ(token2, "cached-token"); +} + +TEST_F(RayAuthTokenLoaderTest, TestThreadSafety) { + // Set a token + setenv("RAY_AUTH_TOKEN", "thread-safe-token", 1); + + auto &loader = RayAuthTokenLoader::instance(); + + // Create multiple threads that try to get token simultaneously + std::vector threads; + std::vector results(10); + + for (int i = 0; i < 10; i++) { + threads.emplace_back([&loader, &results, i]() { results[i] = loader.GetToken(false); }); + } + + // Wait for all threads to complete + for (auto &thread : threads) { + thread.join(); + } + + // All threads should get the same token + for (const auto &result : results) { + EXPECT_EQ(result, "thread-safe-token"); + } +} + +TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) { + // Create token file with whitespace + std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + std::ofstream token_file(default_token_path_); + token_file << " token-with-spaces \n\t"; + token_file.close(); + + auto &loader = RayAuthTokenLoader::instance(); + std::string token = loader.GetToken(false); + + // Whitespace should be trimmed + EXPECT_EQ(token, "token-with-spaces"); +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + From 092f29e3107f2d0cdb72a139855c47bc5ac6216e Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 05:33:45 +0000 Subject: [PATCH 05/94] [Core][Tests] Add unit tests for Python auth_token_loader - Test token loading from RAY_AUTH_TOKEN environment variable - Test token loading from RAY_AUTH_TOKEN_PATH file - Test token loading from default ~/.ray/auth_token path - Test precedence order (env var > path env var > default file) - Test token generation with generate_if_not_found=True - Test token caching behavior across multiple calls - Test has_auth_token() function - Test thread safety with concurrent loads and generation - Test whitespace handling and empty values - Test file permissions on Unix systems (0600) - Test error handling for permission errors - Test integration with fixtures and cleanup Signed-off-by: sampan --- .../ray/tests/unit/test_auth_token_loader.py | 320 ++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 python/ray/tests/unit/test_auth_token_loader.py diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py new file mode 100644 index 000000000000..51a3871ce4a1 --- /dev/null +++ b/python/ray/tests/unit/test_auth_token_loader.py @@ -0,0 +1,320 @@ +"""Unit tests for ray._private.auth_token_loader module.""" + +import os +import tempfile +import threading +from pathlib import Path +from unittest import mock + +import pytest + +from ray._private import auth_token_loader + + +@pytest.fixture(autouse=True) +def reset_cached_token(): + """Reset the cached token before each test.""" + auth_token_loader._cached_token = None + yield + auth_token_loader._cached_token = None + + +@pytest.fixture(autouse=True) +def clean_env_vars(): + """Clean up environment variables before and after each test.""" + env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] + old_values = {var: os.environ.get(var) for var in env_vars} + + # Clear environment variables + for var in env_vars: + if var in os.environ: + del os.environ[var] + + yield + + # Restore old values + for var, value in old_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] + + +@pytest.fixture +def temp_token_file(): + """Create a temporary token file and clean it up after the test.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f: + temp_path = f.name + f.write("test-token-from-file") + yield temp_path + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + + +@pytest.fixture +def default_token_path(): + """Return the default token path and clean it up after the test.""" + path = Path.home() / ".ray" / "auth_token" + yield path + try: + path.unlink() + except FileNotFoundError: + pass + + +class TestLoadAuthToken: + """Tests for load_auth_token function.""" + + def test_load_from_env_variable(self): + """Test loading token from RAY_AUTH_TOKEN environment variable.""" + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-env" + + def test_load_from_env_path(self, temp_token_file): + """Test loading token from RAY_AUTH_TOKEN_PATH environment variable.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "test-token-from-file" + + def test_load_from_default_path(self, default_token_path): + """Test loading token from default ~/.ray/auth_token path.""" + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-default" + + def test_precedence_order(self, temp_token_file, default_token_path): + """Test that token loading follows correct precedence order.""" + # Set all three sources + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + # Environment variable should have highest precedence + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-env" + + def test_env_path_over_default(self, temp_token_file, default_token_path): + """Test that RAY_AUTH_TOKEN_PATH has precedence over default path.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "test-token-from-file" + + def test_no_token_found(self): + """Test behavior when no token is found.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_whitespace_handling(self, temp_token_file): + """Test that whitespace is properly trimmed from token files.""" + # Overwrite the temp file with whitespace + with open(temp_token_file, "w") as f: + f.write(" token-with-spaces \n\t") + + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-with-spaces" + + def test_empty_env_variable(self): + """Test that empty environment variable is ignored.""" + os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_nonexistent_path_in_env(self): + """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + +class TestTokenGeneration: + """Tests for token generation functionality.""" + + def test_generate_token(self, default_token_path): + """Test token generation when no token exists.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + + # Token should be a 32-character hex string (UUID without dashes) + assert len(token) == 32 + assert all(c in "0123456789abcdef" for c in token) + + # Token should be saved to default path + assert default_token_path.exists() + saved_token = default_token_path.read_text().strip() + assert saved_token == token + + def test_no_generation_without_flag(self): + """Test that token is not generated when flag is False.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_dont_generate_when_token_exists(self): + """Test that token is not generated when one already exists.""" + os.environ["RAY_AUTH_TOKEN"] = "existing-token" + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token == "existing-token" + + +class TestTokenCaching: + """Tests for token caching behavior.""" + + def test_caching_behavior(self): + """Test that token is cached after first load.""" + os.environ["RAY_AUTH_TOKEN"] = "cached-token" + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Change environment variable (shouldn't affect cached value) + os.environ["RAY_AUTH_TOKEN"] = "new-token" + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Should still return the cached token + assert token1 == token2 == "cached-token" + + def test_cache_empty_result(self): + """Test that even empty results are cached.""" + # First call with no token + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token1 == "" + + # Set environment variable after first call + os.environ["RAY_AUTH_TOKEN"] = "new-token" + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Should still return cached empty string + assert token2 == "" + + +class TestHasAuthToken: + """Tests for has_auth_token function.""" + + def test_has_token_true(self): + """Test has_auth_token returns True when token exists.""" + os.environ["RAY_AUTH_TOKEN"] = "test-token" + assert auth_token_loader.has_auth_token() is True + + def test_has_token_false(self): + """Test has_auth_token returns False when no token exists.""" + assert auth_token_loader.has_auth_token() is False + + def test_has_token_caches_result(self): + """Test that has_auth_token doesn't trigger generation.""" + # This should return False without generating a token + assert auth_token_loader.has_auth_token() is False + + # Verify no token was generated + default_path = Path.home() / ".ray" / "auth_token" + assert not default_path.exists() + + +class TestThreadSafety: + """Tests for thread safety of token loading.""" + + def test_concurrent_loads(self): + """Test that concurrent token loads are thread-safe.""" + os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token" + + results = [] + threads = [] + + def load_token(): + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + results.append(token) + + # Create multiple threads that try to load token simultaneously + for _ in range(10): + thread = threading.Thread(target=load_token) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All threads should get the same token + assert len(results) == 10 + assert all(result == "thread-safe-token" for result in results) + + def test_concurrent_generation(self, default_token_path): + """Test that concurrent token generation is thread-safe.""" + results = [] + threads = [] + + def generate_token(): + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + results.append(token) + + # Create multiple threads that try to generate token simultaneously + for _ in range(5): + thread = threading.Thread(target=generate_token) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All threads should get the same token (only generated once) + assert len(results) == 5 + assert len(set(results)) == 1 # All tokens should be identical + + +class TestFilePermissions: + """Tests for file permissions when saving tokens.""" + + def test_file_permissions_on_unix(self, default_token_path, monkeypatch): + """Test that token file has 0600 permissions on Unix systems.""" + # Skip on Windows + if os.name == "nt": + pytest.skip("Test only relevant on Unix systems") + + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token + + # Check file permissions (should be 0600) + stat_info = default_token_path.stat() + assert stat_info.st_mode & 0o777 == 0o600 + + def test_file_permissions_error_handling(self, monkeypatch): + """Test that permission errors are handled gracefully.""" + + # Mock os.chmod to raise an exception + def mock_chmod(path, mode): + raise OSError("Permission denied") + + monkeypatch.setattr(os, "chmod", mock_chmod) + + # Should still generate and return token, just not set permissions + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert len(token) == 32 + + +class TestIntegration: + """Integration tests with ray.init() and ray start CLI.""" + + def test_token_loader_with_ray_init(self, default_token_path): + """Test that token loader works with ray.init() enable_token_auth parameter.""" + # This is more of a smoke test to ensure the module can be imported + # and used in the context where it will be called + from ray._private import auth_token_loader + + # Generate a token + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token + assert len(token) == 32 + + # Verify it was saved + assert default_token_path.exists() + saved_token = default_token_path.read_text().strip() + assert saved_token == token + From fcd1d10a97399158264b620eb447348a4b85b218 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:00:45 +0000 Subject: [PATCH 06/94] fix lint errors Signed-off-by: sampan --- .../ray/tests/unit/test_auth_token_loader.py | 320 ++++++++++++++++++ src/ray/gcs/gcs_server.cc | 30 +- src/ray/rpc/auth_token_loader.cc | 7 +- src/ray/rpc/auth_token_loader.h | 1 - src/ray/rpc/tests/auth_token_loader_test.cc | 4 +- 5 files changed, 339 insertions(+), 23 deletions(-) create mode 100644 python/ray/tests/unit/test_auth_token_loader.py diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py new file mode 100644 index 000000000000..51a3871ce4a1 --- /dev/null +++ b/python/ray/tests/unit/test_auth_token_loader.py @@ -0,0 +1,320 @@ +"""Unit tests for ray._private.auth_token_loader module.""" + +import os +import tempfile +import threading +from pathlib import Path +from unittest import mock + +import pytest + +from ray._private import auth_token_loader + + +@pytest.fixture(autouse=True) +def reset_cached_token(): + """Reset the cached token before each test.""" + auth_token_loader._cached_token = None + yield + auth_token_loader._cached_token = None + + +@pytest.fixture(autouse=True) +def clean_env_vars(): + """Clean up environment variables before and after each test.""" + env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] + old_values = {var: os.environ.get(var) for var in env_vars} + + # Clear environment variables + for var in env_vars: + if var in os.environ: + del os.environ[var] + + yield + + # Restore old values + for var, value in old_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] + + +@pytest.fixture +def temp_token_file(): + """Create a temporary token file and clean it up after the test.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f: + temp_path = f.name + f.write("test-token-from-file") + yield temp_path + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + + +@pytest.fixture +def default_token_path(): + """Return the default token path and clean it up after the test.""" + path = Path.home() / ".ray" / "auth_token" + yield path + try: + path.unlink() + except FileNotFoundError: + pass + + +class TestLoadAuthToken: + """Tests for load_auth_token function.""" + + def test_load_from_env_variable(self): + """Test loading token from RAY_AUTH_TOKEN environment variable.""" + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-env" + + def test_load_from_env_path(self, temp_token_file): + """Test loading token from RAY_AUTH_TOKEN_PATH environment variable.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "test-token-from-file" + + def test_load_from_default_path(self, default_token_path): + """Test loading token from default ~/.ray/auth_token path.""" + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-default" + + def test_precedence_order(self, temp_token_file, default_token_path): + """Test that token loading follows correct precedence order.""" + # Set all three sources + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + # Environment variable should have highest precedence + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-from-env" + + def test_env_path_over_default(self, temp_token_file, default_token_path): + """Test that RAY_AUTH_TOKEN_PATH has precedence over default path.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "test-token-from-file" + + def test_no_token_found(self): + """Test behavior when no token is found.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_whitespace_handling(self, temp_token_file): + """Test that whitespace is properly trimmed from token files.""" + # Overwrite the temp file with whitespace + with open(temp_token_file, "w") as f: + f.write(" token-with-spaces \n\t") + + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "token-with-spaces" + + def test_empty_env_variable(self): + """Test that empty environment variable is ignored.""" + os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_nonexistent_path_in_env(self): + """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" + os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + +class TestTokenGeneration: + """Tests for token generation functionality.""" + + def test_generate_token(self, default_token_path): + """Test token generation when no token exists.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + + # Token should be a 32-character hex string (UUID without dashes) + assert len(token) == 32 + assert all(c in "0123456789abcdef" for c in token) + + # Token should be saved to default path + assert default_token_path.exists() + saved_token = default_token_path.read_text().strip() + assert saved_token == token + + def test_no_generation_without_flag(self): + """Test that token is not generated when flag is False.""" + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token == "" + + def test_dont_generate_when_token_exists(self): + """Test that token is not generated when one already exists.""" + os.environ["RAY_AUTH_TOKEN"] = "existing-token" + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token == "existing-token" + + +class TestTokenCaching: + """Tests for token caching behavior.""" + + def test_caching_behavior(self): + """Test that token is cached after first load.""" + os.environ["RAY_AUTH_TOKEN"] = "cached-token" + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Change environment variable (shouldn't affect cached value) + os.environ["RAY_AUTH_TOKEN"] = "new-token" + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Should still return the cached token + assert token1 == token2 == "cached-token" + + def test_cache_empty_result(self): + """Test that even empty results are cached.""" + # First call with no token + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) + assert token1 == "" + + # Set environment variable after first call + os.environ["RAY_AUTH_TOKEN"] = "new-token" + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) + + # Should still return cached empty string + assert token2 == "" + + +class TestHasAuthToken: + """Tests for has_auth_token function.""" + + def test_has_token_true(self): + """Test has_auth_token returns True when token exists.""" + os.environ["RAY_AUTH_TOKEN"] = "test-token" + assert auth_token_loader.has_auth_token() is True + + def test_has_token_false(self): + """Test has_auth_token returns False when no token exists.""" + assert auth_token_loader.has_auth_token() is False + + def test_has_token_caches_result(self): + """Test that has_auth_token doesn't trigger generation.""" + # This should return False without generating a token + assert auth_token_loader.has_auth_token() is False + + # Verify no token was generated + default_path = Path.home() / ".ray" / "auth_token" + assert not default_path.exists() + + +class TestThreadSafety: + """Tests for thread safety of token loading.""" + + def test_concurrent_loads(self): + """Test that concurrent token loads are thread-safe.""" + os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token" + + results = [] + threads = [] + + def load_token(): + token = auth_token_loader.load_auth_token(generate_if_not_found=False) + results.append(token) + + # Create multiple threads that try to load token simultaneously + for _ in range(10): + thread = threading.Thread(target=load_token) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All threads should get the same token + assert len(results) == 10 + assert all(result == "thread-safe-token" for result in results) + + def test_concurrent_generation(self, default_token_path): + """Test that concurrent token generation is thread-safe.""" + results = [] + threads = [] + + def generate_token(): + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + results.append(token) + + # Create multiple threads that try to generate token simultaneously + for _ in range(5): + thread = threading.Thread(target=generate_token) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All threads should get the same token (only generated once) + assert len(results) == 5 + assert len(set(results)) == 1 # All tokens should be identical + + +class TestFilePermissions: + """Tests for file permissions when saving tokens.""" + + def test_file_permissions_on_unix(self, default_token_path, monkeypatch): + """Test that token file has 0600 permissions on Unix systems.""" + # Skip on Windows + if os.name == "nt": + pytest.skip("Test only relevant on Unix systems") + + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token + + # Check file permissions (should be 0600) + stat_info = default_token_path.stat() + assert stat_info.st_mode & 0o777 == 0o600 + + def test_file_permissions_error_handling(self, monkeypatch): + """Test that permission errors are handled gracefully.""" + + # Mock os.chmod to raise an exception + def mock_chmod(path, mode): + raise OSError("Permission denied") + + monkeypatch.setattr(os, "chmod", mock_chmod) + + # Should still generate and return token, just not set permissions + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert len(token) == 32 + + +class TestIntegration: + """Integration tests with ray.init() and ray start CLI.""" + + def test_token_loader_with_ray_init(self, default_token_path): + """Test that token loader works with ray.init() enable_token_auth parameter.""" + # This is more of a smoke test to ensure the module can be imported + # and used in the context where it will be called + from ray._private import auth_token_loader + + # Generate a token + token = auth_token_loader.load_auth_token(generate_if_not_found=True) + assert token + assert len(token) == 32 + + # Verify it was saved + assert default_token_path.exists() + saved_token = default_token_path.read_text().strip() + assert saved_token == token + diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index acd6d262216f..07970c9c064d 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -215,21 +215,21 @@ void GcsServer::Start() { // Init KV Manager. This needs to be initialized first here so that // it can be used to retrieve the cluster ID. InitKVManager(); - gcs_init_data->AsyncLoad({[this, gcs_init_data] { - GetOrGenerateClusterId( - {[this, gcs_init_data](ClusterID cluster_id) { - rpc_server_.SetClusterId(cluster_id); - // Load and set authentication token if enabled - if (RayConfig::instance().enable_token_auth()) { - rpc_server_.SetAuthToken( - rpc::RayAuthTokenLoader::instance().GetToken( - false)); - } - DoStart(*gcs_init_data); - }, - io_context_provider_.GetDefaultIOContext()}); - }, - io_context_provider_.GetDefaultIOContext()}); + gcs_init_data->AsyncLoad( + {[this, gcs_init_data] { + GetOrGenerateClusterId( + {[this, gcs_init_data](ClusterID cluster_id) { + rpc_server_.SetClusterId(cluster_id); + // Load and set authentication token if enabled + if (RayConfig::instance().enable_token_auth()) { + rpc_server_.SetAuthToken( + rpc::RayAuthTokenLoader::instance().GetToken(false)); + } + DoStart(*gcs_init_data); + }, + io_context_provider_.GetDefaultIOContext()}); + }, + io_context_provider_.GetDefaultIOContext()}); } void GcsServer::GetOrGenerateClusterId( diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index ea48a1c0ebc2..f9f21932f5c4 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -102,8 +102,7 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() { token.erase(0, token.find_first_not_of(" \t\n\r\f\v")); token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1); if (!token.empty()) { - RAY_LOG(DEBUG) << "Loaded authentication token from default path: " - << default_path; + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; return token; } } @@ -151,8 +150,7 @@ std::string RayAuthTokenLoader::GenerateToken() { chmod(token_path.c_str(), S_IRUSR | S_IWUSR); #endif - RAY_LOG(INFO) << "Generated new authentication token and saved to " - << token_path; + RAY_LOG(INFO) << "Generated new authentication token and saved to " << token_path; } else { RAY_LOG(WARNING) << "Failed to save generated token to " << token_path << ". Token will only be available in memory."; @@ -195,4 +193,3 @@ std::string RayAuthTokenLoader::GetDefaultTokenPath() { } // namespace rpc } // namespace ray - diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h index c63596772082..9e4f835550f6 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/auth_token_loader.h @@ -66,4 +66,3 @@ class RayAuthTokenLoader { } // namespace rpc } // namespace ray - diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index f824146d44f2..c8ad62c1d790 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -174,7 +174,8 @@ TEST_F(RayAuthTokenLoaderTest, TestThreadSafety) { std::vector results(10); for (int i = 0; i < 10; i++) { - threads.emplace_back([&loader, &results, i]() { results[i] = loader.GetToken(false); }); + threads.emplace_back( + [&loader, &results, i]() { results[i] = loader.GetToken(false); }); } // Wait for all threads to complete @@ -210,4 +211,3 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } - From 411f6f45326c055d5bb7d0cb33886b083d799a4d Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:01:19 +0000 Subject: [PATCH 07/94] missed change Signed-off-by: sampan --- python/ray/tests/unit/test_auth_token_loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py index 51a3871ce4a1..5c755dea70c0 100644 --- a/python/ray/tests/unit/test_auth_token_loader.py +++ b/python/ray/tests/unit/test_auth_token_loader.py @@ -4,7 +4,6 @@ import tempfile import threading from pathlib import Path -from unittest import mock import pytest @@ -317,4 +316,3 @@ def test_token_loader_with_ray_init(self, default_token_path): assert default_token_path.exists() saved_token = default_token_path.read_text().strip() assert saved_token == token - From 40fcdb513245878b3cd8c0971e3bcf8faac95b03 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:06:39 +0000 Subject: [PATCH 08/94] more lint issues Signed-off-by: sampan --- src/ray/rpc/auth_token_loader.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index f9f21932f5c4..8bb9d189e364 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -14,6 +14,8 @@ #include "ray/rpc/auth_token_loader.h" +#include + #include #include #include From 223dbf5ab230a59c3a9d4a8cfe5f8cf065c46508 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:11:32 +0000 Subject: [PATCH 09/94] fix library Signed-off-by: sampan --- src/ray/rpc/auth_token_loader.cc | 3 +-- src/ray/rpc/tests/grpc_server_client_test.cc | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index 8bb9d189e364..b47421f2c07a 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -14,11 +14,10 @@ #include "ray/rpc/auth_token_loader.h" -#include - #include #include #include +#include #include "ray/util/logging.h" #include "ray/util/util.h" diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 15396d1b681f..883e2306d622 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "gtest/gtest.h" From cc89a639891e41864d268d36ea8f6822c1b9fdb5 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:20:41 +0000 Subject: [PATCH 10/94] more lint Signed-off-by: sampan --- src/ray/rpc/tests/auth_token_loader_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index c8ad62c1d790..bfc14f68a0d9 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -15,7 +15,9 @@ #include "ray/rpc/auth_token_loader.h" #include +#include #include +#include #include "gtest/gtest.h" #include "ray/util/logging.h" From c07929886694f092c7c39f4dd6f5db47c0575724 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:22:36 +0000 Subject: [PATCH 11/94] move python side changes to new pr Signed-off-by: sampan --- .../ray/tests/unit/test_auth_token_loader.py | 318 ------------------ 1 file changed, 318 deletions(-) delete mode 100644 python/ray/tests/unit/test_auth_token_loader.py diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py deleted file mode 100644 index 5c755dea70c0..000000000000 --- a/python/ray/tests/unit/test_auth_token_loader.py +++ /dev/null @@ -1,318 +0,0 @@ -"""Unit tests for ray._private.auth_token_loader module.""" - -import os -import tempfile -import threading -from pathlib import Path - -import pytest - -from ray._private import auth_token_loader - - -@pytest.fixture(autouse=True) -def reset_cached_token(): - """Reset the cached token before each test.""" - auth_token_loader._cached_token = None - yield - auth_token_loader._cached_token = None - - -@pytest.fixture(autouse=True) -def clean_env_vars(): - """Clean up environment variables before and after each test.""" - env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] - old_values = {var: os.environ.get(var) for var in env_vars} - - # Clear environment variables - for var in env_vars: - if var in os.environ: - del os.environ[var] - - yield - - # Restore old values - for var, value in old_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] - - -@pytest.fixture -def temp_token_file(): - """Create a temporary token file and clean it up after the test.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f: - temp_path = f.name - f.write("test-token-from-file") - yield temp_path - try: - os.unlink(temp_path) - except FileNotFoundError: - pass - - -@pytest.fixture -def default_token_path(): - """Return the default token path and clean it up after the test.""" - path = Path.home() / ".ray" / "auth_token" - yield path - try: - path.unlink() - except FileNotFoundError: - pass - - -class TestLoadAuthToken: - """Tests for load_auth_token function.""" - - def test_load_from_env_variable(self): - """Test loading token from RAY_AUTH_TOKEN environment variable.""" - os.environ["RAY_AUTH_TOKEN"] = "token-from-env" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-env" - - def test_load_from_env_path(self, temp_token_file): - """Test loading token from RAY_AUTH_TOKEN_PATH environment variable.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "test-token-from-file" - - def test_load_from_default_path(self, default_token_path): - """Test loading token from default ~/.ray/auth_token path.""" - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-default" - - def test_precedence_order(self, temp_token_file, default_token_path): - """Test that token loading follows correct precedence order.""" - # Set all three sources - os.environ["RAY_AUTH_TOKEN"] = "token-from-env" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - - # Environment variable should have highest precedence - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-env" - - def test_env_path_over_default(self, temp_token_file, default_token_path): - """Test that RAY_AUTH_TOKEN_PATH has precedence over default path.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "test-token-from-file" - - def test_no_token_found(self): - """Test behavior when no token is found.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_whitespace_handling(self, temp_token_file): - """Test that whitespace is properly trimmed from token files.""" - # Overwrite the temp file with whitespace - with open(temp_token_file, "w") as f: - f.write(" token-with-spaces \n\t") - - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-with-spaces" - - def test_empty_env_variable(self): - """Test that empty environment variable is ignored.""" - os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_nonexistent_path_in_env(self): - """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - -class TestTokenGeneration: - """Tests for token generation functionality.""" - - def test_generate_token(self, default_token_path): - """Test token generation when no token exists.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - - # Token should be a 32-character hex string (UUID without dashes) - assert len(token) == 32 - assert all(c in "0123456789abcdef" for c in token) - - # Token should be saved to default path - assert default_token_path.exists() - saved_token = default_token_path.read_text().strip() - assert saved_token == token - - def test_no_generation_without_flag(self): - """Test that token is not generated when flag is False.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_dont_generate_when_token_exists(self): - """Test that token is not generated when one already exists.""" - os.environ["RAY_AUTH_TOKEN"] = "existing-token" - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token == "existing-token" - - -class TestTokenCaching: - """Tests for token caching behavior.""" - - def test_caching_behavior(self): - """Test that token is cached after first load.""" - os.environ["RAY_AUTH_TOKEN"] = "cached-token" - token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Change environment variable (shouldn't affect cached value) - os.environ["RAY_AUTH_TOKEN"] = "new-token" - token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Should still return the cached token - assert token1 == token2 == "cached-token" - - def test_cache_empty_result(self): - """Test that even empty results are cached.""" - # First call with no token - token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token1 == "" - - # Set environment variable after first call - os.environ["RAY_AUTH_TOKEN"] = "new-token" - token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Should still return cached empty string - assert token2 == "" - - -class TestHasAuthToken: - """Tests for has_auth_token function.""" - - def test_has_token_true(self): - """Test has_auth_token returns True when token exists.""" - os.environ["RAY_AUTH_TOKEN"] = "test-token" - assert auth_token_loader.has_auth_token() is True - - def test_has_token_false(self): - """Test has_auth_token returns False when no token exists.""" - assert auth_token_loader.has_auth_token() is False - - def test_has_token_caches_result(self): - """Test that has_auth_token doesn't trigger generation.""" - # This should return False without generating a token - assert auth_token_loader.has_auth_token() is False - - # Verify no token was generated - default_path = Path.home() / ".ray" / "auth_token" - assert not default_path.exists() - - -class TestThreadSafety: - """Tests for thread safety of token loading.""" - - def test_concurrent_loads(self): - """Test that concurrent token loads are thread-safe.""" - os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token" - - results = [] - threads = [] - - def load_token(): - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - results.append(token) - - # Create multiple threads that try to load token simultaneously - for _ in range(10): - thread = threading.Thread(target=load_token) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # All threads should get the same token - assert len(results) == 10 - assert all(result == "thread-safe-token" for result in results) - - def test_concurrent_generation(self, default_token_path): - """Test that concurrent token generation is thread-safe.""" - results = [] - threads = [] - - def generate_token(): - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - results.append(token) - - # Create multiple threads that try to generate token simultaneously - for _ in range(5): - thread = threading.Thread(target=generate_token) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # All threads should get the same token (only generated once) - assert len(results) == 5 - assert len(set(results)) == 1 # All tokens should be identical - - -class TestFilePermissions: - """Tests for file permissions when saving tokens.""" - - def test_file_permissions_on_unix(self, default_token_path, monkeypatch): - """Test that token file has 0600 permissions on Unix systems.""" - # Skip on Windows - if os.name == "nt": - pytest.skip("Test only relevant on Unix systems") - - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token - - # Check file permissions (should be 0600) - stat_info = default_token_path.stat() - assert stat_info.st_mode & 0o777 == 0o600 - - def test_file_permissions_error_handling(self, monkeypatch): - """Test that permission errors are handled gracefully.""" - - # Mock os.chmod to raise an exception - def mock_chmod(path, mode): - raise OSError("Permission denied") - - monkeypatch.setattr(os, "chmod", mock_chmod) - - # Should still generate and return token, just not set permissions - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert len(token) == 32 - - -class TestIntegration: - """Integration tests with ray.init() and ray start CLI.""" - - def test_token_loader_with_ray_init(self, default_token_path): - """Test that token loader works with ray.init() enable_token_auth parameter.""" - # This is more of a smoke test to ensure the module can be imported - # and used in the context where it will be called - from ray._private import auth_token_loader - - # Generate a token - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token - assert len(token) == 32 - - # Verify it was saved - assert default_token_path.exists() - saved_token = default_token_path.read_text().strip() - assert saved_token == token From 34aa7a38311d2917e200af1c657c1faf79314b9e Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 06:25:17 +0000 Subject: [PATCH 12/94] remove unused import Signed-off-by: sampan --- python/ray/tests/unit/test_auth_token_loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py index 51a3871ce4a1..5c755dea70c0 100644 --- a/python/ray/tests/unit/test_auth_token_loader.py +++ b/python/ray/tests/unit/test_auth_token_loader.py @@ -4,7 +4,6 @@ import tempfile import threading from pathlib import Path -from unittest import mock import pytest @@ -317,4 +316,3 @@ def test_token_loader_with_ray_init(self, default_token_path): assert default_token_path.exists() saved_token = default_token_path.read_text().strip() assert saved_token == token - From 733efca00321524159e322dd95656c38a2115f24 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 07:51:06 +0000 Subject: [PATCH 13/94] remove generate token method from c++ code Signed-off-by: sampan --- src/ray/common/ray_config_def.h | 1 - src/ray/gcs/gcs_server.cc | 17 ++--- src/ray/rpc/auth_token_loader.cc | 69 ++++----------------- src/ray/rpc/auth_token_loader.h | 10 +-- src/ray/rpc/tests/auth_token_loader_test.cc | 45 ++++++-------- 5 files changed, 41 insertions(+), 101 deletions(-) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index ec904fc213e9..799dd4ed70a2 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -36,7 +36,6 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) RAY_CONFIG(bool, enable_cluster_auth, true) /// Whether to enable token-based authentication for RPC calls. -/// The authentication token is loaded via RayAuthTokenLoader. RAY_CONFIG(bool, enable_token_auth, false) /// The interval of periodic event loop stats print. diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 07970c9c064d..d7f1918f6c8d 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -217,17 +217,12 @@ void GcsServer::Start() { InitKVManager(); gcs_init_data->AsyncLoad( {[this, gcs_init_data] { - GetOrGenerateClusterId( - {[this, gcs_init_data](ClusterID cluster_id) { - rpc_server_.SetClusterId(cluster_id); - // Load and set authentication token if enabled - if (RayConfig::instance().enable_token_auth()) { - rpc_server_.SetAuthToken( - rpc::RayAuthTokenLoader::instance().GetToken(false)); - } - DoStart(*gcs_init_data); - }, - io_context_provider_.GetDefaultIOContext()}); + GetOrGenerateClusterId({[this, gcs_init_data](ClusterID cluster_id) { + rpc_server_.SetClusterId(cluster_id); + rpc_server_.SetAuthToken(RayAuthTokenLoader::instance().GetToken()); + DoStart(*gcs_init_data); + }, + io_context_provider_.GetDefaultIOContext()}); }, io_context_provider_.GetDefaultIOContext()}); } diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index b47421f2c07a..a5a302ac985f 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -15,10 +15,10 @@ #include "ray/rpc/auth_token_loader.h" #include -#include -#include +#include #include +#include "ray/common/ray_config.h" #include "ray/util/logging.h" #include "ray/util/util.h" @@ -37,7 +37,7 @@ RayAuthTokenLoader &RayAuthTokenLoader::instance() { return instance; } -const std::string &RayAuthTokenLoader::GetToken(bool generate_if_not_found) { +const std::string &RayAuthTokenLoader::GetToken() { std::lock_guard lock(token_mutex_); if (token_loaded_) { @@ -48,9 +48,14 @@ const std::string &RayAuthTokenLoader::GetToken(bool generate_if_not_found) { cached_token_ = LoadTokenFromSources(); token_loaded_ = true; - // If not found and generation is requested, generate a new token - if (cached_token_.empty() && generate_if_not_found) { - cached_token_ = GenerateToken(); + // If token auth is enabled but no token is found, throw an error + if (RayConfig::instance().enable_token_auth() && cached_token_.empty()) { + RAY_LOG(ERROR) << "Token authentication is enabled but no authentication token was " + "found. Please set RAY_AUTH_TOKEN environment variable, " + "RAY_AUTH_TOKEN_PATH to a file containing the token, or create a " + "token file at ~/.ray/auth_token"; + throw std::runtime_error( + "Token authentication is enabled but no authentication token was found"); } return cached_token_; @@ -58,7 +63,7 @@ const std::string &RayAuthTokenLoader::GetToken(bool generate_if_not_found) { bool RayAuthTokenLoader::HasToken() { // This will trigger loading if not already loaded - const std::string &token = GetToken(false); + const std::string &token = GetToken(); return !token.empty(); } @@ -113,56 +118,6 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() { return ""; } -std::string RayAuthTokenLoader::GenerateToken() { - // Generate a UUID-like token using random hex string - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, 15); - - const char *hex_chars = "0123456789abcdef"; - std::stringstream ss; - - // Generate a 32-character hex string (similar to UUID without dashes) - for (int i = 0; i < 32; i++) { - ss << hex_chars[dis(gen)]; - } - - std::string token = ss.str(); - std::string token_path = GetDefaultTokenPath(); - - // Try to save the token to the default path - try { - // Create directory if it doesn't exist - std::string dir_path = token_path.substr(0, token_path.find_last_of("/\\")); -#ifdef _WIN32 - CreateDirectoryA(dir_path.c_str(), NULL); -#else - mkdir(dir_path.c_str(), 0700); -#endif - - // Write token to file - std::ofstream token_file(token_path, std::ios::trunc); - if (token_file.is_open()) { - token_file << token; - token_file.close(); - -#ifndef _WIN32 - // Set file permissions to 0600 on Unix systems - chmod(token_path.c_str(), S_IRUSR | S_IWUSR); -#endif - - RAY_LOG(INFO) << "Generated new authentication token and saved to " << token_path; - } else { - RAY_LOG(WARNING) << "Failed to save generated token to " << token_path - << ". Token will only be available in memory."; - } - } catch (const std::exception &e) { - RAY_LOG(WARNING) << "Exception while saving token: " << e.what(); - } - - return token; -} - std::string RayAuthTokenLoader::GetDefaultTokenPath() { std::string home_dir; diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h index 9e4f835550f6..c91276f287d6 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/auth_token_loader.h @@ -33,10 +33,9 @@ class RayAuthTokenLoader { static RayAuthTokenLoader &instance(); /// Get the authentication token. - /// \param generate_if_not_found If true, generate and save a new token if not found. - /// \return The authentication token, or empty string if not found and generation is - /// disabled. - const std::string &GetToken(bool generate_if_not_found = false); + /// If token authentication is enabled but no token is found, throws an error. + /// \return The authentication token, or empty string if auth is disabled. + const std::string &GetToken(); /// Check if an authentication token exists. /// \return True if a token is available (cached or can be loaded). @@ -53,9 +52,6 @@ class RayAuthTokenLoader { /// Load token from available sources (env vars and file). std::string LoadTokenFromSources(); - /// Generate a new UUID token and save it to the default path. - std::string GenerateToken(); - /// Get the default token file path (~/.ray/auth_token). std::string GetDefaultTokenPath(); diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index bfc14f68a0d9..4cd214a0b9ff 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -20,6 +20,7 @@ #include #include "gtest/gtest.h" +#include "ray/common/ray_config.h" #include "ray/util/logging.h" namespace ray { @@ -54,7 +55,7 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) { // Create a new instance to avoid cached state auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); EXPECT_EQ(token, "test-token-from-env"); EXPECT_TRUE(loader.HasToken()); @@ -71,7 +72,7 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); EXPECT_EQ(token, "test-token-from-file"); EXPECT_TRUE(loader.HasToken()); @@ -91,7 +92,7 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { token_file.close(); auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); EXPECT_EQ(token, "test-token-from-default"); EXPECT_TRUE(loader.HasToken()); @@ -115,7 +116,7 @@ TEST_F(RayAuthTokenLoaderTest, TestPrecedenceOrder) { // Environment variable should have highest precedence auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); EXPECT_EQ(token, "token-from-env"); @@ -123,30 +124,25 @@ TEST_F(RayAuthTokenLoaderTest, TestPrecedenceOrder) { remove(temp_token_path.c_str()); } -TEST_F(RayAuthTokenLoaderTest, TestNoTokenFound) { - // No token set anywhere +TEST_F(RayAuthTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { + // No token set anywhere, but auth is disabled (default) auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); EXPECT_EQ(token, ""); EXPECT_FALSE(loader.HasToken()); } -TEST_F(RayAuthTokenLoaderTest, TestGenerateToken) { - // No token exists, but request generation - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(true); +TEST_F(RayAuthTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { + // Enable token auth + RayConfig::instance().initialize(R"({"enable_token_auth": true})"); - // Token should be generated (32 character hex string) - EXPECT_EQ(token.length(), 32); - EXPECT_TRUE(loader.HasToken()); + // No token exists, should throw an error + auto &loader = RayAuthTokenLoader::instance(); + EXPECT_THROW(loader.GetToken(), std::runtime_error); - // Token should be saved to default path - std::ifstream token_file(default_token_path_); - EXPECT_TRUE(token_file.is_open()); - std::string saved_token; - std::getline(token_file, saved_token); - EXPECT_EQ(saved_token, token); + // Reset config for other tests + RayConfig::instance().initialize(R"({"enable_token_auth": false})"); } TEST_F(RayAuthTokenLoaderTest, TestCaching) { @@ -154,11 +150,11 @@ TEST_F(RayAuthTokenLoaderTest, TestCaching) { setenv("RAY_AUTH_TOKEN", "cached-token", 1); auto &loader = RayAuthTokenLoader::instance(); - std::string token1 = loader.GetToken(false); + std::string token1 = loader.GetToken(); // Change environment variable (shouldn't affect cached value) setenv("RAY_AUTH_TOKEN", "new-token", 1); - std::string token2 = loader.GetToken(false); + std::string token2 = loader.GetToken(); // Should still return the cached token EXPECT_EQ(token1, token2); @@ -176,8 +172,7 @@ TEST_F(RayAuthTokenLoaderTest, TestThreadSafety) { std::vector results(10); for (int i = 0; i < 10; i++) { - threads.emplace_back( - [&loader, &results, i]() { results[i] = loader.GetToken(false); }); + threads.emplace_back([&loader, &results, i]() { results[i] = loader.GetToken(); }); } // Wait for all threads to complete @@ -200,7 +195,7 @@ TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) { token_file.close(); auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(false); + std::string token = loader.GetToken(); // Whitespace should be trimmed EXPECT_EQ(token, "token-with-spaces"); From 16fd74ece8c7405497d0eeddeeb25a40cfeeb3d5 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 07:51:31 +0000 Subject: [PATCH 14/94] fix lint Signed-off-by: sampan --- src/ray/gcs/gcs_server.cc | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index d7f1918f6c8d..11a925a44bbd 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -215,16 +215,17 @@ void GcsServer::Start() { // Init KV Manager. This needs to be initialized first here so that // it can be used to retrieve the cluster ID. InitKVManager(); - gcs_init_data->AsyncLoad( - {[this, gcs_init_data] { - GetOrGenerateClusterId({[this, gcs_init_data](ClusterID cluster_id) { - rpc_server_.SetClusterId(cluster_id); - rpc_server_.SetAuthToken(RayAuthTokenLoader::instance().GetToken()); - DoStart(*gcs_init_data); - }, - io_context_provider_.GetDefaultIOContext()}); - }, - io_context_provider_.GetDefaultIOContext()}); + gcs_init_data->AsyncLoad({[this, gcs_init_data] { + GetOrGenerateClusterId( + {[this, gcs_init_data](ClusterID cluster_id) { + rpc_server_.SetClusterId(cluster_id); + rpc_server_.SetAuthToken( + RayAuthTokenLoader::instance().GetToken()); + DoStart(*gcs_init_data); + }, + io_context_provider_.GetDefaultIOContext()}); + }, + io_context_provider_.GetDefaultIOContext()}); } void GcsServer::GetOrGenerateClusterId( From 7094efd4f8a69f7e2fbd6b165dc7258db1211a5c Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 08:47:57 +0000 Subject: [PATCH 15/94] refactor code files Signed-off-by: sampan --- src/ray/core_worker/grpc_service.cc | 3 +- src/ray/core_worker/grpc_service.h | 3 +- src/ray/gcs/BUILD.bazel | 1 + src/ray/gcs/grpc_services.cc | 36 ++++-- src/ray/gcs/grpc_services.h | 36 ++++-- src/ray/rpc/BUILD.bazel | 12 ++ src/ray/rpc/auth_token_loader.cc | 3 +- src/ray/rpc/auth_token_loader.h | 2 +- src/ray/rpc/client_call.h | 11 +- src/ray/rpc/grpc_server.cc | 4 +- src/ray/rpc/grpc_server.h | 2 - .../rpc/node_manager/node_manager_server.h | 3 +- src/ray/rpc/object_manager_server.h | 3 +- src/ray/rpc/tests/BUILD.bazel | 13 ++ src/ray/rpc/tests/auth_token_loader_test.cc | 114 +++++++++++------- 15 files changed, 157 insertions(+), 89 deletions(-) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index 2f1cd9574787..613f0f55f7c6 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -23,7 +23,8 @@ namespace rpc { void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index e70b9c8ee475..0cd2096882bc 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -158,7 +158,8 @@ class CoreWorkerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: CoreWorkerService::AsyncService service_; diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index 2511da321245..850150a28b93 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -526,6 +526,7 @@ ray_cc_library( "//src/ray/raylet/scheduling:scheduler", "//src/ray/raylet_rpc_client:raylet_client_lib", "//src/ray/raylet_rpc_client:raylet_client_pool", + "//src/ray/rpc:auth_token_loader", "//src/ray/rpc:grpc_server", "//src/ray/rpc:metrics_agent_client", "//src/ray/util:counter_map", diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index 72113b3be474..912c7a267eaf 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -22,7 +22,8 @@ namespace rpc { void ActorInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { /// The register & create actor RPCs take a long time, so we shouldn't limit their /// concurrency to avoid distributed deadlock. RPC_SERVICE_HANDLER(ActorInfoGcsService, RegisterActor, -1) @@ -42,7 +43,8 @@ void ActorInfoGrpcService::InitServerCallFactories( void NodeInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { // We only allow one cluster ID in the lifetime of a client. // So, if a client connects, it should not have a pre-existing different ID. RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, @@ -61,7 +63,8 @@ void NodeInfoGrpcService::InitServerCallFactories( void NodeResourceInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER( NodeResourceInfoGcsService, GetAllAvailableResources, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -75,7 +78,8 @@ void NodeResourceInfoGrpcService::InitServerCallFactories( void InternalPubSubGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER(InternalPubSubGcsService, GcsPublish, max_active_rpcs_per_handler_); RPC_SERVICE_HANDLER( InternalPubSubGcsService, GcsSubscriberPoll, max_active_rpcs_per_handler_); @@ -86,7 +90,8 @@ void InternalPubSubGrpcService::InitServerCallFactories( void JobInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER(JobInfoGcsService, AddJob, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, MarkJobFinished, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, GetAllJobInfo, max_active_rpcs_per_handler_) @@ -97,7 +102,8 @@ void JobInfoGrpcService::InitServerCallFactories( void RuntimeEnvGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER( RuntimeEnvGcsService, PinRuntimeEnvURI, max_active_rpcs_per_handler_) } @@ -105,7 +111,8 @@ void RuntimeEnvGrpcService::InitServerCallFactories( void WorkerInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER( WorkerInfoGcsService, ReportWorkerFailure, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(WorkerInfoGcsService, GetWorkerInfo, max_active_rpcs_per_handler_) @@ -121,7 +128,8 @@ void WorkerInfoGrpcService::InitServerCallFactories( void InternalKVGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER(InternalKVGcsService, InternalKVGet, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( InternalKVGcsService, InternalKVMultiGet, max_active_rpcs_per_handler_) @@ -137,7 +145,8 @@ void InternalKVGrpcService::InitServerCallFactories( void TaskInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER(TaskInfoGcsService, AddTaskEventData, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(TaskInfoGcsService, GetTaskEvents, max_active_rpcs_per_handler_) } @@ -145,7 +154,8 @@ void TaskInfoGrpcService::InitServerCallFactories( void PlacementGroupInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER( PlacementGroupInfoGcsService, CreatePlacementGroup, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -166,7 +176,8 @@ namespace autoscaler { void AutoscalerStateGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER( AutoscalerStateService, GetClusterResourceState, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -188,7 +199,8 @@ namespace events { void RayEventExportGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::string &auth_token) { RPC_SERVICE_HANDLER(RayEventExportGcsService, AddEvents, max_active_rpcs_per_handler_) } diff --git a/src/ray/gcs/grpc_services.h b/src/ray/gcs/grpc_services.h index d8a0899e2439..2b9f4a50a36f 100644 --- a/src/ray/gcs/grpc_services.h +++ b/src/ray/gcs/grpc_services.h @@ -51,7 +51,8 @@ class ActorInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: ActorInfoGcsService::AsyncService service_; @@ -74,7 +75,8 @@ class NodeInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: NodeInfoGcsService::AsyncService service_; @@ -97,7 +99,8 @@ class NodeResourceInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: NodeResourceInfoGcsService::AsyncService service_; @@ -120,7 +123,8 @@ class InternalPubSubGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: InternalPubSubGcsService::AsyncService service_; @@ -143,7 +147,8 @@ class JobInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: JobInfoGcsService::AsyncService service_; @@ -166,7 +171,8 @@ class RuntimeEnvGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: RuntimeEnvGcsService::AsyncService service_; @@ -189,7 +195,8 @@ class WorkerInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: WorkerInfoGcsService::AsyncService service_; @@ -212,7 +219,8 @@ class InternalKVGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: InternalKVGcsService::AsyncService service_; @@ -235,7 +243,8 @@ class TaskInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: TaskInfoGcsService::AsyncService service_; @@ -258,7 +267,8 @@ class PlacementGroupInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: PlacementGroupInfoGcsService::AsyncService service_; @@ -283,7 +293,8 @@ class AutoscalerStateGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: AutoscalerStateService::AsyncService service_; @@ -310,7 +321,8 @@ class RayEventExportGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::string &auth_token) override; private: RayEventExportGcsService::AsyncService service_; diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index b4474616d52e..a8ce7c738ef7 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -15,6 +15,7 @@ ray_cc_library( "//src/ray/core_worker:__pkg__", ], deps = [ + ":auth_token_loader", ":rpc_callback_types", "//src/ray/common:asio", "//src/ray/common:grpc_util", @@ -109,6 +110,17 @@ ray_cc_library( ], ) +ray_cc_library( + name = "auth_token_loader", + srcs = ["auth_token_loader.cc"], + hdrs = ["auth_token_loader.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/util:logging", + ], +) + ray_cc_library( name = "grpc_server", srcs = ["grpc_server.cc"], diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index a5a302ac985f..78ba8703288d 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -1,4 +1,4 @@ -// Copyright 2017 The Ray Authors. +// Copyright 2025 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ #include "ray/common/ray_config.h" #include "ray/util/logging.h" -#include "ray/util/util.h" #ifdef _WIN32 #include diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h index c91276f287d6..d4d974e10829 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/auth_token_loader.h @@ -1,4 +1,4 @@ -// Copyright 2017 The Ray Authors. +// Copyright 2025 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index ab14e1314fe6..bd754dc93c6f 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -31,6 +31,7 @@ #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/auth_token_loader.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/stats/metric_defs.h" #include "ray/util/thread_utils.h" @@ -70,11 +71,6 @@ class ClientCallImpl : public ClientCall { /// Constructor. /// /// \param[in] callback The callback function to handle the reply. - /// \param[in] cluster_id The cluster ID for authentication. - /// \param[in] auth_token The authentication token (empty = disabled). - /// \param[in] stats_handle Statistics handle for this call. - /// \param[in] record_stats Whether to record statistics. - /// \param[in] timeout_ms The timeout for this call in milliseconds. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, const std::string &auth_token, @@ -224,9 +220,7 @@ class ClientCallManager { int num_threads = 1, int64_t call_timeout_ms = -1) : cluster_id_(cluster_id), - auth_token_(::RayConfig::instance().enable_token_auth() - ? ::RayConfig::instance().auth_token() - : ""), + auth_token_(RayAuthTokenLoader::instance().GetToken()), main_service_(main_service), num_threads_(num_threads), record_stats_(record_stats), @@ -375,7 +369,6 @@ class ClientCallManager { ClusterID cluster_id_; /// Cached authentication token for token-based authentication. - /// Empty string means no token authentication. const std::string auth_token_; /// The main event loop, to which the callback functions will be posted. diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index ff03df9e7104..1f02896cfd8d 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -178,8 +178,8 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) } void GrpcServer::RegisterService(std::unique_ptr &&service, - bool token_auth) { - if (token_auth && cluster_id_.IsNil()) { + bool cluster_id_auth_enabled) { + if (cluster_id_auth_enabled && cluster_id_.IsNil()) { RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; } for (int i = 0; i < num_threads_; i++) { diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 13a56b8c630e..83012674f7a4 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -144,8 +144,6 @@ class GrpcServer { void SetAuthToken(const std::string &auth_token) { auth_token_ = auth_token; } - const std::string &GetAuthToken() const { return auth_token_; } - protected: /// Initialize this server. void Init(); diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 2ca9713e0f17..f911c83d7517 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -202,7 +202,8 @@ class NodeManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::string &auth_token) override { RAY_NODE_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index b19f2d540db2..2a2ff6a5dcc4 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -77,7 +77,8 @@ class ObjectManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::string &auth_token) override { RAY_OBJECT_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 5fa8b14cc4db..01541804a684 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -40,3 +40,16 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "auth_token_loader_test", + size = "small", + srcs = [ + "auth_token_loader_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/rpc:auth_token_loader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index 4cd214a0b9ff..f56f110cf148 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -1,4 +1,4 @@ -// Copyright 2017 The Ray Authors. +// Copyright 2025 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,17 +30,17 @@ class RayAuthTokenLoaderTest : public ::testing::Test { protected: void SetUp() override { // Clean up environment variables before each test - unsetenv("RAY_AUTH_TOKEN"); - unsetenv("RAY_AUTH_TOKEN_PATH"); - - // Clean up default token file std::string home_dir = getenv("HOME"); default_token_path_ = home_dir + "/.ray/auth_token"; - remove(default_token_path_.c_str()); + cleanup_env(); } void TearDown() override { // Clean up after test + cleanup_env(); + } + + void cleanup_env() { unsetenv("RAY_AUTH_TOKEN"); unsetenv("RAY_AUTH_TOKEN_PATH"); remove(default_token_path_.c_str()); @@ -98,30 +98,79 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { EXPECT_TRUE(loader.HasToken()); } -TEST_F(RayAuthTokenLoaderTest, TestPrecedenceOrder) { - // Set all three sources - setenv("RAY_AUTH_TOKEN", "token-from-env", 1); +// Parametrized test for token loading precedence: env var > file > default file + +struct TokenSourceConfig { + bool set_env = false; + bool set_file = false; + bool set_default = false; + std::string expected_token; + std::string env_token = "token-from-env"; + std::string file_token = "token-from-path"; + std::string default_token = "token-from-default"; +}; + +class RayAuthTokenLoaderPrecedenceTest + : public RayAuthTokenLoaderTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, + RayAuthTokenLoaderPrecedenceTest, + ::testing::Values( + // All set: env should win + TokenSourceConfig{true, true, true, "token-from-env"}, + // File and default file set: file should win + TokenSourceConfig{false, true, true, "token-from-path"}, + // Only default file set + TokenSourceConfig{ + false, false, true, "token-from-default"})); + +TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { + const auto ¶m = GetParam(); + + // Optionally set environment variable + if (param.set_env) { + setenv("RAY_AUTH_TOKEN", param.env_token.c_str(), 1); + } else { + unsetenv("RAY_AUTH_TOKEN"); + } + // Optionally create file and set path std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid()); - std::ofstream temp_file(temp_token_path); - temp_file << "token-from-path"; - temp_file.close(); - setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + if (param.set_file) { + std::ofstream token_file(temp_token_path); + token_file << param.file_token; + token_file.close(); + setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + } else { + unsetenv("RAY_AUTH_TOKEN_PATH"); + } + // Optionally create default file std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; mkdir(ray_dir.c_str(), 0700); - std::ofstream default_file(default_token_path_); - default_file << "token-from-default"; - default_file.close(); + if (param.set_default) { + std::ofstream default_file(default_token_path_); + default_file << param.default_token; + default_file.close(); + } else { + remove(default_token_path_.c_str()); + } - // Environment variable should have highest precedence + // Always create a new instance to avoid cached state auto &loader = RayAuthTokenLoader::instance(); std::string token = loader.GetToken(); - EXPECT_EQ(token, "token-from-env"); + EXPECT_EQ(token, param.expected_token); - // Clean up - remove(temp_token_path.c_str()); + // Clean up token file if it was written + if (param.set_file) { + remove(temp_token_path.c_str()); + } + // Clean up default file if it was written + if (param.set_default) { + remove(default_token_path_.c_str()); + } } TEST_F(RayAuthTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { @@ -161,31 +210,6 @@ TEST_F(RayAuthTokenLoaderTest, TestCaching) { EXPECT_EQ(token2, "cached-token"); } -TEST_F(RayAuthTokenLoaderTest, TestThreadSafety) { - // Set a token - setenv("RAY_AUTH_TOKEN", "thread-safe-token", 1); - - auto &loader = RayAuthTokenLoader::instance(); - - // Create multiple threads that try to get token simultaneously - std::vector threads; - std::vector results(10); - - for (int i = 0; i < 10; i++) { - threads.emplace_back([&loader, &results, i]() { results[i] = loader.GetToken(); }); - } - - // Wait for all threads to complete - for (auto &thread : threads) { - thread.join(); - } - - // All threads should get the same token - for (const auto &result : results) { - EXPECT_EQ(result, "thread-safe-token"); - } -} - TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) { // Create token file with whitespace std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; From f56d5ee89c0ec2d5432f21d0c1a0866c2cfd32b4 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 08:49:18 +0000 Subject: [PATCH 16/94] fix lint Signed-off-by: sampan --- src/ray/core_worker/grpc_service.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index 613f0f55f7c6..d15d31889180 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -16,6 +16,7 @@ #include #include +#include namespace ray { namespace rpc { From 356a38eb6ece531c98fd16dbb0e328c16bf750b5 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 08:49:43 +0000 Subject: [PATCH 17/94] fix lint Signed-off-by: sampan --- src/ray/core_worker/grpc_service.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index d15d31889180..f6d40e03cafb 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -15,8 +15,8 @@ #include "ray/core_worker/grpc_service.h" #include -#include #include +#include namespace ray { namespace rpc { From 899973ea040591a9e0c5f090e6086138c0213b62 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 10:20:37 +0000 Subject: [PATCH 18/94] add missing imports Signed-off-by: sampan --- src/ray/rpc/node_manager/node_manager_server.h | 1 + src/ray/rpc/object_manager_server.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index f911c83d7517..942032d9e089 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index 2a2ff6a5dcc4..dd21f5382991 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" From 47f2e5ae39089d23cb8b39a4407d2325d888d457 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 17 Oct 2025 10:43:20 +0000 Subject: [PATCH 19/94] refactor token loader and tests Signed-off-by: sampan --- src/ray/core_worker/grpc_service.h | 1 + src/ray/rpc/auth_token_loader.cc | 40 ++++++++--- src/ray/rpc/auth_token_loader.h | 11 ++- src/ray/rpc/tests/auth_token_loader_test.cc | 4 ++ src/ray/rpc/tests/grpc_server_client_test.cc | 74 ++++++++++++-------- 5 files changed, 88 insertions(+), 42 deletions(-) diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index 0cd2096882bc..65cce1eaa538 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -29,6 +29,7 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index 78ba8703288d..15cb127e8327 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -39,16 +39,22 @@ RayAuthTokenLoader &RayAuthTokenLoader::instance() { const std::string &RayAuthTokenLoader::GetToken() { std::lock_guard lock(token_mutex_); - if (token_loaded_) { - return cached_token_; + // If already loaded, return cached value + if (cached_token_.has_value()) { + return *cached_token_; } - // Try to load from sources - cached_token_ = LoadTokenFromSources(); - token_loaded_ = true; + // If token auth is disabled, return empty string without loading + if (!RayConfig::instance().enable_token_auth()) { + cached_token_ = ""; + return *cached_token_; + } + + // Token auth is enabled, try to load from sources + std::string token = LoadTokenFromSources(); - // If token auth is enabled but no token is found, throw an error - if (RayConfig::instance().enable_token_auth() && cached_token_.empty()) { + // If no token found and auth is enabled, throw error + if (token.empty()) { RAY_LOG(ERROR) << "Token authentication is enabled but no authentication token was " "found. Please set RAY_AUTH_TOKEN environment variable, " "RAY_AUTH_TOKEN_PATH to a file containing the token, or create a " @@ -57,12 +63,26 @@ const std::string &RayAuthTokenLoader::GetToken() { "Token authentication is enabled but no authentication token was found"); } - return cached_token_; + // Cache and return the loaded token + cached_token_ = token; + return *cached_token_; } bool RayAuthTokenLoader::HasToken() { - // This will trigger loading if not already loaded - const std::string &token = GetToken(); + std::lock_guard lock(token_mutex_); + + // If already loaded, check if non-empty + if (cached_token_.has_value()) { + return !cached_token_->empty(); + } + + // If token auth is disabled, no token needed + if (!RayConfig::instance().enable_token_auth()) { + return false; + } + + // Try to load token + std::string token = LoadTokenFromSources(); return !token.empty(); } diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h index d4d974e10829..a17e01b47a51 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/auth_token_loader.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include namespace ray { @@ -41,6 +42,13 @@ class RayAuthTokenLoader { /// \return True if a token is available (cached or can be loaded). bool HasToken(); + /// Reset the cached token. For testing only. + /// This allows tests to simulate fresh loader state between test cases. + void ResetForTesting() { + std::lock_guard lock(token_mutex_); + cached_token_.reset(); + } + // Prevent copying and moving RayAuthTokenLoader(const RayAuthTokenLoader &) = delete; RayAuthTokenLoader &operator=(const RayAuthTokenLoader &) = delete; @@ -56,8 +64,7 @@ class RayAuthTokenLoader { std::string GetDefaultTokenPath(); std::mutex token_mutex_; - std::string cached_token_; - bool token_loaded_ = false; + std::optional cached_token_; }; } // namespace rpc diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index f56f110cf148..c89ed0d6d8f3 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -33,11 +33,15 @@ class RayAuthTokenLoaderTest : public ::testing::Test { std::string home_dir = getenv("HOME"); default_token_path_ = home_dir + "/.ray/auth_token"; cleanup_env(); + // Reset the singleton's cached state for test isolation + RayAuthTokenLoader::instance().ResetForTesting(); } void TearDown() override { // Clean up after test cleanup_env(); + // Reset the singleton's cached state for test isolation + RayAuthTokenLoader::instance().ResetForTesting(); } void cleanup_env() { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 883e2306d622..b4fdea6ae123 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -18,6 +18,7 @@ #include #include "gtest/gtest.h" +#include "ray/rpc/auth_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" @@ -329,10 +330,33 @@ TEST_F(TestGrpcServerClientFixture, TestTimeoutMacro) { } class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { public: - void SetUpServerWithConfig(const std::string &config_json) { + void SetUp() override { + // Configure token auth via RayConfig + std::string config_json = R"({"enable_token_auth": true})"; RayConfig::instance().initialize(config_json); + // Reset the token loader for test isolation + RayAuthTokenLoader::instance().ResetForTesting(); + } + + void SetUpServerAndClient(const std::string &server_token, + const std::string &client_token) { + // IMPORTANT: Set client token in environment FIRST, before creating ClientCallManager + // This is because ClientCallManager reads from RayAuthTokenLoader singleton on + // construction and caches the value. + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + } else { + unsetenv("RAY_AUTH_TOKEN"); + } + + // Start client thread FIRST + client_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + client_io_service_work_(client_io_service_.get_executor()); + client_io_service_.run(); + }); - // Start handler thread + // Start handler thread for server handler_thread_ = std::make_unique([this]() { boost::asio::executor_work_guard handler_io_service_work_(handler_io_service_.get_executor()); @@ -341,6 +365,8 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { // Create and start server grpc_server_.reset(new GrpcServer("test", 0, true)); + // Set server token explicitly (can be different from client token) + grpc_server_->SetAuthToken(server_token); grpc_server_->RegisterService( std::make_unique(handler_io_service_, test_service_handler_), false); @@ -349,20 +375,9 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { while (grpc_server_->GetPort() == 0) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - } - void SetUpClientWithConfig(const std::string &config_json) { - // Reconfigure for client (allows different token) - RayConfig::instance().initialize(config_json); - - // Start client thread - client_thread_ = std::make_unique([this]() { - boost::asio::executor_work_guard - client_io_service_work_(client_io_service_.get_executor()); - client_io_service_.run(); - }); - - // Create client + // Create client (will read auth token from RayAuthTokenLoader which reads the + // environment) client_call_manager_.reset( new ClientCallManager(client_io_service_, false, /*local_address=*/"")); grpc_client_.reset(new GrpcClient( @@ -392,6 +407,12 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { handler_thread_->join(); } } + + // Clean up environment variables + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + // Reset the token loader for test isolation + RayAuthTokenLoader::instance().ResetForTesting(); } // Helper to execute RPC and wait for result @@ -442,10 +463,8 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { // Both server and client have the same token - const std::string config = - R"({"enable_token_auth": true, "auth_token": "test_secret_token_123"})"; - SetUpServerWithConfig(config); - SetUpClientWithConfig(config); + const std::string token = "test_secret_token_123"; + SetUpServerAndClient(token, token); auto result = ExecutePingAndWait(); @@ -455,9 +474,7 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { // Server and client have different tokens - SetUpServerWithConfig(R"({"enable_token_auth": true, "auth_token": "server_token"})"); - SetUpClientWithConfig( - R"({"enable_token_auth": true, "auth_token": "wrong_client_token"})"); + SetUpServerAndClient("server_token", "wrong_client_token"); auto result = ExecutePingAndWait(); @@ -470,8 +487,7 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { // Server expects token, client doesn't send one (empty token) - SetUpServerWithConfig(R"({"enable_token_auth": true, "auth_token": "server_token"})"); - SetUpClientWithConfig(R"({"enable_token_auth": true, "auth_token": ""})"); + SetUpServerAndClient("server_token", ""); auto result = ExecutePingAndWait(); @@ -484,9 +500,9 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthDisabled) { // Token auth disabled, should succeed regardless - const std::string config = R"({"enable_token_auth": false})"; - SetUpServerWithConfig(config); - SetUpClientWithConfig(config); + // Temporarily disable token auth + RayConfig::instance().initialize(R"({"enable_token_auth": false})"); + SetUpServerAndClient("", ""); auto result = ExecutePingAndWait(); @@ -496,9 +512,7 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthDisabled) { TEST_F(TestGrpcServerClientTokenAuthFixture, TestEmptyTokenNoEnforcement) { // Empty token with token auth enabled should not enforce - const std::string config = R"({"enable_token_auth": true, "auth_token": ""})"; - SetUpServerWithConfig(config); - SetUpClientWithConfig(config); + SetUpServerAndClient("", ""); auto result = ExecutePingAndWait(); From d6a87e2192e59a27dbeeb03629958a4581f54825 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 21 Oct 2025 05:55:49 +0000 Subject: [PATCH 20/94] refactor token loader + fix build Signed-off-by: sampan --- src/ray/gcs/gcs_server.cc | 2 - src/ray/rpc/BUILD.bazel | 1 + src/ray/rpc/auth_token_loader.cc | 76 +++++--- src/ray/rpc/auth_token_loader.h | 16 +- src/ray/rpc/grpc_server.cc | 7 +- src/ray/rpc/grpc_server.h | 11 +- src/ray/rpc/tests/auth_token_loader_test.cc | 192 ++++++++++++++----- src/ray/rpc/tests/grpc_server_client_test.cc | 59 ++---- 8 files changed, 236 insertions(+), 128 deletions(-) diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 11a925a44bbd..beb237250cd8 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -219,8 +219,6 @@ void GcsServer::Start() { GetOrGenerateClusterId( {[this, gcs_init_data](ClusterID cluster_id) { rpc_server_.SetClusterId(cluster_id); - rpc_server_.SetAuthToken( - RayAuthTokenLoader::instance().GetToken()); DoStart(*gcs_init_data); }, io_context_provider_.GetDefaultIOContext()}); diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index a8ce7c738ef7..c178c4279a95 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -127,6 +127,7 @@ ray_cc_library( hdrs = ["grpc_server.h"], visibility = ["//visibility:public"], deps = [ + ":auth_token_loader", ":common", ":server_call", "//src/ray/common:asio", diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/auth_token_loader.cc index 15cb127e8327..3019757252a5 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/auth_token_loader.cc @@ -21,13 +21,22 @@ #include "ray/common/ray_config.h" #include "ray/util/logging.h" -#ifdef _WIN32 -#include -#else +#if defined(__APPLE__) || defined(__linux__) #include #include #endif +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#endif + namespace ray { namespace rpc { @@ -86,6 +95,26 @@ bool RayAuthTokenLoader::HasToken() { return !token.empty(); } +// Read token from the first line of the file. trim whitespace. +// Returns empty string if file cannot be opened or is empty. +std::string RayAuthTokenLoader::ReadTokenFromFile(const std::string &file_path) { + std::ifstream token_file(file_path); + if (!token_file.is_open()) { + return ""; + } + + std::string token; + std::getline(token_file, token); + token_file.close(); + + // Trim whitespace + std::string whitespace = " \t\n\r\f\v"; + token.erase(0, token.find_first_not_of(whitespace)); + token.erase(token.find_last_not_of(whitespace) + 1); + + return token; +} + std::string RayAuthTokenLoader::LoadTokenFromSources() { // Precedence 1: RAY_AUTH_TOKEN environment variable const char *env_token = std::getenv("RAY_AUTH_TOKEN"); @@ -98,18 +127,10 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() { // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { - std::ifstream token_file(env_token_path); - if (token_file.is_open()) { - std::string token; - std::getline(token_file, token); - token_file.close(); - // Trim whitespace - token.erase(0, token.find_first_not_of(" \t\n\r\f\v")); - token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1); - if (!token.empty()) { - RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; - return token; - } + std::string token = ReadTokenFromFile(env_token_path); + if (!token.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + return token; } else { RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " << env_token_path; @@ -118,18 +139,10 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() { // Precedence 3: Default token path ~/.ray/auth_token std::string default_path = GetDefaultTokenPath(); - std::ifstream token_file(default_path); - if (token_file.is_open()) { - std::string token; - std::getline(token_file, token); - token_file.close(); - // Trim whitespace - token.erase(0, token.find_first_not_of(" \t\n\r\f\v")); - token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1); - if (!token.empty()) { - RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; - return token; - } + std::string token = ReadTokenFromFile(default_path); + if (!token.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; + return token; } // No token found @@ -141,6 +154,7 @@ std::string RayAuthTokenLoader::GetDefaultTokenPath() { std::string home_dir; #ifdef _WIN32 + const char *path_separator = "\\"; const char *userprofile = std::getenv("USERPROFILE"); if (userprofile != nullptr) { home_dir = userprofile; @@ -152,18 +166,22 @@ std::string RayAuthTokenLoader::GetDefaultTokenPath() { } } #else + const char *path_separator = "/"; const char *home = std::getenv("HOME"); if (home != nullptr) { home_dir = home; } #endif + const std::string token_subpath = + std::string(path_separator) + ".ray" + std::string(path_separator) + "auth_token"; + if (home_dir.empty()) { RAY_LOG(WARNING) << "Cannot determine home directory for token storage"; - return ".ray/auth_token"; + return "." + token_subpath; } - return home_dir + "/.ray/auth_token"; + return home_dir + token_subpath; } } // namespace rpc diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/auth_token_loader.h index a17e01b47a51..af798a081e41 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/auth_token_loader.h @@ -25,7 +25,8 @@ namespace rpc { /// Supports loading tokens from multiple sources with precedence: /// 1. RAY_AUTH_TOKEN environment variable /// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) -/// 3. Default token path: ~/.ray/auth_token +/// 3. Default token path: ~/.ray/auth_token (Unix) or %USERPROFILE%\.ray\auth_token +/// (Windows) /// /// Thread-safe with internal caching to avoid repeated file I/O. class RayAuthTokenLoader { @@ -42,9 +43,9 @@ class RayAuthTokenLoader { /// \return True if a token is available (cached or can be loaded). bool HasToken(); - /// Reset the cached token. For testing only. - /// This allows tests to simulate fresh loader state between test cases. - void ResetForTesting() { + /// Reset the cached token. This ensures that the next call to GetToken() returns the + /// token from the sources. + void ResetCache() { std::lock_guard lock(token_mutex_); cached_token_.reset(); } @@ -57,10 +58,15 @@ class RayAuthTokenLoader { RayAuthTokenLoader() = default; ~RayAuthTokenLoader() = default; + /// Read and trim token from a file. Returns empty string if file cannot be opened or is + /// empty. + std::string ReadTokenFromFile(const std::string &file_path); + /// Load token from available sources (env vars and file). std::string LoadTokenFromSources(); - /// Get the default token file path (~/.ray/auth_token). + /// Get the default token file path (~/.ray/auth_token on Unix, + /// %USERPROFILE%\.ray\auth_token on Windows). std::string GetDefaultTokenPath(); std::mutex token_mutex_; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 1f02896cfd8d..359c97517d8d 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -26,6 +26,7 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" +#include "ray/rpc/auth_token_loader.h" #include "ray/rpc/common.h" #include "ray/util/network_util.h" #include "ray/util/thread_utils.h" @@ -182,9 +183,13 @@ void GrpcServer::RegisterService(std::unique_ptr &&service, if (cluster_id_auth_enabled && cluster_id_.IsNil()) { RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; } + // Use override token if set, otherwise load from RayAuthTokenLoader + std::string auth_token = auth_token_override_.empty() + ? RayAuthTokenLoader::instance().GetToken() + : auth_token_override_; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( - cqs_[i], &server_call_factories_, cluster_id_, auth_token_); + cqs_[i], &server_call_factories_, cluster_id_, auth_token); } services_.push_back(std::move(service)); } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 83012674f7a4..f61d89929d6a 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -142,7 +142,12 @@ class GrpcServer { cluster_id_ = cluster_id; } - void SetAuthToken(const std::string &auth_token) { auth_token_ = auth_token; } + /// Set an override token for testing. This takes precedence over RayAuthTokenLoader. + /// This is primarily for testing purposes. + /// \param override_token The override token to use for this server. + void SetAuthTokenOverride(const std::string &override_token) { + auth_token_override_ = override_token; + } protected: /// Initialize this server. @@ -162,8 +167,8 @@ class GrpcServer { const bool listen_to_localhost_only_; /// Token representing ID of this cluster. ClusterID cluster_id_; - /// Authentication token for token-based authentication. - std::string auth_token_; + /// Override token for testing. If set, this takes precedence over RayAuthTokenLoader. + std::string auth_token_override_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/auth_token_loader_test.cc index c89ed0d6d8f3..d253e345cf3a 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/auth_token_loader_test.cc @@ -23,39 +23,145 @@ #include "ray/common/ray_config.h" #include "ray/util/logging.h" +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#include // For _mkdir on Windows +#include // For _getpid on Windows +#endif + namespace ray { namespace rpc { class RayAuthTokenLoaderTest : public ::testing::Test { protected: void SetUp() override { - // Clean up environment variables before each test - std::string home_dir = getenv("HOME"); - default_token_path_ = home_dir + "/.ray/auth_token"; + // Enable token authentication for tests + RayConfig::instance().initialize(R"({"enable_token_auth": true})"); + + // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory + // This ensures tests work in environments where HOME isn't provided +#ifdef _WIN32 + if (std::getenv("USERPROFILE") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "\\ray_test_home"; + } else { + test_home_dir_ = "C:\\Windows\\Temp\\ray_test_home"; + } + _putenv(("USERPROFILE=" + test_home_dir_).c_str()); + } + const char *home_dir = std::getenv("USERPROFILE"); + default_token_path_ = std::string(home_dir) + "\\.ray\\auth_token"; +#else + if (std::getenv("HOME") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "/ray_test_home"; + } else { + test_home_dir_ = "/tmp/ray_test_home"; + } + setenv("HOME", test_home_dir_.c_str(), 1); + } + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + default_token_path_ = std::string(home_dir) + "/.ray/auth_token"; + test_home_dir_ = home_dir; + } else { + default_token_path_ = ".ray/auth_token"; + } +#endif cleanup_env(); // Reset the singleton's cached state for test isolation - RayAuthTokenLoader::instance().ResetForTesting(); + RayAuthTokenLoader::instance().ResetCache(); } void TearDown() override { // Clean up after test cleanup_env(); // Reset the singleton's cached state for test isolation - RayAuthTokenLoader::instance().ResetForTesting(); + RayAuthTokenLoader::instance().ResetCache(); + // Disable token auth after tests + RayConfig::instance().initialize(R"({"enable_token_auth": false})"); } void cleanup_env() { - unsetenv("RAY_AUTH_TOKEN"); - unsetenv("RAY_AUTH_TOKEN_PATH"); + unset_env_var("RAY_AUTH_TOKEN"); + unset_env_var("RAY_AUTH_TOKEN_PATH"); remove(default_token_path_.c_str()); } + std::string get_temp_token_path() { +#ifdef _WIN32 + return "C:\\Windows\\Temp\\ray_test_token_" + std::to_string(_getpid()); +#else + return "/tmp/ray_test_token_" + std::to_string(getpid()); +#endif + } + + void set_env_var(const char *name, const char *value) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "=" + std::string(value); + _putenv(env_str.c_str()); +#else + setenv(name, value, 1); +#endif + } + + void unset_env_var(const char *name) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "="; + _putenv(env_str.c_str()); +#else + unsetenv(name); +#endif + } + + void ensure_ray_dir_exists() { +#ifdef _WIN32 + const char *home_dir = std::getenv("USERPROFILE"); + _mkdir(home_dir); // Create parent directory + std::string ray_dir = std::string(home_dir) + "\\.ray"; + _mkdir(ray_dir.c_str()); +#else + // Always ensure the home directory exists (it might be a test temp dir we created) + if (!test_home_dir_.empty()) { + mkdir(test_home_dir_.c_str(), + 0700); // Create if it doesn't exist (ignore error if it does) + } + + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + std::string ray_dir = std::string(home_dir) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + } +#endif + } + + void write_token_file(const std::string &path, const std::string &content) { + std::ofstream token_file(path); + token_file << content; + token_file.close(); + } + std::string default_token_path_; + std::string test_home_dir_; // Fallback home directory for tests }; TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) { // Set token in environment variable - setenv("RAY_AUTH_TOKEN", "test-token-from-env", 1); + set_env_var("RAY_AUTH_TOKEN", "test-token-from-env"); // Create a new instance to avoid cached state auto &loader = RayAuthTokenLoader::instance(); @@ -67,13 +173,11 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) { TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { // Create a temporary token file - std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid()); - std::ofstream token_file(temp_token_path); - token_file << "test-token-from-file"; - token_file.close(); + std::string temp_token_path = get_temp_token_path(); + write_token_file(temp_token_path, "test-token-from-file"); // Set path in environment variable - setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); auto &loader = RayAuthTokenLoader::instance(); std::string token = loader.GetToken(); @@ -86,14 +190,9 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { } TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { - // Create directory - std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; - mkdir(ray_dir.c_str(), 0700); - - // Create token file in default location - std::ofstream token_file(default_token_path_); - token_file << "test-token-from-default"; - token_file.close(); + // Create directory and token file in default location + ensure_ray_dir_exists(); + write_token_file(default_token_path_, "test-token-from-default"); auto &loader = RayAuthTokenLoader::instance(); std::string token = loader.GetToken(); @@ -102,7 +201,8 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { EXPECT_TRUE(loader.HasToken()); } -// Parametrized test for token loading precedence: env var > file > default file +// Parametrized test for token loading precedence: env var > user-specified file > default +// file struct TokenSourceConfig { bool set_env = false; @@ -134,29 +234,24 @@ TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { // Optionally set environment variable if (param.set_env) { - setenv("RAY_AUTH_TOKEN", param.env_token.c_str(), 1); + set_env_var("RAY_AUTH_TOKEN", param.env_token.c_str()); } else { - unsetenv("RAY_AUTH_TOKEN"); + unset_env_var("RAY_AUTH_TOKEN"); } // Optionally create file and set path - std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid()); + std::string temp_token_path = get_temp_token_path(); if (param.set_file) { - std::ofstream token_file(temp_token_path); - token_file << param.file_token; - token_file.close(); - setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1); + write_token_file(temp_token_path, param.file_token); + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); } else { - unsetenv("RAY_AUTH_TOKEN_PATH"); + unset_env_var("RAY_AUTH_TOKEN_PATH"); } // Optionally create default file - std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; - mkdir(ray_dir.c_str(), 0700); + ensure_ray_dir_exists(); if (param.set_default) { - std::ofstream default_file(default_token_path_); - default_file << param.default_token; - default_file.close(); + write_token_file(default_token_path_, param.default_token); } else { remove(default_token_path_.c_str()); } @@ -178,35 +273,37 @@ TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { } TEST_F(RayAuthTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { - // No token set anywhere, but auth is disabled (default) + // Disable auth for this specific test + RayConfig::instance().initialize(R"({"enable_token_auth": false})"); + RayAuthTokenLoader::instance().ResetCache(); + + // No token set anywhere, but auth is disabled auto &loader = RayAuthTokenLoader::instance(); std::string token = loader.GetToken(); EXPECT_EQ(token, ""); EXPECT_FALSE(loader.HasToken()); -} -TEST_F(RayAuthTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { - // Enable token auth + // Re-enable for other tests RayConfig::instance().initialize(R"({"enable_token_auth": true})"); +} +TEST_F(RayAuthTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { + // Token auth is already enabled in SetUp() // No token exists, should throw an error auto &loader = RayAuthTokenLoader::instance(); EXPECT_THROW(loader.GetToken(), std::runtime_error); - - // Reset config for other tests - RayConfig::instance().initialize(R"({"enable_token_auth": false})"); } TEST_F(RayAuthTokenLoaderTest, TestCaching) { // Set token in environment - setenv("RAY_AUTH_TOKEN", "cached-token", 1); + set_env_var("RAY_AUTH_TOKEN", "cached-token"); auto &loader = RayAuthTokenLoader::instance(); std::string token1 = loader.GetToken(); // Change environment variable (shouldn't affect cached value) - setenv("RAY_AUTH_TOKEN", "new-token", 1); + set_env_var("RAY_AUTH_TOKEN", "new-token"); std::string token2 = loader.GetToken(); // Should still return the cached token @@ -216,11 +313,8 @@ TEST_F(RayAuthTokenLoaderTest, TestCaching) { TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) { // Create token file with whitespace - std::string ray_dir = std::string(getenv("HOME")) + "/.ray"; - mkdir(ray_dir.c_str(), 0700); - std::ofstream token_file(default_token_path_); - token_file << " token-with-spaces \n\t"; - token_file.close(); + ensure_ray_dir_exists(); + write_token_file(default_token_path_, " token-with-spaces \n\t"); auto &loader = RayAuthTokenLoader::instance(); std::string token = loader.GetToken(); diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index b4fdea6ae123..5fe2f3ac5ae4 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -89,7 +89,8 @@ class TestGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::string &auth_token) override { RPC_SERVICE_HANDLER_CUSTOM_AUTH( TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH( @@ -334,18 +335,18 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { // Configure token auth via RayConfig std::string config_json = R"({"enable_token_auth": true})"; RayConfig::instance().initialize(config_json); - // Reset the token loader for test isolation - RayAuthTokenLoader::instance().ResetForTesting(); + RayAuthTokenLoader::instance().ResetCache(); } void SetUpServerAndClient(const std::string &server_token, const std::string &client_token) { - // IMPORTANT: Set client token in environment FIRST, before creating ClientCallManager - // This is because ClientCallManager reads from RayAuthTokenLoader singleton on - // construction and caches the value. + // Set client token in environment for ClientCallManager to read from + // RayAuthTokenLoader if (!client_token.empty()) { setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); } else { + RayConfig::instance().initialize(R"({"enable_token_auth": false})"); + RayAuthTokenLoader::instance().ResetCache(); unsetenv("RAY_AUTH_TOKEN"); } @@ -364,9 +365,13 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { }); // Create and start server + // In production, server would automatically use RayAuthTokenLoader, but for testing + // we can override with SetAuthTokenOverride to test mismatched token scenarios grpc_server_.reset(new GrpcServer("test", 0, true)); - // Set server token explicitly (can be different from client token) - grpc_server_->SetAuthToken(server_token); + if (!server_token.empty()) { + // Set server token explicitly for testing scenarios with different tokens + grpc_server_->SetAuthTokenOverride(server_token); + } grpc_server_->RegisterService( std::make_unique(handler_io_service_, test_service_handler_), false); @@ -412,7 +417,7 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { unsetenv("RAY_AUTH_TOKEN"); unsetenv("RAY_AUTH_TOKEN_PATH"); // Reset the token loader for test isolation - RayAuthTokenLoader::instance().ResetForTesting(); + RayAuthTokenLoader::instance().ResetCache(); } // Helper to execute RPC and wait for result @@ -479,10 +484,9 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { auto result = ExecutePingAndWait(); ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_FALSE(result.success) << "Request should fail with wrong token"; - ASSERT_TRUE(result.error_msg.find("InvalidAuthToken") != std::string::npos || - result.error_msg.find("Authentication") != std::string::npos) - << "Error message should indicate auth failure, got: " << result.error_msg; + ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; + ASSERT_TRUE(result.error_msg == + "InvalidAuthToken: Authentication token is missing or incorrect"); } TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { @@ -492,34 +496,11 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { auto result = ExecutePingAndWait(); ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_FALSE(result.success) << "Request should fail when token is missing"; - ASSERT_TRUE(result.error_msg.find("InvalidAuthToken") != std::string::npos || - result.error_msg.find("Authentication") != std::string::npos) - << "Error message should indicate auth failure, got: " << result.error_msg; + // If the server has a token but the client doesn't, auth should fail + ASSERT_FALSE(result.success) + << "Request should fail when client doesn't provide required token"; } -TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthDisabled) { - // Token auth disabled, should succeed regardless - // Temporarily disable token auth - RayConfig::instance().initialize(R"({"enable_token_auth": false})"); - SetUpServerAndClient("", ""); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_TRUE(result.success) << "Request should succeed when token auth is disabled"; -} - -TEST_F(TestGrpcServerClientTokenAuthFixture, TestEmptyTokenNoEnforcement) { - // Empty token with token auth enabled should not enforce - SetUpServerAndClient("", ""); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_TRUE(result.success) - << "Request should succeed with empty token (no enforcement)"; -} } // namespace rpc } // namespace ray From e5797417dde5e5641e0165808574ec3b8a5e7d73 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 21 Oct 2025 06:20:05 +0000 Subject: [PATCH 21/94] fix lint Signed-off-by: sampan --- src/ray/rpc/tests/grpc_server_client_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 5fe2f3ac5ae4..5df0b85d478b 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -485,8 +485,8 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { ASSERT_TRUE(result.completed) << "Request did not complete in time"; ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; - ASSERT_TRUE(result.error_msg == - "InvalidAuthToken: Authentication token is missing or incorrect"); + ASSERT_EQ(result.error_msg, + "InvalidAuthToken: Authentication token is missing or incorrect"); } TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { From 8678815bcd4ed210227db08f3eb836cdcae5630b Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 05:35:47 +0000 Subject: [PATCH 22/94] fix issues + update tests Signed-off-by: sampan --- python/ray/_private/auth_token_loader.py | 136 +++++++---- python/ray/_private/worker.py | 24 +- python/ray/_raylet.pyx | 10 + python/ray/includes/common.pxd | 6 + python/ray/includes/ray_config.pxd | 2 + python/ray/includes/ray_config.pxi | 4 + .../ray/tests/test_token_auth_integration.py | 188 +++++++++++++++ .../ray/tests/unit/test_auth_token_loader.py | 225 +++++------------- 8 files changed, 378 insertions(+), 217 deletions(-) create mode 100644 python/ray/tests/test_token_auth_integration.py diff --git a/python/ray/_private/auth_token_loader.py b/python/ray/_private/auth_token_loader.py index 008861f3a849..32cc032f4872 100644 --- a/python/ray/_private/auth_token_loader.py +++ b/python/ray/_private/auth_token_loader.py @@ -13,7 +13,9 @@ import threading import uuid from pathlib import Path -from typing import Optional +from typing import Dict, Optional + +from ray._raylet import Config, reset_auth_token_cache logger = logging.getLogger(__name__) @@ -51,23 +53,13 @@ def load_auth_token(generate_if_not_found: bool = False) -> str: # Generate if requested and not found if not token and generate_if_not_found: - token = _generate_and_save_token() + token = _generate_and_save_token_internal() # Cache the result (even if empty) _cached_token = token return _cached_token -def has_auth_token() -> bool: - """Check if an authentication token exists. - - Returns: - True if a token is available (cached or can be loaded), False otherwise. - """ - token = load_auth_token(generate_if_not_found=False) - return bool(token) - - def _load_token_from_sources() -> str: """Load token from available sources (env vars and file). @@ -85,21 +77,13 @@ def _load_token_from_sources() -> str: # Precedence 2: RAY_AUTH_TOKEN_PATH environment variable env_token_path = os.environ.get("RAY_AUTH_TOKEN_PATH", "").strip() if env_token_path: - try: - token_path = Path(env_token_path).expanduser() - if token_path.exists(): - token = token_path.read_text().strip() - if token: - logger.debug(f"Loaded authentication token from file: {token_path}") - return token - else: - logger.warning( - f"RAY_AUTH_TOKEN_PATH is set but file does not exist: {token_path}" - ) - except Exception as e: - logger.warning( - f"Failed to read token from RAY_AUTH_TOKEN_PATH ({env_token_path}): {e}" - ) + token_path = Path(env_token_path).expanduser() + if not token_path.exists(): + raise FileNotFoundError(f"Token file not found: {token_path}") + token = token_path.read_text().strip() + if token: + logger.debug(f"Loaded authentication token from file: {token_path}") + return token # Precedence 3: Default token path ~/.ray/auth_token default_path = _get_default_token_path() @@ -119,12 +103,31 @@ def _load_token_from_sources() -> str: return "" -def _generate_and_save_token() -> str: - """Generate a new UUID token and save it to the default path. +def generate_and_save_token() -> str: + """Generate a new random token and save it in the default token path. Returns: The newly generated authentication token. """ + global _cached_token + + with _token_lock: + # Check if we already have a cached token + if _cached_token is not None: + logger.warning( + "Returning cached authentication token instead of generating new one. " + "Call load_auth_token() to use existing token or clear cache first." + ) + return _cached_token + + # Generate and save token without nested lock + return _generate_and_save_token_internal() + + +def _generate_and_save_token_internal() -> str: + """Internal function to generate and save token. Assumes lock is already held.""" + global _cached_token + # Generate a UUID-based token token = uuid.uuid4().hex @@ -137,17 +140,13 @@ def _generate_and_save_token() -> str: # Write token to file token_path.write_text(token) - # Set file permissions to 0600 on Unix systems - try: - # This will work on Unix systems, but not on Windows - os.chmod(token_path, 0o600) - except (OSError, AttributeError): - # chmod may not work on Windows or may fail for other reasons - # This is not critical, so we just log a debug message - logger.debug( - f"Could not set file permissions to 0600 for {token_path}. " - "This is expected on Windows." - ) + # Ensure file is flushed to disk immediately + # This is critical for subprocess/C++ code to read it immediately + import os + + fd = os.open(str(token_path), os.O_RDONLY) + os.fsync(fd) + os.close(fd) logger.info(f"Generated new authentication token and saved to {token_path}") except Exception as e: @@ -156,6 +155,8 @@ def _generate_and_save_token() -> str: "Token will only be available in memory." ) + # Cache the generated token + _cached_token = token return token @@ -166,3 +167,58 @@ def _get_default_token_path() -> Path: Path object pointing to ~/.ray/auth_token """ return Path.home() / ".ray" / "auth_token" + + +def setup_and_verify_auth( + system_config: Optional[Dict] = None, is_new_cluster: bool = True +) -> None: + """Verify auth configuration and ensure token is available when auth is enabled. + + This is called early during ray.init() to: + 1. Check for _system_config misuse and provide helpful error + 2. Verify token is available if auth is enabled + 3. Generate default token for new local clusters if needed + + Args: + system_config: The _system_config dict from ray.init() (checked for misuse) + is_new_cluster: True if starting new local cluster, False if connecting to an existing cluster + + Raises: + ValueError: If _system_config is used for enabling auth (should use env var instead) + """ + # Check for _system_config misuse + if system_config and system_config.get("enable_token_auth", False): + raise ValueError( + "Authentication mode should be configured via environment variable, " + "not _system_config (which is for testing only).\n" + "Please set: RAY_enable_token_auth=1\n" + "Or in Python: os.environ['RAY_enable_token_auth'] = '1'" + ) + + Config.initialize("") + + if Config.enable_token_auth(): + # For new clusters: generate token if not found + # For existing clusters: only use existing token (don't generate) + token = load_auth_token(generate_if_not_found=is_new_cluster) + + if not is_new_cluster and not token: + raise RuntimeError( + "Token authentication is enabled on the cluster you're connecting to, " + "but no authentication token was found. Please provide a token using one of:\n" + " 1. RAY_AUTH_TOKEN environment variable\n" + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" + " 3. Default token file: ~/.ray/auth_token" + ) + + +def _reset_token_cache_for_testing(): + """Reset both Python and C++ token caches. + + Should only be used for testing purposes. + """ + global _cached_token + _cached_token = None + + # Also reset the C++ token cache + reset_auth_token_cache() diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 2d3dd00923b3..3e993ce9bd4f 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -62,6 +62,7 @@ from ray._common import ray_option_utils from ray._common.constants import RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR from ray._common.utils import load_class +from ray._private.auth_token_loader import setup_and_verify_auth from ray._private.client_mode_hook import client_mode_hook from ray._private.custom_types import TensorTransportEnum from ray._private.function_manager import FunctionActorManager @@ -1448,7 +1449,6 @@ def init( enable_resource_isolation: bool = False, system_reserved_cpu: Optional[float] = None, system_reserved_memory: Optional[int] = None, - enable_token_auth: bool = False, **kwargs, ) -> BaseContext: """ @@ -1570,10 +1570,6 @@ def init( By default, the min of 10% and 25GB plus object_store_memory will be reserved. Must be >= 100MB and system_reserved_memory + object_store_bytes < total available memory. This option only works if enable_resource_isolation is True. - enable_token_auth: If True, enable token-based authentication for Ray cluster - communication. If no token is found in the environment (RAY_AUTH_TOKEN or - RAY_AUTH_TOKEN_PATH) or default path (~/.ray/auth_token), a new token will be - automatically generated and saved to ~/.ray/auth_token. _cgroup_path: The path for the cgroup the raylet should use to enforce resource isolation. By default, the cgroup used for resource isolation will be /sys/fs/cgroup. The raylet must have read/write permissions to this path. @@ -1670,18 +1666,6 @@ def init( # Fix for https://github.com/ray-project/ray/issues/26729 _skip_env_hook: bool = kwargs.pop("_skip_env_hook", False) - # Handle token-based authentication - if enable_token_auth: - from ray._private.auth_token_loader import load_auth_token - - # Load or generate token - _token = load_auth_token(generate_if_not_found=True) - # Only pass the flag via system_config, NOT the token - # C++ will load token using its own RayAuthTokenLoader - if _system_config is None: - _system_config = {} - _system_config["enable_token_auth"] = "true" - resource_isolation_config = ResourceIsolationConfig( enable_resource_isolation=enable_resource_isolation, cgroup_path=_cgroup_path, @@ -1883,6 +1867,9 @@ def sigterm_handler(signum, frame): if bootstrap_address is None: # In this case, we need to start a new cluster. + # Setup and verify authentication for new cluster + setup_and_verify_auth(_system_config, is_new_cluster=True) + # Don't collect usage stats in ray.init() unless it's a nightly wheel. from ray._common.usage import usage_lib @@ -1970,6 +1957,9 @@ def sigterm_handler(signum, frame): "an existing cluster." ) + # Setup and verify authentication for connecting to existing cluster + setup_and_verify_auth(_system_config, is_new_cluster=False) + # In this case, we only need to connect the node. ray_params = ray._private.parameter.RayParams( node_ip_address=_node_ip_address, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 2ef7be64adeb..bcbe87bc79d4 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -114,6 +114,7 @@ from ray.includes.common cimport ( CConcurrencyGroup, CGrpcStatusCode, CLineageReconstructionTask, + CRayAuthTokenLoader, move, LANGUAGE_CPP, LANGUAGE_JAVA, @@ -4946,3 +4947,12 @@ def get_session_key_from_storage(host, port, username, password, use_ssl, config else: logger.info("Could not retrieve session key from storage.") return None + + +def reset_auth_token_cache(): + """Reset the C++ authentication token cache. + + This forces the RayAuthTokenLoader to reload the token from environment + variables or files on the next request. + """ + CRayAuthTokenLoader.instance().ResetCache() diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index e7e25071e5de..d061ab8ab511 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -797,3 +797,9 @@ cdef extern from "ray/common/constants.h" nogil: cdef const char[] kLabelKeyTpuSliceName cdef const char[] kLabelKeyTpuWorkerId cdef const char[] kLabelKeyTpuPodType + +cdef extern from "ray/rpc/auth_token_loader.h" namespace "ray::rpc" nogil: + cdef cppclass CRayAuthTokenLoader "ray::rpc::RayAuthTokenLoader": + @staticmethod + CRayAuthTokenLoader& instance() + void ResetCache() diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 729395a22ee3..4404b674af06 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -88,3 +88,5 @@ cdef extern from "ray/common/ray_config.h" nogil: c_bool record_task_actor_creation_sites() const c_bool start_python_gc_manager_thread() const + + c_bool enable_token_auth() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 6915e4877962..1a7b104f4eb7 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -144,3 +144,7 @@ cdef class Config: @staticmethod def start_python_gc_manager_thread(): return RayConfig.instance().start_python_gc_manager_thread() + + @staticmethod + def enable_token_auth(): + return RayConfig.instance().enable_token_auth() diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py new file mode 100644 index 000000000000..d055d1294bce --- /dev/null +++ b/python/ray/tests/test_token_auth_integration.py @@ -0,0 +1,188 @@ +"""Integration tests for token-based authentication in Ray.""" + +import os +import sys +from pathlib import Path + +import pytest + +import ray +from ray._private.auth_token_loader import reset_token_cache +from ray.cluster_utils import Cluster + + +@pytest.fixture(autouse=True) +def clean_token_sources(): + """Clean up all token sources before and after each test.""" + # Clean environment variables + env_vars_to_clean = [ + "RAY_AUTH_TOKEN", + "RAY_AUTH_TOKEN_PATH", + "RAY_enable_token_auth", + ] + original_values = {} + for var in env_vars_to_clean: + original_values[var] = os.environ.get(var) + if var in os.environ: + del os.environ[var] + + # Clean default token file + default_token_path = Path.home() / ".ray" / "auth_token" + original_exists = default_token_path.exists() + if original_exists: + original_content = default_token_path.read_text() + default_token_path.unlink() + + # Reset token caches (both Python and C++) + reset_token_cache() + + yield + + # Restore environment variables + for var, value in original_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] + + # Restore default token file + if original_exists: + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text(original_content) + + # Reset token caches again after test + reset_token_cache() + + +def test_local_cluster_generates_token(): + """Test ray.init() generates token for local cluster when enable_token_auth is set.""" + # Ensure no token exists + default_token_path = Path.home() / ".ray" / "auth_token" + assert not default_token_path.exists() + + # Enable token auth via environment variable + os.environ["RAY_enable_token_auth"] = "1" + + # Initialize Ray with token auth + ray.init() + + try: + # Verify token file was created + assert default_token_path.exists() + token = default_token_path.read_text().strip() + assert len(token) == 32 + assert all(c in "0123456789abcdef" for c in token) + + # Verify cluster is working + assert ray.is_initialized() + + finally: + ray.shutdown() + + +def test_connect_without_token_raises_error(): + """Test ray.init(address=...) without token fails when enable_token_auth config is set.""" + # Test the token validation logic directly + # Clear the cached token to ensure we start fresh + import ray._private.auth_token_loader as auth_module + from ray._private.auth_token_loader import load_auth_token + + auth_module._cached_token = None + + # Ensure no token exists + token = load_auth_token(generate_if_not_found=False) + assert token == "" + + # Test the exact error message that would be raised + with pytest.raises(RuntimeError, match="no authentication token was found"): + if not token: + raise RuntimeError( + "Token-based authentication is enabled on the cluster you're connecting to, " + "but no authentication token was found. Please provide a token using one of:\n" + " 1. RAY_AUTH_TOKEN environment variable\n" + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" + " 3. Default token file: ~/.ray/auth_token" + ) + + +def test_token_path_nonexistent_file_fails(): + """Test that setting RAY_AUTH_TOKEN_PATH to nonexistent file fails gracefully.""" + # Enable token auth and set token path to nonexistent file + os.environ["RAY_enable_token_auth"] = "1" + os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" + + # Initialize Ray with token auth should fail + with pytest.raises((FileNotFoundError, RuntimeError)): + ray.init() + + +@pytest.mark.parametrize("tokens_match", [True, False]) +def test_cluster_token_authentication(tokens_match): + """Test cluster authentication with matching and non-matching tokens.""" + # Set up cluster token first + cluster_token = "a" * 32 + os.environ["RAY_AUTH_TOKEN"] = cluster_token + os.environ["RAY_enable_token_auth"] = "1" + + # Create cluster with token auth enabled - node will read current env token + cluster = Cluster() + cluster.add_node() + + try: + # Set client token based on test parameter + if tokens_match: + client_token = cluster_token # Same token - should succeed + else: + client_token = "b" * 32 # Different token - should fail + + os.environ["RAY_AUTH_TOKEN"] = client_token + + # Reset cached token so it reads the new environment variable + reset_token_cache() + + if tokens_match: + # Should succeed - test gRPC calls work + ray.init(address=cluster.address) + + # Test that gRPC calls succeed + obj_ref = ray.put("test_data") + result = ray.get(obj_ref) + assert result == "test_data" + + # Test remote function call + @ray.remote + def test_func(): + return "success" + + result = ray.get(test_func.remote()) + assert result == "success" + + ray.shutdown() + + else: + # Should fail - connection or gRPC calls should fail + with pytest.raises((ConnectionError, RuntimeError)): + ray.init(address=cluster.address) + # If init somehow succeeds, try a gRPC operation that should fail + try: + ray.put("test") + finally: + ray.shutdown() + + finally: + # Ensure cleanup + try: + ray.shutdown() + except: + pass + cluster.shutdown() + + +def test_system_config_auth_raises_error(): + """Test that using _system_config for enabling token auth raises helpful error.""" + with pytest.raises(ValueError, match="environment variable"): + ray.init(_system_config={"enable_token_auth": True}) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py index 5c755dea70c0..cdfd6add8b9f 100644 --- a/python/ray/tests/unit/test_auth_token_loader.py +++ b/python/ray/tests/unit/test_auth_token_loader.py @@ -1,8 +1,8 @@ """Unit tests for ray._private.auth_token_loader module.""" import os +import sys import tempfile -import threading from pathlib import Path import pytest @@ -86,28 +86,50 @@ def test_load_from_default_path(self, default_token_path): token = auth_token_loader.load_auth_token(generate_if_not_found=False) assert token == "token-from-default" - def test_precedence_order(self, temp_token_file, default_token_path): - """Test that token loading follows correct precedence order.""" - # Set all three sources - os.environ["RAY_AUTH_TOKEN"] = "token-from-env" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - - # Environment variable should have highest precedence - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-env" - - def test_env_path_over_default(self, temp_token_file, default_token_path): - """Test that RAY_AUTH_TOKEN_PATH has precedence over default path.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - + @pytest.mark.parametrize( + "set_env,set_file,set_default,expected_token", + [ + # All set: env should win + (True, True, True, "token-from-env"), + # File and default file set: file should win + (False, True, True, "test-token-from-file"), + # Only default file set + (False, False, True, "token-from-default"), + ], + ) + def test_token_precedence_parametrized( + self, + temp_token_file, + default_token_path, + set_env, + set_file, + set_default, + expected_token, + ): + """Parametrized test for token loading precedence: env var > user-specified file > default file.""" + # Optionally set environment variable + if set_env: + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" + else: + os.environ.pop("RAY_AUTH_TOKEN", None) + + # Optionally create file and set path + if set_file: + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file + else: + os.environ.pop("RAY_AUTH_TOKEN_PATH", None) + + # Optionally create default file + if set_default: + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text("token-from-default") + else: + if default_token_path.exists(): + default_token_path.unlink() + + # Load token and verify expected precedence token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "test-token-from-file" + assert token == expected_token def test_no_token_found(self): """Test behavior when no token is found.""" @@ -133,8 +155,8 @@ def test_empty_env_variable(self): def test_nonexistent_path_in_env(self): """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" + with pytest.raises(FileNotFoundError): + auth_token_loader.load_auth_token(generate_if_not_found=False) class TestTokenGeneration: @@ -158,11 +180,27 @@ def test_no_generation_without_flag(self): token = auth_token_loader.load_auth_token(generate_if_not_found=False) assert token == "" - def test_dont_generate_when_token_exists(self): + def test_dont_generate_when_token_exists(self, default_token_path): """Test that token is not generated when one already exists.""" os.environ["RAY_AUTH_TOKEN"] = "existing-token" token = auth_token_loader.load_auth_token(generate_if_not_found=True) assert token == "existing-token" + generated_token = auth_token_loader.generate_and_save_token() + assert generated_token == "existing-token" # does not generate a new token + assert not default_token_path.exists() + + def test_public_generate_and_save_token(self, default_token_path): + """Test the public generate_and_save_token function.""" + token = auth_token_loader.generate_and_save_token() + + # Token should be a 32-character hex string (UUID without dashes) + assert len(token) == 32 + assert all(c in "0123456789abcdef" for c in token) + + # Token should be saved to default path + assert default_token_path.exists() + saved_token = default_token_path.read_text().strip() + assert saved_token == token class TestTokenCaching: @@ -180,139 +218,6 @@ def test_caching_behavior(self): # Should still return the cached token assert token1 == token2 == "cached-token" - def test_cache_empty_result(self): - """Test that even empty results are cached.""" - # First call with no token - token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token1 == "" - - # Set environment variable after first call - os.environ["RAY_AUTH_TOKEN"] = "new-token" - token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Should still return cached empty string - assert token2 == "" - - -class TestHasAuthToken: - """Tests for has_auth_token function.""" - - def test_has_token_true(self): - """Test has_auth_token returns True when token exists.""" - os.environ["RAY_AUTH_TOKEN"] = "test-token" - assert auth_token_loader.has_auth_token() is True - - def test_has_token_false(self): - """Test has_auth_token returns False when no token exists.""" - assert auth_token_loader.has_auth_token() is False - def test_has_token_caches_result(self): - """Test that has_auth_token doesn't trigger generation.""" - # This should return False without generating a token - assert auth_token_loader.has_auth_token() is False - - # Verify no token was generated - default_path = Path.home() / ".ray" / "auth_token" - assert not default_path.exists() - - -class TestThreadSafety: - """Tests for thread safety of token loading.""" - - def test_concurrent_loads(self): - """Test that concurrent token loads are thread-safe.""" - os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token" - - results = [] - threads = [] - - def load_token(): - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - results.append(token) - - # Create multiple threads that try to load token simultaneously - for _ in range(10): - thread = threading.Thread(target=load_token) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # All threads should get the same token - assert len(results) == 10 - assert all(result == "thread-safe-token" for result in results) - - def test_concurrent_generation(self, default_token_path): - """Test that concurrent token generation is thread-safe.""" - results = [] - threads = [] - - def generate_token(): - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - results.append(token) - - # Create multiple threads that try to generate token simultaneously - for _ in range(5): - thread = threading.Thread(target=generate_token) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # All threads should get the same token (only generated once) - assert len(results) == 5 - assert len(set(results)) == 1 # All tokens should be identical - - -class TestFilePermissions: - """Tests for file permissions when saving tokens.""" - - def test_file_permissions_on_unix(self, default_token_path, monkeypatch): - """Test that token file has 0600 permissions on Unix systems.""" - # Skip on Windows - if os.name == "nt": - pytest.skip("Test only relevant on Unix systems") - - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token - - # Check file permissions (should be 0600) - stat_info = default_token_path.stat() - assert stat_info.st_mode & 0o777 == 0o600 - - def test_file_permissions_error_handling(self, monkeypatch): - """Test that permission errors are handled gracefully.""" - - # Mock os.chmod to raise an exception - def mock_chmod(path, mode): - raise OSError("Permission denied") - - monkeypatch.setattr(os, "chmod", mock_chmod) - - # Should still generate and return token, just not set permissions - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert len(token) == 32 - - -class TestIntegration: - """Integration tests with ray.init() and ray start CLI.""" - - def test_token_loader_with_ray_init(self, default_token_path): - """Test that token loader works with ray.init() enable_token_auth parameter.""" - # This is more of a smoke test to ensure the module can be imported - # and used in the context where it will be called - from ray._private import auth_token_loader - - # Generate a token - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token - assert len(token) == 32 - - # Verify it was saved - assert default_token_path.exists() - saved_token = default_token_path.read_text().strip() - assert saved_token == token +if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) From 4274544a523170cdcd23d945cc0beae2d47d7d61 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 05:36:26 +0000 Subject: [PATCH 23/94] missed change Signed-off-by: sampan --- python/ray/_raylet.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bcbe87bc79d4..51b25929d408 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -4951,7 +4951,7 @@ def get_session_key_from_storage(host, port, username, password, use_ssl, config def reset_auth_token_cache(): """Reset the C++ authentication token cache. - + This forces the RayAuthTokenLoader to reload the token from environment variables or files on the next request. """ From 4a5dda979a14de03afd700d91aaf6efd08eaf449 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 05:37:48 +0000 Subject: [PATCH 24/94] fix lint Signed-off-by: sampan --- python/ray/tests/test_token_auth_integration.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index d055d1294bce..3e79e4666f8f 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -171,10 +171,7 @@ def test_func(): finally: # Ensure cleanup - try: - ray.shutdown() - except: - pass + ray.shutdown() cluster.shutdown() From b20e1ef25d51f458b3990d104d5032eb5f18bb2a Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 16:46:21 +0000 Subject: [PATCH 25/94] address comments - version 1 Signed-off-by: sampan --- src/ray/common/constants.h | 3 +- src/ray/common/grpc_util.h | 4 + src/ray/common/ray_config_def.h | 2 +- src/ray/common/status.cc | 2 +- src/ray/common/status.h | 8 +- .../gcs_rpc_client/tests/gcs_client_test.cc | 4 +- src/ray/raylet/node_manager.cc | 2 +- src/ray/rpc/BUILD.bazel | 16 +- src/ray/rpc/authentication/BUILD.bazel | 34 +++ .../rpc/authentication/authentication_mode.cc | 37 ++++ .../rpc/authentication/authentication_mode.h | 33 +++ .../rpc/authentication/authentication_token.h | 146 +++++++++++++ .../authentication_token_loader.cc} | 71 +++--- .../authentication_token_loader.h} | 47 ++-- src/ray/rpc/client_call.h | 28 ++- src/ray/rpc/grpc_server.cc | 8 +- src/ray/rpc/grpc_server.h | 24 ++- src/ray/rpc/server_call.h | 61 ++++-- src/ray/rpc/tests/BUILD.bazel | 20 +- ...cc => authentication_token_loader_test.cc} | 112 +++++----- .../rpc/tests/authentication_token_test.cc | 203 ++++++++++++++++++ src/ray/rpc/tests/grpc_server_client_test.cc | 76 ++++--- 22 files changed, 716 insertions(+), 225 deletions(-) create mode 100644 src/ray/rpc/authentication/BUILD.bazel create mode 100644 src/ray/rpc/authentication/authentication_mode.cc create mode 100644 src/ray/rpc/authentication/authentication_mode.h create mode 100644 src/ray/rpc/authentication/authentication_token.h rename src/ray/rpc/{auth_token_loader.cc => authentication/authentication_token_loader.cc} (71%) rename src/ray/rpc/{auth_token_loader.h => authentication/authentication_token_loader.h} (53%) rename src/ray/rpc/tests/{auth_token_loader_test.cc => authentication_token_loader_test.cc} (72%) create mode 100644 src/ray/rpc/tests/authentication_token_test.cc diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 4905daedc733..bc3d58b70442 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,7 +42,8 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; -constexpr char kAuthTokenKey[] = "ray_auth_token"; +constexpr char kAuthTokenKey[] = "authorization"; +constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"; diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index ae99eaf79081..ed2f8c73eda1 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -83,6 +83,10 @@ inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { if (ray_status.ok()) { return grpc::Status::OK; } + // Map Unauthenticated to gRPC's UNAUTHENTICATED status code + if (ray_status.IsUnauthenticated()) { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, ray_status.message()); + } // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means // more robust. return grpc::Status( diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 799dd4ed70a2..2d1243df8afd 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -36,7 +36,7 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) RAY_CONFIG(bool, enable_cluster_auth, true) /// Whether to enable token-based authentication for RPC calls. -RAY_CONFIG(bool, enable_token_auth, false) +RAY_CONFIG(std::string, auth_mode, "disabled") /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available diff --git a/src/ray/common/status.cc b/src/ray/common/status.cc index 3500ddaf3b80..528a6766412e 100644 --- a/src/ray/common/status.cc +++ b/src/ray/common/status.cc @@ -74,7 +74,7 @@ const absl::flat_hash_map kCodeToStr = { {StatusCode::RpcError, "RpcError"}, {StatusCode::OutOfResource, "OutOfResource"}, {StatusCode::ObjectRefEndOfStream, "ObjectRefEndOfStream"}, - {StatusCode::AuthError, "AuthError"}, + {StatusCode::Unauthenticated, "Unauthenticated"}, {StatusCode::InvalidArgument, "InvalidArgument"}, {StatusCode::ChannelError, "ChannelError"}, {StatusCode::ChannelTimeoutError, "ChannelTimeoutError"}, diff --git a/src/ray/common/status.h b/src/ray/common/status.h index 2544918ac263..f04040cea934 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -263,7 +263,7 @@ enum class StatusCode : char { RpcError = 30, OutOfResource = 31, ObjectRefEndOfStream = 32, - AuthError = 33, + Unauthenticated = 33, // Indicates the input value is not valid. InvalidArgument = 34, // Indicates that a channel (a mutable plasma object) is closed and cannot be @@ -415,8 +415,8 @@ class RAY_EXPORT Status { return Status(StatusCode::OutOfResource, msg); } - static Status AuthError(const std::string &msg) { - return Status(StatusCode::AuthError, msg); + static Status Unauthenticated(const std::string &msg) { + return Status(StatusCode::Unauthenticated, msg); } static Status ChannelError(const std::string &msg) { @@ -475,7 +475,7 @@ class RAY_EXPORT Status { bool IsOutOfResource() const { return code() == StatusCode::OutOfResource; } - bool IsAuthError() const { return code() == StatusCode::AuthError; } + bool IsUnauthenticated() const { return code() == StatusCode::Unauthenticated; } bool IsChannelError() const { return code() == StatusCode::ChannelError; } diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index c22b1b14dee1..bf2b4585649d 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -220,7 +220,7 @@ class GcsClientTest : public ::testing::TestWithParam { auto status = stub->CheckAlive(&context, request, &reply); // If it is in memory, we don't have the new token until we connect again. if (!((!no_redis_ && status.ok()) || - (no_redis_ && GrpcStatusToRayStatus(status).IsAuthError()))) { + (no_redis_ && GrpcStatusToRayStatus(status).IsUnauthenticated()))) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); continue; @@ -996,7 +996,7 @@ TEST_P(GcsClientTest, TestGcsEmptyAuth) { auto status = stub->GetClusterId(&context, request, &reply); // We expect the wrong cluster ID - EXPECT_TRUE(GrpcStatusToRayStatus(status).IsAuthError()); + EXPECT_TRUE(GrpcStatusToRayStatus(status).IsUnauthenticated()); } TEST_P(GcsClientTest, TestGcsAuth) { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 56683ef32250..835abf47497e 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -390,7 +390,7 @@ void NodeManager::RegisterGcs() { << "GCS consider this node to be dead. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " << "in the DB."; - } else if (status.IsAuthError()) { + } else if (status.IsUnauthenticated()) { RAY_LOG(FATAL) << "GCS returned an authentication error. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index c178c4279a95..62a37fa12869 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -15,7 +15,7 @@ ray_cc_library( "//src/ray/core_worker:__pkg__", ], deps = [ - ":auth_token_loader", + "//src/ray/rpc/authentication:authentication_token_loader", ":rpc_callback_types", "//src/ray/common:asio", "//src/ray/common:grpc_util", @@ -105,29 +105,19 @@ ray_cc_library( "//src/ray/common:id", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token", "//src/ray/stats:stats_metric", "@com_github_grpc_grpc//:grpc++", ], ) -ray_cc_library( - name = "auth_token_loader", - srcs = ["auth_token_loader.cc"], - hdrs = ["auth_token_loader.h"], - visibility = ["//visibility:public"], - deps = [ - "//src/ray/common:ray_config", - "//src/ray/util:logging", - ], -) - ray_cc_library( name = "grpc_server", srcs = ["grpc_server.cc"], hdrs = ["grpc_server.h"], visibility = ["//visibility:public"], deps = [ - ":auth_token_loader", + "//src/ray/rpc/authentication:authentication_token_loader", ":common", ":server_call", "//src/ray/common:asio", diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel new file mode 100644 index 000000000000..4af7dff26e7d --- /dev/null +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -0,0 +1,34 @@ +load("//bazel:ray.bzl", "ray_cc_library", "ray_cc_test") + +ray_cc_library( + name = "authentication_mode", + srcs = ["authentication_mode.cc"], + hdrs = ["authentication_mode.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings", + "//src/ray/common:ray_config", + ], +) + +ray_cc_library( + name = "authentication_token", + hdrs = ["authentication_token.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:constants", + "@com_github_grpc_grpc//:grpc++", + ], +) + +ray_cc_library( + name = "authentication_token_loader", + srcs = ["authentication_token_loader.cc"], + hdrs = ["authentication_token_loader.h"], + visibility = ["//visibility:public"], + deps = [ + ":authentication_mode", + ":authentication_token", + "//src/ray/util:logging", + ], +) diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc new file mode 100644 index 000000000000..b629fe9b618d --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -0,0 +1,37 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_mode.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "ray/common/ray_config.h" + +namespace ray { +namespace rpc { + +AuthenticationMode GetAuthenticationMode() { + std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); + + if (auth_mode_lower == "ray_token") { + return AuthenticationMode::RAY_TOKEN; + } else { + return AuthenticationMode::DISABLED; + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_mode.h b/src/ray/rpc/authentication/authentication_mode.h new file mode 100644 index 000000000000..4acf33e29276 --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.h @@ -0,0 +1,33 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace rpc { + +enum class AuthenticationMode { + DISABLED, + RAY_TOKEN, +}; + +/// Get the authentication mode from the RayConfig. +/// \return The authentication mode enum value. returns AuthenticationMode::DISABLED if +/// the authentication mode is not set or is invalid. +AuthenticationMode GetAuthenticationMode(); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h new file mode 100644 index 000000000000..9ab441b331fa --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token.h @@ -0,0 +1,146 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "ray/common/constants.h" + +namespace ray { +namespace rpc { + +/// Secure wrapper for authentication tokens. +/// - Wipes memory on destruction +/// - Constant-time comparison +/// - Redacted output when logged or printed +class AuthenticationToken { + public: + AuthenticationToken() = default; + explicit AuthenticationToken(std::string value) : secret_(value.begin(), value.end()) {} + + // Copy operations - allowed for caching, but use sparingly + AuthenticationToken(const AuthenticationToken &other) : secret_(other.secret_) {} + AuthenticationToken &operator=(const AuthenticationToken &other) { + if (this != &other) { + SecureClear(); + secret_ = other.secret_; + } + return *this; + } + + // Move operations + AuthenticationToken(AuthenticationToken &&other) noexcept { + MoveFrom(std::move(other)); + } + AuthenticationToken &operator=(AuthenticationToken &&other) noexcept { + if (this != &other) { + SecureClear(); + MoveFrom(std::move(other)); + } + return *this; + } + ~AuthenticationToken() { SecureClear(); } + + bool empty() const noexcept { return secret_.empty(); } + + /// Constant-time equality comparison + bool Equals(const AuthenticationToken &other) const noexcept { + return ConstTimeEqual(secret_, other.secret_); + } + + /// Set authentication metadata on a gRPC client context + /// Only call this from client-side code + void SetMetadata(grpc::ClientContext &context) const { + if (!secret_.empty()) { + context.AddMetadata(kAuthTokenKey, + kBearerPrefix + std::string(secret_.begin(), secret_.end())); + } + } + + /// Create AuthenticationToken from gRPC metadata value + /// Strips "Bearer " prefix and creates token object + /// @param metadata_value The raw value from server metadata (should include "Bearer " + /// prefix) + /// @return AuthenticationToken object (empty if format invalid) + static AuthenticationToken FromMetadata(std::string_view metadata_value) { + const std::string_view prefix(kBearerPrefix, sizeof(kBearerPrefix) - 1); + if (metadata_value.size() <= prefix.size() || + metadata_value.substr(0, prefix.size()) != prefix) { + return AuthenticationToken(); // Invalid format, return empty + } + std::string_view token_part = metadata_value.substr(prefix.size()); + return AuthenticationToken(std::string(token_part)); + } + + friend std::ostream &operator<<(std::ostream &os, const AuthenticationToken &t) { + return os << ""; + } + + private: + std::vector secret_; + + // Constant-time string comparison to avoid timing attacks. + // https://en.wikipedia.org/wiki/Timing_attack + static bool ConstTimeEqual(const std::vector &a, + const std::vector &b) noexcept { + if (a.size() != b.size()) { + return false; + } + unsigned char diff = 0; + for (size_t i = 0; i < a.size(); ++i) { + diff |= a[i] ^ b[i]; + } + return diff == 0; + } + + static void ExplicitBurn(void *p, size_t n) noexcept { +#if defined(_MSC_VER) + SecureZeroMemory(p, n); +#elif defined(__STDC_LIB_EXT1__) + memset_s(p, n, 0, n); +#else + // Using array indexing instead of pointer arithmetic + volatile auto *vp = static_cast(p); + for (size_t i = 0; i < n; ++i) { + vp[i] = 0; + } +#endif + } + + void SecureClear() noexcept { + if (!secret_.empty()) { + ExplicitBurn(secret_.data(), secret_.size()); + secret_.clear(); + secret_.shrink_to_fit(); + } + } + + void MoveFrom(AuthenticationToken &&other) noexcept { + secret_ = std::move(other.secret_); + // Clear the moved-from object explicitly for security + // Note: 'other' is already an rvalue reference, no need to move again + other.SecureClear(); + } +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/auth_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc similarity index 71% rename from src/ray/rpc/auth_token_loader.cc rename to src/ray/rpc/authentication/authentication_token_loader.cc index 3019757252a5..bf270cfdaa8a 100644 --- a/src/ray/rpc/auth_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ray/rpc/auth_token_loader.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include -#include -#include -#include "ray/common/ray_config.h" #include "ray/util/logging.h" #if defined(__APPLE__) || defined(__linux__) @@ -40,12 +37,12 @@ namespace ray { namespace rpc { -RayAuthTokenLoader &RayAuthTokenLoader::instance() { - static RayAuthTokenLoader instance; +AuthenticationTokenLoader &AuthenticationTokenLoader::instance() { + static AuthenticationTokenLoader instance; return instance; } -const std::string &RayAuthTokenLoader::GetToken() { +std::optional AuthenticationTokenLoader::GetToken() { std::lock_guard lock(token_mutex_); // If already loaded, return cached value @@ -53,51 +50,47 @@ const std::string &RayAuthTokenLoader::GetToken() { return *cached_token_; } - // If token auth is disabled, return empty string without loading - if (!RayConfig::instance().enable_token_auth()) { - cached_token_ = ""; - return *cached_token_; + // If token auth is disabled, return std::nullopt + if (GetAuthenticationMode() == AuthenticationMode::DISABLED) { + cached_token_ = std::nullopt; + return std::nullopt; } // Token auth is enabled, try to load from sources - std::string token = LoadTokenFromSources(); - - // If no token found and auth is enabled, throw error - if (token.empty()) { - RAY_LOG(ERROR) << "Token authentication is enabled but no authentication token was " - "found. Please set RAY_AUTH_TOKEN environment variable, " - "RAY_AUTH_TOKEN_PATH to a file containing the token, or create a " - "token file at ~/.ray/auth_token"; - throw std::runtime_error( - "Token authentication is enabled but no authentication token was found"); - } + AuthenticationToken token = LoadTokenFromSources(); + + // If no token found and auth is enabled, fail with RAY_CHECK + RAY_CHECK(!token.empty()) + << "Token authentication is enabled but no authentication token was found. " + << "Please set RAY_AUTH_TOKEN environment variable, RAY_AUTH_TOKEN_PATH to a file " + << "containing the token, or create a token file at ~/.ray/auth_token"; // Cache and return the loaded token - cached_token_ = token; + cached_token_ = std::move(token); return *cached_token_; } -bool RayAuthTokenLoader::HasToken() { +bool AuthenticationTokenLoader::HasToken() { std::lock_guard lock(token_mutex_); - // If already loaded, check if non-empty + // If already loaded, check if present if (cached_token_.has_value()) { - return !cached_token_->empty(); + return cached_token_->has_value(); } // If token auth is disabled, no token needed - if (!RayConfig::instance().enable_token_auth()) { + if (GetAuthenticationMode() == AuthenticationMode::DISABLED) { return false; } // Try to load token - std::string token = LoadTokenFromSources(); + AuthenticationToken token = LoadTokenFromSources(); return !token.empty(); } // Read token from the first line of the file. trim whitespace. // Returns empty string if file cannot be opened or is empty. -std::string RayAuthTokenLoader::ReadTokenFromFile(const std::string &file_path) { +std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { std::ifstream token_file(file_path); if (!token_file.is_open()) { return ""; @@ -115,22 +108,22 @@ std::string RayAuthTokenLoader::ReadTokenFromFile(const std::string &file_path) return token; } -std::string RayAuthTokenLoader::LoadTokenFromSources() { +AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { // Precedence 1: RAY_AUTH_TOKEN environment variable const char *env_token = std::getenv("RAY_AUTH_TOKEN"); if (env_token != nullptr && std::string(env_token).length() > 0) { RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " "variable"; - return std::string(env_token); + return AuthenticationToken(std::string(env_token)); } // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { - std::string token = ReadTokenFromFile(env_token_path); - if (!token.empty()) { + std::string token_str = ReadTokenFromFile(env_token_path); + if (!token_str.empty()) { RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; - return token; + return AuthenticationToken(token_str); } else { RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " << env_token_path; @@ -139,18 +132,18 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() { // Precedence 3: Default token path ~/.ray/auth_token std::string default_path = GetDefaultTokenPath(); - std::string token = ReadTokenFromFile(default_path); - if (!token.empty()) { + std::string token_str = ReadTokenFromFile(default_path); + if (!token_str.empty()) { RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; - return token; + return AuthenticationToken(token_str); } // No token found RAY_LOG(DEBUG) << "No authentication token found in any source"; - return ""; + return AuthenticationToken(); } -std::string RayAuthTokenLoader::GetDefaultTokenPath() { +std::string AuthenticationTokenLoader::GetDefaultTokenPath() { std::string home_dir; #ifdef _WIN32 diff --git a/src/ray/rpc/auth_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h similarity index 53% rename from src/ray/rpc/auth_token_loader.h rename to src/ray/rpc/authentication/authentication_token_loader.h index af798a081e41..c392fdc11b2a 100644 --- a/src/ray/rpc/auth_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -6,11 +6,11 @@ // // http://www.apache.org/licenses/LICENSE-2.0 // -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// See the License for the specific language governing permissions +// and limitations under the License. #pragma once @@ -18,6 +18,9 @@ #include #include +#include "ray/rpc/authentication/authentication_mode.h" +#include "ray/rpc/authentication/authentication_token.h" + namespace ray { namespace rpc { @@ -26,51 +29,43 @@ namespace rpc { /// 1. RAY_AUTH_TOKEN environment variable /// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) /// 3. Default token path: ~/.ray/auth_token (Unix) or %USERPROFILE%\.ray\auth_token -/// (Windows) /// /// Thread-safe with internal caching to avoid repeated file I/O. -class RayAuthTokenLoader { +class AuthenticationTokenLoader { public: - /// Get the singleton instance. - static RayAuthTokenLoader &instance(); + static AuthenticationTokenLoader &instance(); /// Get the authentication token. - /// If token authentication is enabled but no token is found, throws an error. - /// \return The authentication token, or empty string if auth is disabled. - const std::string &GetToken(); + /// If token authentication is enabled but no token is found, fails with RAY_CHECK. + /// \return The authentication token, or std::nullopt if auth is disabled. + std::optional GetToken(); /// Check if an authentication token exists. - /// \return True if a token is available (cached or can be loaded). bool HasToken(); - /// Reset the cached token. This ensures that the next call to GetToken() returns the - /// token from the sources. void ResetCache() { std::lock_guard lock(token_mutex_); cached_token_.reset(); } - // Prevent copying and moving - RayAuthTokenLoader(const RayAuthTokenLoader &) = delete; - RayAuthTokenLoader &operator=(const RayAuthTokenLoader &) = delete; + AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete; + AuthenticationTokenLoader &operator=(const AuthenticationTokenLoader &) = delete; private: - RayAuthTokenLoader() = default; - ~RayAuthTokenLoader() = default; + AuthenticationTokenLoader() = default; + ~AuthenticationTokenLoader() = default; - /// Read and trim token from a file. Returns empty string if file cannot be opened or is - /// empty. + /// Read and trim token from file. std::string ReadTokenFromFile(const std::string &file_path); - /// Load token from available sources (env vars and file). - std::string LoadTokenFromSources(); + /// Load token from environment or file. + AuthenticationToken LoadTokenFromSources(); - /// Get the default token file path (~/.ray/auth_token on Unix, - /// %USERPROFILE%\.ray\auth_token on Windows). + /// Default token file path (~/.ray/auth_token or %USERPROFILE%\.ray\auth_token). std::string GetDefaultTokenPath(); std::mutex token_mutex_; - std::optional cached_token_; + std::optional> cached_token_; }; } // namespace rpc diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index bd754dc93c6f..d344e70f6e13 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -31,7 +31,8 @@ #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/rpc/auth_token_loader.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/stats/metric_defs.h" #include "ray/util/thread_utils.h" @@ -73,7 +74,7 @@ class ClientCallImpl : public ClientCall { /// \param[in] callback The callback function to handle the reply. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, - const std::string &auth_token, + const std::optional &auth_token, std::shared_ptr stats_handle, bool record_stats, int64_t timeout_ms = -1) @@ -88,9 +89,9 @@ class ClientCallImpl : public ClientCall { if (!cluster_id.IsNil()) { context_.AddMetadata(kClusterIdKey, cluster_id.Hex()); } - // Add authentication token if provided (empty = disabled) - if (!auth_token.empty()) { - context_.AddMetadata(kAuthTokenKey, auth_token); + // Add authentication token if provided + if (auth_token.has_value()) { + auth_token->SetMetadata(context_); } } @@ -220,7 +221,6 @@ class ClientCallManager { int num_threads = 1, int64_t call_timeout_ms = -1) : cluster_id_(cluster_id), - auth_token_(RayAuthTokenLoader::instance().GetToken()), main_service_(main_service), num_threads_(num_threads), record_stats_(record_stats), @@ -279,12 +279,13 @@ class ClientCallManager { method_timeout_ms = call_timeout_ms_; } - auto call = std::make_shared>(callback, - cluster_id_, - auth_token_, - std::move(stats_handle), - record_stats_, - method_timeout_ms); + auto call = std::make_shared>( + callback, + cluster_id_, + AuthenticationTokenLoader::instance().GetToken(), + std::move(stats_handle), + record_stats_, + method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( @@ -368,9 +369,6 @@ class ClientCallManager { /// and setting the cluster ID. ClusterID cluster_id_; - /// Cached authentication token for token-based authentication. - const std::string auth_token_; - /// The main event loop, to which the callback functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 359c97517d8d..e471bf7e39de 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -26,7 +26,7 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" -#include "ray/rpc/auth_token_loader.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/common.h" #include "ray/util/network_util.h" #include "ray/util/thread_utils.h" @@ -183,13 +183,9 @@ void GrpcServer::RegisterService(std::unique_ptr &&service, if (cluster_id_auth_enabled && cluster_id_.IsNil()) { RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; } - // Use override token if set, otherwise load from RayAuthTokenLoader - std::string auth_token = auth_token_override_.empty() - ? RayAuthTokenLoader::instance().GetToken() - : auth_token_override_; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( - cqs_[i], &server_call_factories_, cluster_id_, auth_token); + cqs_[i], &server_call_factories_, cluster_id_, auth_token_); } services_.push_back(std::move(service)); } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index f61d89929d6a..bf7eb7f8c5d1 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -24,6 +24,8 @@ #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/server_call.h" namespace ray { @@ -93,13 +95,20 @@ class GrpcServer { const uint32_t port, bool listen_to_localhost_only, int num_threads = 1, - int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/) + int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ + std::optional auth_token = std::nullopt) : name_(std::move(name)), port_(port), listen_to_localhost_only_(listen_to_localhost_only), is_shutdown_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { + // Initialize auth token: use provided value or load from AuthenticationTokenLoader + if (auth_token.has_value()) { + auth_token_ = std::move(auth_token.value()); + } else { + auth_token_ = AuthenticationTokenLoader::instance().GetToken(); + } Init(); } @@ -142,13 +151,6 @@ class GrpcServer { cluster_id_ = cluster_id; } - /// Set an override token for testing. This takes precedence over RayAuthTokenLoader. - /// This is primarily for testing purposes. - /// \param override_token The override token to use for this server. - void SetAuthTokenOverride(const std::string &override_token) { - auth_token_override_ = override_token; - } - protected: /// Initialize this server. void Init(); @@ -167,8 +169,8 @@ class GrpcServer { const bool listen_to_localhost_only_; /// Token representing ID of this cluster. ClusterID cluster_id_; - /// Override token for testing. If set, this takes precedence over RayAuthTokenLoader. - std::string auth_token_override_; + /// Authentication token for token-based authentication. + std::optional auth_token_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. @@ -226,7 +228,7 @@ class GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) = 0; + const std::optional &auth_token) = 0; /// The main event loop, to which the service handler functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 79e2c58179d0..c391203f6793 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "ray/common/asio/asio_chaos.h" @@ -28,6 +29,7 @@ #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/stats/metric.h" #include "ray/stats/metric_defs.h" @@ -172,7 +174,7 @@ class ServerCallImpl : public ServerCall { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, - const std::string &auth_token, + const std::optional &auth_token, bool record_metrics, std::function preprocess_function = nullptr) : state_(ServerCallState::PENDING), @@ -203,15 +205,9 @@ class ServerCallImpl : public ServerCall { bool cluster_id_auth_failed = false; // Token authentication - // Empty token = no authentication required - if (!auth_token_.empty()) { - auto &metadata = context_.client_metadata(); - auto it = metadata.find(kAuthTokenKey); - if (it == metadata.end() || it->second != auth_token_) { - RAY_LOG(WARNING) << "Invalid or missing auth token in request!"; - auth_success = false; - token_auth_failed = true; - } + if (!ValidateBearerToken()) { + auth_success = false; + token_auth_failed = true; } // Cluster ID authentication @@ -259,10 +255,10 @@ class ServerCallImpl : public ServerCall { if (auth_success) { SendReply(Status::Invalid("HandleServiceClosed")); } else if (token_auth_failed) { - SendReply(Status::AuthError( + SendReply(Status::Unauthenticated( "InvalidAuthToken: Authentication token is missing or incorrect")); } else { - SendReply(Status::AuthError("WrongClusterID")); + SendReply(Status::Unauthenticated("WrongClusterID")); } } } @@ -287,12 +283,12 @@ class ServerCallImpl : public ServerCall { if (!auth_success) { boost::asio::post(GetServerCallExecutor(), [this, token_auth_failed]() { if (token_auth_failed) { - SendReply(Status::AuthError( + SendReply(Status::Unauthenticated( "InvalidAuthToken: Authentication token is missing or incorrect")); } else { - SendReply( - Status::AuthError("WrongClusterID: Perhaps the client is accessing GCS " - "after it has restarted.")); + SendReply(Status::Unauthenticated( + "WrongClusterID: Perhaps the client is accessing GCS " + "after it has restarted.")); } }); } else { @@ -342,6 +338,32 @@ class ServerCallImpl : public ServerCall { const ServerCallFactory &GetServerCallFactory() override { return factory_; } private: + /// Validates token-based authentication. + /// Returns true if authentication succeeds or is not required. + /// Returns false if authentication is required but fails. + bool ValidateBearerToken() { + if (!auth_token_.has_value() || auth_token_->empty()) { + return true; // No auth required + } + + const auto &metadata = context_.client_metadata(); + auto it = metadata.find(kAuthTokenKey); + if (it == metadata.end()) { + RAY_LOG(WARNING) << "Missing authorization header in request!"; + return false; + } + + const std::string_view header(it->second.data(), it->second.length()); + AuthenticationToken provided_token = AuthenticationToken::FromMetadata(header); + + if (!auth_token_->Equals(provided_token)) { + RAY_LOG(WARNING) << "Invalid bearer token in request!"; + return false; + } + + return true; + } + /// Log the duration this query used void LogProcessTime() { EventTracker::RecordEnd(std::move(stats_handle_)); @@ -410,8 +432,7 @@ class ServerCallImpl : public ServerCall { const ClusterID &cluster_id_; /// Authentication token for token-based authentication. - /// Empty string means no token authentication. - const std::string auth_token_; + std::optional auth_token_; /// The callback when sending reply successes. std::function send_reply_success_callback_ = nullptr; @@ -482,7 +503,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, - const std::string &auth_token, + const std::optional &auth_token, int64_t max_active_rpcs, bool record_metrics) : service_(service), @@ -548,7 +569,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { const ClusterID cluster_id_; /// Authentication token for token-based authentication. - const std::string auth_token_; + std::optional auth_token_; /// Maximum request number to handle at the same time. /// -1 means no limit. diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 01541804a684..6e303b53d47b 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -42,14 +42,28 @@ ray_cc_test( ) ray_cc_test( - name = "auth_token_loader_test", + name = "authentication_token_loader_test", size = "small", srcs = [ - "auth_token_loader_test.cc", + "authentication_token_loader_test.cc", ], tags = ["team:core"], deps = [ - "//src/ray/rpc:auth_token_loader", + "//src/ray/rpc/authentication:authentication_token_loader", + "//src/ray/common:ray_config", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "authentication_token_test", + size = "small", + srcs = [ + "authentication_token_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/rpc/authentication:authentication_token", "@com_google_googletest//:gtest_main", ], ) diff --git a/src/ray/rpc/tests/auth_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc similarity index 72% rename from src/ray/rpc/tests/auth_token_loader_test.cc rename to src/ray/rpc/tests/authentication_token_loader_test.cc index d253e345cf3a..024c50bcb32f 100644 --- a/src/ray/rpc/tests/auth_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -12,16 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "ray/rpc/auth_token_loader.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include #include -#include -#include #include "gtest/gtest.h" #include "ray/common/ray_config.h" -#include "ray/util/logging.h" #if defined(__APPLE__) || defined(__linux__) #include @@ -44,11 +41,11 @@ namespace ray { namespace rpc { -class RayAuthTokenLoaderTest : public ::testing::Test { +class AuthenticationTokenLoaderTest : public ::testing::Test { protected: void SetUp() override { // Enable token authentication for tests - RayConfig::instance().initialize(R"({"enable_token_auth": true})"); + RayConfig::instance().initialize(R"({"auth_mode": "ray_token"})"); // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory // This ensures tests work in environments where HOME isn't provided @@ -84,16 +81,16 @@ class RayAuthTokenLoaderTest : public ::testing::Test { #endif cleanup_env(); // Reset the singleton's cached state for test isolation - RayAuthTokenLoader::instance().ResetCache(); + AuthenticationTokenLoader::instance().ResetCache(); } void TearDown() override { // Clean up after test cleanup_env(); // Reset the singleton's cached state for test isolation - RayAuthTokenLoader::instance().ResetCache(); + AuthenticationTokenLoader::instance().ResetCache(); // Disable token auth after tests - RayConfig::instance().initialize(R"({"enable_token_auth": false})"); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); } void cleanup_env() { @@ -159,19 +156,21 @@ class RayAuthTokenLoaderTest : public ::testing::Test { std::string test_home_dir_; // Fallback home directory for tests }; -TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) { +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { // Set token in environment variable set_env_var("RAY_AUTH_TOKEN", "test-token-from-env"); // Create a new instance to avoid cached state - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); - EXPECT_EQ(token, "test-token-from-env"); + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-env"); + EXPECT_TRUE(token_opt->Equals(expected)); EXPECT_TRUE(loader.HasToken()); } -TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { // Create a temporary token file std::string temp_token_path = get_temp_token_path(); write_token_file(temp_token_path, "test-token-from-file"); @@ -179,25 +178,29 @@ TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) { // Set path in environment variable set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); - EXPECT_EQ(token, "test-token-from-file"); + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-file"); + EXPECT_TRUE(token_opt->Equals(expected)); EXPECT_TRUE(loader.HasToken()); // Clean up remove(temp_token_path.c_str()); } -TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) { +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { // Create directory and token file in default location ensure_ray_dir_exists(); write_token_file(default_token_path_, "test-token-from-default"); - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); - EXPECT_EQ(token, "test-token-from-default"); + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-default"); + EXPECT_TRUE(token_opt->Equals(expected)); EXPECT_TRUE(loader.HasToken()); } @@ -214,12 +217,12 @@ struct TokenSourceConfig { std::string default_token = "token-from-default"; }; -class RayAuthTokenLoaderPrecedenceTest - : public RayAuthTokenLoaderTest, +class AuthenticationTokenLoaderPrecedenceTest + : public AuthenticationTokenLoaderTest, public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, - RayAuthTokenLoaderPrecedenceTest, + AuthenticationTokenLoaderPrecedenceTest, ::testing::Values( // All set: env should win TokenSourceConfig{true, true, true, "token-from-env"}, @@ -229,7 +232,7 @@ INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, TokenSourceConfig{ false, false, true, "token-from-default"})); -TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { +TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { const auto ¶m = GetParam(); // Optionally set environment variable @@ -257,10 +260,12 @@ TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { } // Always create a new instance to avoid cached state - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); - EXPECT_EQ(token, param.expected_token); + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected(param.expected_token); + EXPECT_TRUE(token_opt->Equals(expected)); // Clean up token file if it was written if (param.set_file) { @@ -272,55 +277,64 @@ TEST_P(RayAuthTokenLoaderPrecedenceTest, Precedence) { } } -TEST_F(RayAuthTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { +TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { // Disable auth for this specific test - RayConfig::instance().initialize(R"({"enable_token_auth": false})"); - RayAuthTokenLoader::instance().ResetCache(); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); // No token set anywhere, but auth is disabled - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); - EXPECT_EQ(token, ""); + EXPECT_FALSE(token_opt.has_value()); EXPECT_FALSE(loader.HasToken()); // Re-enable for other tests - RayConfig::instance().initialize(R"({"enable_token_auth": true})"); + RayConfig::instance().initialize(R"({"auth_mode": "ray_token"})"); } -TEST_F(RayAuthTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { +TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { // Token auth is already enabled in SetUp() - // No token exists, should throw an error - auto &loader = RayAuthTokenLoader::instance(); - EXPECT_THROW(loader.GetToken(), std::runtime_error); + // No token exists, should trigger RAY_CHECK failure + EXPECT_DEATH( + { + auto &loader = AuthenticationTokenLoader::instance(); + loader.GetToken(); + }, + "Token authentication is enabled but no authentication token was found"); } -TEST_F(RayAuthTokenLoaderTest, TestCaching) { +TEST_F(AuthenticationTokenLoaderTest, TestCaching) { // Set token in environment set_env_var("RAY_AUTH_TOKEN", "cached-token"); - auto &loader = RayAuthTokenLoader::instance(); - std::string token1 = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt1 = loader.GetToken(); // Change environment variable (shouldn't affect cached value) set_env_var("RAY_AUTH_TOKEN", "new-token"); - std::string token2 = loader.GetToken(); + auto token_opt2 = loader.GetToken(); // Should still return the cached token - EXPECT_EQ(token1, token2); - EXPECT_EQ(token2, "cached-token"); + ASSERT_TRUE(token_opt1.has_value()); + ASSERT_TRUE(token_opt2.has_value()); + EXPECT_TRUE(token_opt1->Equals(*token_opt2)); + AuthenticationToken expected("cached-token"); + EXPECT_TRUE(token_opt2->Equals(expected)); } -TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) { +TEST_F(AuthenticationTokenLoaderTest, TestWhitespaceHandling) { // Create token file with whitespace ensure_ray_dir_exists(); write_token_file(default_token_path_, " token-with-spaces \n\t"); - auto &loader = RayAuthTokenLoader::instance(); - std::string token = loader.GetToken(); + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); // Whitespace should be trimmed - EXPECT_EQ(token, "token-with-spaces"); + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("token-with-spaces"); + EXPECT_TRUE(token_opt->Equals(expected)); } } // namespace rpc diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc new file mode 100644 index 000000000000..10209e6dee90 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -0,0 +1,203 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace ray { +namespace rpc { + +class AuthenticationTokenTest : public ::testing::Test {}; + +TEST_F(AuthenticationTokenTest, TestDefaultConstructor) { + AuthenticationToken token; + EXPECT_TRUE(token.empty()); +} + +TEST_F(AuthenticationTokenTest, TestConstructorWithValue) { + AuthenticationToken token("test-token-value"); + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token-value"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestMoveConstructor) { + AuthenticationToken token1("original-token"); + AuthenticationToken token2(std::move(token1)); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("original-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestMoveAssignment) { + AuthenticationToken token1("first-token"); + AuthenticationToken token2("second-token"); + + token2 = std::move(token1); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("first-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestSelfMoveAssignment) { + AuthenticationToken token("test-token"); + + // Self-assignment should not break the token + token = std::move(token); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestEquals) { + AuthenticationToken token1("same-token"); + AuthenticationToken token2("same-token"); + AuthenticationToken token3("different-token"); + + EXPECT_TRUE(token1.Equals(token2)); + EXPECT_FALSE(token1.Equals(token3)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityDifferentLengths) { + AuthenticationToken token1("short"); + AuthenticationToken token2("much-longer-token"); + + EXPECT_FALSE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyTokens) { + AuthenticationToken token1; + AuthenticationToken token2; + + EXPECT_TRUE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyVsNonEmpty) { + AuthenticationToken token1; + AuthenticationToken token2("non-empty"); + + EXPECT_FALSE(token1.Equals(token2)); + EXPECT_FALSE(token2.Equals(token1)); +} + +TEST_F(AuthenticationTokenTest, TestRedactedOutput) { + AuthenticationToken token("super-secret-token"); + + std::ostringstream oss; + oss << token; + + std::string output = oss.str(); + EXPECT_EQ(output, ""); + EXPECT_EQ(output.find("super-secret-token"), std::string::npos); +} + +TEST_F(AuthenticationTokenTest, TestEmptyString) { + AuthenticationToken token(""); + EXPECT_TRUE(token.empty()); + AuthenticationToken expected(""); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestSpecialCharacters) { + std::string special = "token-with-special!@#$%^&*()_+={}[]|\\:;\"'<>,.?/~`"; + AuthenticationToken token(special); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected(special); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestUnicodeCharacters) { + std::string unicode = "token-with-unicode-café-😀"; + AuthenticationToken token(unicode); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected(unicode); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestBinaryData) { + std::string binary; + for (int i = 0; i < 256; ++i) { + binary += static_cast(i); + } + + AuthenticationToken token(binary); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected(binary); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestLongToken) { + std::string long_token(10000, 'x'); + AuthenticationToken token(long_token); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected(long_token); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestConstTimeComparison) { + // This test verifies that comparison works correctly + // Actual timing attack resistance would require specialized timing tests + AuthenticationToken token1("token-abc"); + AuthenticationToken token2("token-xyz"); + AuthenticationToken token3("token-abc"); + + EXPECT_FALSE(token1.Equals(token2)); + EXPECT_TRUE(token1.Equals(token3)); +} + +TEST_F(AuthenticationTokenTest, TestMoveClearsOriginal) { + AuthenticationToken token1("test-token"); + AuthenticationToken expected("test-token"); + + AuthenticationToken token2(std::move(token1)); + + // Original should be empty after move + EXPECT_TRUE(token1.empty()); + // New token should have the value + EXPECT_TRUE(token2.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestMoveAssignmentClearsOriginal) { + AuthenticationToken token1("test-token"); + AuthenticationToken token2("other-token"); + AuthenticationToken expected("test-token"); + + token2 = std::move(token1); + + // Original should be empty after move + EXPECT_TRUE(token1.empty()); + // New token should have the value + EXPECT_TRUE(token2.Equals(expected)); +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 5df0b85d478b..729ece195e46 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -13,12 +13,13 @@ // limitations under the License. #include +#include #include #include #include #include "gtest/gtest.h" -#include "ray/rpc/auth_token_loader.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" @@ -90,7 +91,7 @@ class TestGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override { + const std::optional &auth_token) override { RPC_SERVICE_HANDLER_CUSTOM_AUTH( TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH( @@ -333,20 +334,20 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { public: void SetUp() override { // Configure token auth via RayConfig - std::string config_json = R"({"enable_token_auth": true})"; + std::string config_json = R"({"auth_mode": "ray_token"})"; RayConfig::instance().initialize(config_json); - RayAuthTokenLoader::instance().ResetCache(); + AuthenticationTokenLoader::instance().ResetCache(); } void SetUpServerAndClient(const std::string &server_token, const std::string &client_token) { // Set client token in environment for ClientCallManager to read from - // RayAuthTokenLoader + // AuthenticationTokenLoader if (!client_token.empty()) { setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); } else { - RayConfig::instance().initialize(R"({"enable_token_auth": false})"); - RayAuthTokenLoader::instance().ResetCache(); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); unsetenv("RAY_AUTH_TOKEN"); } @@ -365,13 +366,12 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { }); // Create and start server - // In production, server would automatically use RayAuthTokenLoader, but for testing - // we can override with SetAuthTokenOverride to test mismatched token scenarios - grpc_server_.reset(new GrpcServer("test", 0, true)); + // Pass server token explicitly for testing scenarios with different tokens + std::optional server_auth_token; if (!server_token.empty()) { - // Set server token explicitly for testing scenarios with different tokens - grpc_server_->SetAuthTokenOverride(server_token); + server_auth_token = AuthenticationToken(server_token); } + grpc_server_.reset(new GrpcServer("test", 0, true, 1, 7200000, server_auth_token)); grpc_server_->RegisterService( std::make_unique(handler_io_service_, test_service_handler_), false); @@ -381,7 +381,7 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - // Create client (will read auth token from RayAuthTokenLoader which reads the + // Create client (will read auth token from AuthenticationTokenLoader which reads the // environment) client_call_manager_.reset( new ClientCallManager(client_io_service_, false, /*local_address=*/"")); @@ -417,7 +417,7 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { unsetenv("RAY_AUTH_TOKEN"); unsetenv("RAY_AUTH_TOKEN_PATH"); // Reset the token loader for test isolation - RayAuthTokenLoader::instance().ResetCache(); + AuthenticationTokenLoader::instance().ResetCache(); } // Helper to execute RPC and wait for result @@ -429,27 +429,22 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { PingResult ExecutePingAndWait() { PingRequest request; - std::atomic done(false); - bool success = false; - std::string error_msg; - - Ping(request, - [&done, &success, &error_msg](const Status &status, const PingReply &reply) { - RAY_LOG(INFO) << "Token auth test replied, status=" << status; - success = status.ok(); - if (!status.ok()) { - error_msg = status.message(); - } - done = true; - }); + auto result_promise = std::make_shared>(); + std::future result_future = result_promise->get_future(); + + Ping(request, [result_promise](const Status &status, const PingReply &reply) { + RAY_LOG(INFO) << "Token auth test replied, status=" << status; + bool success = status.ok(); + std::string error_msg = status.ok() ? "" : status.message(); + result_promise->set_value({true, success, error_msg}); + }); // Wait for response with timeout - auto start = std::chrono::steady_clock::now(); - while (!done && std::chrono::steady_clock::now() - start < std::chrono::seconds(5)) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (result_future.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) { + return {false, false, "Request timed out"}; } - return {done, success, error_msg}; + return result_future.get(); } protected: @@ -485,8 +480,10 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { ASSERT_TRUE(result.completed) << "Request did not complete in time"; ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; - ASSERT_EQ(result.error_msg, - "InvalidAuthToken: Authentication token is missing or incorrect"); + ASSERT_TRUE(result.error_msg.find( + "InvalidAuthToken: Authentication token is missing or incorrect") != + std::string::npos) + << "Error message should contain token auth error. Got: " << result.error_msg; } TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { @@ -501,6 +498,19 @@ TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { << "Request should fail when client doesn't provide required token"; } +TEST_F(TestGrpcServerClientTokenAuthFixture, + TestClientProvidesTokenServerDoesNotRequire) { + // Client provides token, but server doesn't require one (should succeed) + SetUpServerAndClient("", "client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // Server should accept request even though client sent unnecessary token + ASSERT_TRUE(result.success) + << "Request should succeed when server doesn't require token"; +} + } // namespace rpc } // namespace ray From d1fe7b97cb9aa15828631a1d85975ae4aa0d5be7 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 16:46:48 +0000 Subject: [PATCH 26/94] fix lint Signed-off-by: sampan --- src/ray/rpc/BUILD.bazel | 4 ++-- src/ray/rpc/authentication/BUILD.bazel | 4 ++-- src/ray/rpc/tests/BUILD.bazel | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 62a37fa12869..3051c5ef123e 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -15,12 +15,12 @@ ray_cc_library( "//src/ray/core_worker:__pkg__", ], deps = [ - "//src/ray/rpc/authentication:authentication_token_loader", ":rpc_callback_types", "//src/ray/common:asio", "//src/ray/common:grpc_util", "//src/ray/common:id", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_absl//absl/synchronization", ], ) @@ -117,12 +117,12 @@ ray_cc_library( hdrs = ["grpc_server.h"], visibility = ["//visibility:public"], deps = [ - "//src/ray/rpc/authentication:authentication_token_loader", ":common", ":server_call", "//src/ray/common:asio", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:network_util", "//src/ray/util:thread_utils", "@com_github_grpc_grpc//:grpc++", diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel index 4af7dff26e7d..8da78e5d728b 100644 --- a/src/ray/rpc/authentication/BUILD.bazel +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "ray_cc_library", "ray_cc_test") +load("//bazel:ray.bzl", "ray_cc_library") ray_cc_library( name = "authentication_mode", @@ -6,8 +6,8 @@ ray_cc_library( hdrs = ["authentication_mode.h"], visibility = ["//visibility:public"], deps = [ - "@com_google_absl//absl/strings", "//src/ray/common:ray_config", + "@com_google_absl//absl/strings", ], ) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 6e303b53d47b..d5113ae0d3aa 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -49,8 +49,8 @@ ray_cc_test( ], tags = ["team:core"], deps = [ - "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/common:ray_config", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_googletest//:gtest_main", ], ) From 09359d6f2adb84f06b37ad34c7aefd2af1b00715 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 16:49:44 +0000 Subject: [PATCH 27/94] missing imports Signed-off-by: sampan --- src/ray/rpc/authentication/authentication_token_loader.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index bf270cfdaa8a..133fac254e80 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -15,6 +15,8 @@ #include "ray/rpc/authentication/authentication_token_loader.h" #include +#include +#include #include "ray/util/logging.h" From 886c1097f1086f90e60e8bd6afb002f5eeb384c7 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 22 Oct 2025 16:51:58 +0000 Subject: [PATCH 28/94] fix lint Signed-off-by: sampan --- src/ray/rpc/authentication/authentication_token.h | 1 + src/ray/rpc/tests/authentication_token_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index 9ab441b331fa..129ea4dcf0f3 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "ray/common/constants.h" diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc index 10209e6dee90..b8b7de28ac21 100644 --- a/src/ray/rpc/tests/authentication_token_test.cc +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include "gtest/gtest.h" From 1a0b53be437a0012c315a6af0a997d4a414a4084 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 05:46:20 +0000 Subject: [PATCH 29/94] fix build + refactor Signed-off-by: sampan --- src/ray/common/ray_config_def.h | 3 + src/ray/core_worker/BUILD.bazel | 1 + src/ray/core_worker/grpc_service.cc | 2 +- src/ray/core_worker/grpc_service.h | 4 +- src/ray/gcs/BUILD.bazel | 2 +- src/ray/gcs/gcs_server.cc | 1 - src/ray/gcs/grpc_services.cc | 24 +++--- src/ray/gcs/grpc_services.h | 26 +++--- src/ray/rpc/BUILD.bazel | 2 + .../rpc/authentication/authentication_mode.cc | 4 +- .../rpc/authentication/authentication_mode.h | 2 +- .../rpc/authentication/authentication_token.h | 14 +++- .../authentication_token_loader.cc | 28 ++++--- .../authentication_token_loader.h | 3 + .../rpc/node_manager/node_manager_server.h | 4 +- src/ray/rpc/object_manager_server.h | 4 +- .../rpc/tests/authentication_token_test.cc | 81 +------------------ src/ray/rpc/tests/grpc_bench/BUILD.bazel | 1 + src/ray/rpc/tests/grpc_bench/grpc_bench.cc | 10 ++- 19 files changed, 88 insertions(+), 128 deletions(-) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 2d1243df8afd..e4e8fc1d48ef 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -36,6 +36,9 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) RAY_CONFIG(bool, enable_cluster_auth, true) /// Whether to enable token-based authentication for RPC calls. +/// will be converted to AuthenticationMode enum defined in +/// rpc/authentication/authentication_mode.h +/// use GetAuthenticationMode() to get the authentication mode enum value. RAY_CONFIG(std::string, auth_mode, "disabled") /// The interval of periodic event loop stats print. diff --git a/src/ray/core_worker/BUILD.bazel b/src/ray/core_worker/BUILD.bazel index fa140f66e33c..27df09cf8dde 100644 --- a/src/ray/core_worker/BUILD.bazel +++ b/src/ray/core_worker/BUILD.bazel @@ -78,6 +78,7 @@ ray_cc_library( "//src/ray/protobuf:core_worker_cc_proto", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index f6d40e03cafb..585308c3ae67 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -25,7 +25,7 @@ void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index 65cce1eaa538..e1f944ddb64f 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -29,10 +29,12 @@ #pragma once #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/core_worker.grpc.pb.h" @@ -160,7 +162,7 @@ class CoreWorkerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: CoreWorkerService::AsyncService service_; diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index 850150a28b93..b764ad410db6 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -353,6 +353,7 @@ ray_cc_library( "//src/ray/protobuf:gcs_service_cc_grpc", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) @@ -526,7 +527,6 @@ ray_cc_library( "//src/ray/raylet/scheduling:scheduler", "//src/ray/raylet_rpc_client:raylet_client_lib", "//src/ray/raylet_rpc_client:raylet_client_pool", - "//src/ray/rpc:auth_token_loader", "//src/ray/rpc:grpc_server", "//src/ray/rpc:metrics_agent_client", "//src/ray/util:counter_map", diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index beb237250cd8..1d93198b5f36 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -39,7 +39,6 @@ #include "ray/observability/metric_constants.h" #include "ray/pubsub/publisher.h" #include "ray/raylet_rpc_client/raylet_client.h" -#include "ray/rpc/auth_token_loader.h" #include "ray/stats/stats.h" #include "ray/util/network_util.h" diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index 912c7a267eaf..66b4397782c2 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -23,7 +23,7 @@ void ActorInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { /// The register & create actor RPCs take a long time, so we shouldn't limit their /// concurrency to avoid distributed deadlock. RPC_SERVICE_HANDLER(ActorInfoGcsService, RegisterActor, -1) @@ -44,7 +44,7 @@ void NodeInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { // We only allow one cluster ID in the lifetime of a client. // So, if a client connects, it should not have a pre-existing different ID. RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, @@ -64,7 +64,7 @@ void NodeResourceInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER( NodeResourceInfoGcsService, GetAllAvailableResources, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -79,7 +79,7 @@ void InternalPubSubGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalPubSubGcsService, GcsPublish, max_active_rpcs_per_handler_); RPC_SERVICE_HANDLER( InternalPubSubGcsService, GcsSubscriberPoll, max_active_rpcs_per_handler_); @@ -91,7 +91,7 @@ void JobInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER(JobInfoGcsService, AddJob, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, MarkJobFinished, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, GetAllJobInfo, max_active_rpcs_per_handler_) @@ -103,7 +103,7 @@ void RuntimeEnvGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER( RuntimeEnvGcsService, PinRuntimeEnvURI, max_active_rpcs_per_handler_) } @@ -112,7 +112,7 @@ void WorkerInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER( WorkerInfoGcsService, ReportWorkerFailure, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(WorkerInfoGcsService, GetWorkerInfo, max_active_rpcs_per_handler_) @@ -129,7 +129,7 @@ void InternalKVGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalKVGcsService, InternalKVGet, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( InternalKVGcsService, InternalKVMultiGet, max_active_rpcs_per_handler_) @@ -146,7 +146,7 @@ void TaskInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER(TaskInfoGcsService, AddTaskEventData, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(TaskInfoGcsService, GetTaskEvents, max_active_rpcs_per_handler_) } @@ -155,7 +155,7 @@ void PlacementGroupInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER( PlacementGroupInfoGcsService, CreatePlacementGroup, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -177,7 +177,7 @@ void AutoscalerStateGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER( AutoscalerStateService, GetClusterResourceState, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -200,7 +200,7 @@ void RayEventExportGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) { + const std::optional &auth_token) { RPC_SERVICE_HANDLER(RayEventExportGcsService, AddEvents, max_active_rpcs_per_handler_) } diff --git a/src/ray/gcs/grpc_services.h b/src/ray/gcs/grpc_services.h index 2b9f4a50a36f..f7b34746114d 100644 --- a/src/ray/gcs/grpc_services.h +++ b/src/ray/gcs/grpc_services.h @@ -23,11 +23,13 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/id.h" #include "ray/gcs/grpc_service_interfaces.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/autoscaler.grpc.pb.h" @@ -52,7 +54,7 @@ class ActorInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: ActorInfoGcsService::AsyncService service_; @@ -76,7 +78,7 @@ class NodeInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: NodeInfoGcsService::AsyncService service_; @@ -100,7 +102,7 @@ class NodeResourceInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: NodeResourceInfoGcsService::AsyncService service_; @@ -124,7 +126,7 @@ class InternalPubSubGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: InternalPubSubGcsService::AsyncService service_; @@ -148,7 +150,7 @@ class JobInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: JobInfoGcsService::AsyncService service_; @@ -172,7 +174,7 @@ class RuntimeEnvGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: RuntimeEnvGcsService::AsyncService service_; @@ -196,7 +198,7 @@ class WorkerInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: WorkerInfoGcsService::AsyncService service_; @@ -220,7 +222,7 @@ class InternalKVGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: InternalKVGcsService::AsyncService service_; @@ -244,7 +246,7 @@ class TaskInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: TaskInfoGcsService::AsyncService service_; @@ -268,7 +270,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: PlacementGroupInfoGcsService::AsyncService service_; @@ -294,7 +296,7 @@ class AutoscalerStateGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: AutoscalerStateService::AsyncService service_; @@ -322,7 +324,7 @@ class RayEventExportGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override; + const std::optional &auth_token) override; private: RayEventExportGcsService::AsyncService service_; diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 3051c5ef123e..d5e60c0ba484 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -140,6 +140,7 @@ ray_cc_library( deps = [ ":grpc_server", "//src/ray/protobuf:node_manager_cc_grpc", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) @@ -155,6 +156,7 @@ ray_cc_library( "//src/ray/object_manager:object_manager_grpc_client_manager", "//src/ray/protobuf:object_manager_cc_grpc", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", ], diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc index b629fe9b618d..1bbe209733ce 100644 --- a/src/ray/rpc/authentication/authentication_mode.cc +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -26,8 +26,8 @@ namespace rpc { AuthenticationMode GetAuthenticationMode() { std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); - if (auth_mode_lower == "ray_token") { - return AuthenticationMode::RAY_TOKEN; + if (auth_mode_lower == "token") { + return AuthenticationMode::TOKEN; } else { return AuthenticationMode::DISABLED; } diff --git a/src/ray/rpc/authentication/authentication_mode.h b/src/ray/rpc/authentication/authentication_mode.h index 4acf33e29276..21bd165fd34b 100644 --- a/src/ray/rpc/authentication/authentication_mode.h +++ b/src/ray/rpc/authentication/authentication_mode.h @@ -21,7 +21,7 @@ namespace rpc { enum class AuthenticationMode { DISABLED, - RAY_TOKEN, + TOKEN, }; /// Get the authentication mode from the RayConfig. diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index 129ea4dcf0f3..6846d3c08ada 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -38,7 +38,6 @@ class AuthenticationToken { AuthenticationToken() = default; explicit AuthenticationToken(std::string value) : secret_(value.begin(), value.end()) {} - // Copy operations - allowed for caching, but use sparingly AuthenticationToken(const AuthenticationToken &other) : secret_(other.secret_) {} AuthenticationToken &operator=(const AuthenticationToken &other) { if (this != &other) { @@ -68,6 +67,16 @@ class AuthenticationToken { return ConstTimeEqual(secret_, other.secret_); } + /// Equality operator (constant-time) + bool operator==(const AuthenticationToken &other) const noexcept { + return Equals(other); + } + + /// Inequality operator + bool operator!=(const AuthenticationToken &other) const noexcept { + return !(*this == other); + } + /// Set authentication metadata on a gRPC client context /// Only call this from client-side code void SetMetadata(grpc::ClientContext &context) const { @@ -113,6 +122,7 @@ class AuthenticationToken { return diff == 0; } + // replace the characters in the memory with 0 static void ExplicitBurn(void *p, size_t n) noexcept { #if defined(_MSC_VER) SecureZeroMemory(p, n); @@ -131,11 +141,11 @@ class AuthenticationToken { if (!secret_.empty()) { ExplicitBurn(secret_.data(), secret_.size()); secret_.clear(); - secret_.shrink_to_fit(); } } void MoveFrom(AuthenticationToken &&other) noexcept { + SecureClear(); secret_ = std::move(other.secret_); // Clear the moved-from object explicitly for security // Note: 'other' is already an rvalue reference, no need to move again diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 133fac254e80..cd19fed8987b 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -52,8 +52,8 @@ std::optional AuthenticationTokenLoader::GetToken() { return *cached_token_; } - // If token auth is disabled, return std::nullopt - if (GetAuthenticationMode() == AuthenticationMode::DISABLED) { + // If token auth is not enabled, return std::nullopt + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { cached_token_ = std::nullopt; return std::nullopt; } @@ -80,8 +80,8 @@ bool AuthenticationTokenLoader::HasToken() { return cached_token_->has_value(); } - // If token auth is disabled, no token needed - if (GetAuthenticationMode() == AuthenticationMode::DISABLED) { + // If token auth is not enabled, no token needed + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { return false; } @@ -101,12 +101,6 @@ std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file std::string token; std::getline(token_file, token); token_file.close(); - - // Trim whitespace - std::string whitespace = " \t\n\r\f\v"; - token.erase(0, token.find_first_not_of(whitespace)); - token.erase(token.find_last_not_of(whitespace) + 1); - return token; } @@ -116,13 +110,13 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { if (env_token != nullptr && std::string(env_token).length() > 0) { RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " "variable"; - return AuthenticationToken(std::string(env_token)); + return AuthenticationToken(TrimWhitespace(std::string(env_token))); } // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { - std::string token_str = ReadTokenFromFile(env_token_path); + std::string token_str = TrimWhitespace(ReadTokenFromFile(env_token_path)); if (!token_str.empty()) { RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; return AuthenticationToken(token_str); @@ -134,7 +128,7 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { // Precedence 3: Default token path ~/.ray/auth_token std::string default_path = GetDefaultTokenPath(); - std::string token_str = ReadTokenFromFile(default_path); + std::string token_str = TrimWhitespace(ReadTokenFromFile(default_path)); if (!token_str.empty()) { RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; return AuthenticationToken(token_str); @@ -179,5 +173,13 @@ std::string AuthenticationTokenLoader::GetDefaultTokenPath() { return home_dir + token_subpath; } +std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { + std::string whitespace = " \t\n\r\f\v"; + std::string trimmed_str = str; + trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); + return trimmed_str; +} + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h index c392fdc11b2a..712c45f142d8 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -64,6 +64,9 @@ class AuthenticationTokenLoader { /// Default token file path (~/.ray/auth_token or %USERPROFILE%\.ray\auth_token). std::string GetDefaultTokenPath(); + /// Trim whitespace from the beginning and end of the string. + std::string TrimWhitespace(const std::string &str); + std::mutex token_mutex_; std::optional> cached_token_; }; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 942032d9e089..2bad1356aea9 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -15,10 +15,12 @@ #pragma once #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/node_manager.grpc.pb.h" #include "src/ray/protobuf/node_manager.pb.h" @@ -204,7 +206,7 @@ class NodeManagerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override { + const std::optional &auth_token) override { RAY_NODE_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index dd21f5382991..576de9396142 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -15,10 +15,12 @@ #pragma once #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" #include "src/ray/protobuf/object_manager.pb.h" @@ -79,7 +81,7 @@ class ObjectManagerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::string &auth_token) override { + const std::optional &auth_token) override { RAY_OBJECT_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc index b8b7de28ac21..db88d7481da1 100644 --- a/src/ray/rpc/tests/authentication_token_test.cc +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -77,6 +77,10 @@ TEST_F(AuthenticationTokenTest, TestEquals) { EXPECT_TRUE(token1.Equals(token2)); EXPECT_FALSE(token1.Equals(token3)); + EXPECT_TRUE(token1 == token2); + EXPECT_FALSE(token1 == token3); + EXPECT_FALSE(token1 != token2); + EXPECT_TRUE(token1 != token3); } TEST_F(AuthenticationTokenTest, TestEqualityDifferentLengths) { @@ -118,83 +122,6 @@ TEST_F(AuthenticationTokenTest, TestEmptyString) { AuthenticationToken expected(""); EXPECT_TRUE(token.Equals(expected)); } - -TEST_F(AuthenticationTokenTest, TestSpecialCharacters) { - std::string special = "token-with-special!@#$%^&*()_+={}[]|\\:;\"'<>,.?/~`"; - AuthenticationToken token(special); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected(special); - EXPECT_TRUE(token.Equals(expected)); -} - -TEST_F(AuthenticationTokenTest, TestUnicodeCharacters) { - std::string unicode = "token-with-unicode-café-😀"; - AuthenticationToken token(unicode); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected(unicode); - EXPECT_TRUE(token.Equals(expected)); -} - -TEST_F(AuthenticationTokenTest, TestBinaryData) { - std::string binary; - for (int i = 0; i < 256; ++i) { - binary += static_cast(i); - } - - AuthenticationToken token(binary); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected(binary); - EXPECT_TRUE(token.Equals(expected)); -} - -TEST_F(AuthenticationTokenTest, TestLongToken) { - std::string long_token(10000, 'x'); - AuthenticationToken token(long_token); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected(long_token); - EXPECT_TRUE(token.Equals(expected)); -} - -TEST_F(AuthenticationTokenTest, TestConstTimeComparison) { - // This test verifies that comparison works correctly - // Actual timing attack resistance would require specialized timing tests - AuthenticationToken token1("token-abc"); - AuthenticationToken token2("token-xyz"); - AuthenticationToken token3("token-abc"); - - EXPECT_FALSE(token1.Equals(token2)); - EXPECT_TRUE(token1.Equals(token3)); -} - -TEST_F(AuthenticationTokenTest, TestMoveClearsOriginal) { - AuthenticationToken token1("test-token"); - AuthenticationToken expected("test-token"); - - AuthenticationToken token2(std::move(token1)); - - // Original should be empty after move - EXPECT_TRUE(token1.empty()); - // New token should have the value - EXPECT_TRUE(token2.Equals(expected)); -} - -TEST_F(AuthenticationTokenTest, TestMoveAssignmentClearsOriginal) { - AuthenticationToken token1("test-token"); - AuthenticationToken token2("other-token"); - AuthenticationToken expected("test-token"); - - token2 = std::move(token1); - - // Original should be empty after move - EXPECT_TRUE(token1.empty()); - // New token should have the value - EXPECT_TRUE(token2.Equals(expected)); -} - } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/tests/grpc_bench/BUILD.bazel b/src/ray/rpc/tests/grpc_bench/BUILD.bazel index 5238a11c0baf..4594e3873c5f 100644 --- a/src/ray/rpc/tests/grpc_bench/BUILD.bazel +++ b/src/ray/rpc/tests/grpc_bench/BUILD.bazel @@ -28,5 +28,6 @@ cc_binary( ":helloworld_cc_lib", "//src/ray/common:asio", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc index 6bdf24c26106..81dd9477f948 100644 --- a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc +++ b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc @@ -13,10 +13,12 @@ // limitations under the License. #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/rpc/test/grpc_bench/helloworld.grpc.pb.h" #include "src/ray/rpc/test/grpc_bench/helloworld.pb.h" @@ -57,9 +59,11 @@ class GreeterGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override{ - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - Greeter, SayHello, -1, ClusterIdAuthType::NO_AUTH)} + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( + Greeter, SayHello, -1, ClusterIdAuthType::NO_AUTH); + } /// The grpc async service object. Greeter::AsyncService service_; From d5d711bd466d0ddfed0abe7463c5e609e476dab3 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 07:16:30 +0000 Subject: [PATCH 30/94] address cursor comments Signed-off-by: sampan --- src/ray/rpc/authentication/authentication_token_loader.cc | 2 +- src/ray/rpc/authentication/authentication_token_loader.h | 2 +- src/ray/rpc/tests/authentication_token_loader_test.cc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index cd19fed8987b..f131bbf9b46b 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -49,7 +49,7 @@ std::optional AuthenticationTokenLoader::GetToken() { // If already loaded, return cached value if (cached_token_.has_value()) { - return *cached_token_; + return cached_token_; } // If token auth is not enabled, return std::nullopt diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h index 712c45f142d8..f6ff6dc61883 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -68,7 +68,7 @@ class AuthenticationTokenLoader { std::string TrimWhitespace(const std::string &str); std::mutex token_mutex_; - std::optional> cached_token_; + std::optional cached_token_; }; } // namespace rpc diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 024c50bcb32f..23cdb328e30f 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -45,7 +45,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { protected: void SetUp() override { // Enable token authentication for tests - RayConfig::instance().initialize(R"({"auth_mode": "ray_token"})"); + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory // This ensures tests work in environments where HOME isn't provided @@ -290,7 +290,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { EXPECT_FALSE(loader.HasToken()); // Re-enable for other tests - RayConfig::instance().initialize(R"({"auth_mode": "ray_token"})"); + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); } TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { From 78c9cf49f9bff6b8d568e08970de15d5ce473a62 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 09:07:46 +0000 Subject: [PATCH 31/94] split grpc client server tests Signed-off-by: sampan --- .../authentication_token_loader.cc | 18 -- .../authentication_token_loader.h | 3 - src/ray/rpc/tests/BUILD.bazel | 17 ++ .../tests/authentication_token_loader_test.cc | 8 +- src/ray/rpc/tests/grpc_auth_token_tests.cc | 221 +++++++++++++++ src/ray/rpc/tests/grpc_server_client_test.cc | 264 +----------------- src/ray/rpc/tests/grpc_test_common.h | 107 +++++++ 7 files changed, 351 insertions(+), 287 deletions(-) create mode 100644 src/ray/rpc/tests/grpc_auth_token_tests.cc create mode 100644 src/ray/rpc/tests/grpc_test_common.h diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index f131bbf9b46b..59a1184e080a 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -72,24 +72,6 @@ std::optional AuthenticationTokenLoader::GetToken() { return *cached_token_; } -bool AuthenticationTokenLoader::HasToken() { - std::lock_guard lock(token_mutex_); - - // If already loaded, check if present - if (cached_token_.has_value()) { - return cached_token_->has_value(); - } - - // If token auth is not enabled, no token needed - if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { - return false; - } - - // Try to load token - AuthenticationToken token = LoadTokenFromSources(); - return !token.empty(); -} - // Read token from the first line of the file. trim whitespace. // Returns empty string if file cannot be opened or is empty. std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h index f6ff6dc61883..4034ecbc78dd 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -40,9 +40,6 @@ class AuthenticationTokenLoader { /// \return The authentication token, or std::nullopt if auth is disabled. std::optional GetToken(); - /// Check if an authentication token exists. - bool HasToken(); - void ResetCache() { std::lock_guard lock(token_mutex_); cached_token_.reset(); diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index d5113ae0d3aa..279b68f91ba3 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,6 +18,23 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", + "grpc_test_common.h", + ], + tags = ["team:core"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "grpc_auth_token_tests", + size = "small", + srcs = [ + "grpc_auth_token_tests.cc", + "grpc_test_common.h", ], tags = ["team:core"], deps = [ diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 23cdb328e30f..616a13b0e457 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -167,7 +167,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { ASSERT_TRUE(token_opt.has_value()); AuthenticationToken expected("test-token-from-env"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.HasToken()); + EXPECT_TRUE(loader.GetToken().has_value()); } TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { @@ -184,7 +184,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { ASSERT_TRUE(token_opt.has_value()); AuthenticationToken expected("test-token-from-file"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.HasToken()); + EXPECT_TRUE(loader.GetToken().has_value()); // Clean up remove(temp_token_path.c_str()); @@ -201,7 +201,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { ASSERT_TRUE(token_opt.has_value()); AuthenticationToken expected("test-token-from-default"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.HasToken()); + EXPECT_TRUE(loader.GetToken().has_value()); } // Parametrized test for token loading precedence: env var > user-specified file > default @@ -287,7 +287,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { auto token_opt = loader.GetToken(); EXPECT_FALSE(token_opt.has_value()); - EXPECT_FALSE(loader.HasToken()); + EXPECT_FALSE(loader.GetToken().has_value()); // Re-enable for other tests RayConfig::instance().initialize(R"({"auth_mode": "token"})"); diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc new file mode 100644 index 000000000000..5feaf5563add --- /dev/null +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -0,0 +1,221 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/tests/grpc_test_common.h" + +namespace ray { +namespace rpc { + +class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { + public: + void SetUp() override { + // Configure token auth via RayConfig + std::string config_json = R"({"auth_mode": "token"})"; + RayConfig::instance().initialize(config_json); + AuthenticationTokenLoader::instance().ResetCache(); + } + + void SetUpServerAndClient(const std::string &server_token, + const std::string &client_token) { + // Set client token in environment for ClientCallManager to read from + // AuthenticationTokenLoader + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + } else { + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + unsetenv("RAY_AUTH_TOKEN"); + } + + // Start client thread FIRST + client_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + client_io_service_work_(client_io_service_.get_executor()); + client_io_service_.run(); + }); + + // Start handler thread for server + handler_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + handler_io_service_work_(handler_io_service_.get_executor()); + handler_io_service_.run(); + }); + + // Create and start server + // Pass server token explicitly for testing scenarios with different tokens + std::optional server_auth_token; + if (!server_token.empty()) { + server_auth_token = AuthenticationToken(server_token); + } else { + // Explicitly set empty token (no auth required) + server_auth_token = AuthenticationToken(""); + } + grpc_server_.reset(new GrpcServer("test", 0, true, 1, 7200000, server_auth_token)); + grpc_server_->RegisterService( + std::make_unique(handler_io_service_, test_service_handler_), + false); + grpc_server_->Run(); + + while (grpc_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Create client (will read auth token from AuthenticationTokenLoader which reads the + // environment) + client_call_manager_.reset( + new ClientCallManager(client_io_service_, false, /*local_address=*/"")); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); + } + + void TearDown() override { + if (grpc_client_) { + grpc_client_.reset(); + } + if (client_call_manager_) { + client_call_manager_.reset(); + } + if (client_thread_) { + client_io_service_.stop(); + if (client_thread_->joinable()) { + client_thread_->join(); + } + } + + if (grpc_server_) { + grpc_server_->Shutdown(); + } + if (handler_thread_) { + handler_io_service_.stop(); + if (handler_thread_->joinable()) { + handler_thread_->join(); + } + } + + // Clean up environment variables + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + // Reset the token loader for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + // Helper to execute RPC and wait for result + struct PingResult { + bool completed; + bool success; + std::string error_msg; + }; + + PingResult ExecutePingAndWait() { + PingRequest request; + auto result_promise = std::make_shared>(); + std::future result_future = result_promise->get_future(); + + Ping(request, [result_promise](const Status &status, const PingReply &reply) { + RAY_LOG(INFO) << "Token auth test replied, status=" << status; + bool success = status.ok(); + std::string error_msg = status.ok() ? "" : status.message(); + result_promise->set_value({true, success, error_msg}); + }); + + // Wait for response with timeout + if (result_future.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) { + return {false, false, "Request timed out"}; + } + + return result_future.get(); + } + + protected: + VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) + + TestServiceHandler test_service_handler_; + instrumented_io_context handler_io_service_; + std::unique_ptr handler_thread_; + std::unique_ptr grpc_server_; + + instrumented_io_context client_io_service_; + std::unique_ptr client_thread_; + std::unique_ptr client_call_manager_; + std::unique_ptr> grpc_client_; +}; + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { + // Both server and client have the same token + const std::string token = "test_secret_token_123"; + SetUpServerAndClient(token, token); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) << "Request should succeed with matching token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { + // Server and client have different tokens + SetUpServerAndClient("server_token", "wrong_client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; + ASSERT_TRUE(result.error_msg.find( + "InvalidAuthToken: Authentication token is missing or incorrect") != + std::string::npos) + << "Error message should contain token auth error. Got: " << result.error_msg; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { + // Server expects token, client doesn't send one (empty token) + SetUpServerAndClient("server_token", ""); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // If the server has a token but the client doesn't, auth should fail + ASSERT_FALSE(result.success) + << "Request should fail when client doesn't provide required token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, + TestClientProvidesTokenServerDoesNotRequire) { + // Client provides token, but server doesn't require one (should succeed) + SetUpServerAndClient("", "client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // Server should accept request even though client sent unnecessary token + ASSERT_TRUE(result.success) + << "Request should succeed when server doesn't require token"; +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 729ece195e46..0e95fc9823a5 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -13,97 +13,17 @@ // limitations under the License. #include -#include #include -#include -#include +#include #include "gtest/gtest.h" -#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { -class TestServiceHandler { - public: - void HandlePing(PingRequest request, - PingReply *reply, - SendReplyCallback send_reply_callback) { - RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); - request_count++; - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - if (request.no_reply()) { - RAY_LOG(INFO) << "No reply!"; - return; - } - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - void HandlePingTimeout(PingTimeoutRequest request, - PingTimeoutReply *reply, - SendReplyCallback send_reply_callback) { - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - std::atomic request_count{0}; - std::atomic reply_failure_count{0}; - std::atomic frozen{false}; -}; - -class TestGrpcService : public GrpcService { - public: - /// Constructor. - /// - /// \param[in] handler The service handler that actually handle the requests. - explicit TestGrpcService(instrumented_io_context &handler_io_service_, - TestServiceHandler &handler) - : GrpcService(handler_io_service_), service_handler_(handler){}; - - protected: - grpc::Service &GetGrpcService() override { return service_; } - - void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector> *server_call_factories, - const ClusterID &cluster_id, - const std::optional &auth_token) override { - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, PingTimeout, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); - } - - private: - /// The grpc async service object. - TestService::AsyncService service_; - /// The service handler that actually handle the requests. - TestServiceHandler &service_handler_; -}; class TestGrpcServerClientFixture : public ::testing::Test { public: @@ -330,186 +250,6 @@ TEST_F(TestGrpcServerClientFixture, TestTimeoutMacro) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } -class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { - public: - void SetUp() override { - // Configure token auth via RayConfig - std::string config_json = R"({"auth_mode": "ray_token"})"; - RayConfig::instance().initialize(config_json); - AuthenticationTokenLoader::instance().ResetCache(); - } - - void SetUpServerAndClient(const std::string &server_token, - const std::string &client_token) { - // Set client token in environment for ClientCallManager to read from - // AuthenticationTokenLoader - if (!client_token.empty()) { - setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); - } else { - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); - AuthenticationTokenLoader::instance().ResetCache(); - unsetenv("RAY_AUTH_TOKEN"); - } - - // Start client thread FIRST - client_thread_ = std::make_unique([this]() { - boost::asio::executor_work_guard - client_io_service_work_(client_io_service_.get_executor()); - client_io_service_.run(); - }); - - // Start handler thread for server - handler_thread_ = std::make_unique([this]() { - boost::asio::executor_work_guard - handler_io_service_work_(handler_io_service_.get_executor()); - handler_io_service_.run(); - }); - - // Create and start server - // Pass server token explicitly for testing scenarios with different tokens - std::optional server_auth_token; - if (!server_token.empty()) { - server_auth_token = AuthenticationToken(server_token); - } - grpc_server_.reset(new GrpcServer("test", 0, true, 1, 7200000, server_auth_token)); - grpc_server_->RegisterService( - std::make_unique(handler_io_service_, test_service_handler_), - false); - grpc_server_->Run(); - - while (grpc_server_->GetPort() == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - - // Create client (will read auth token from AuthenticationTokenLoader which reads the - // environment) - client_call_manager_.reset( - new ClientCallManager(client_io_service_, false, /*local_address=*/"")); - grpc_client_.reset(new GrpcClient( - "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); - } - - void TearDown() override { - if (grpc_client_) { - grpc_client_.reset(); - } - if (client_call_manager_) { - client_call_manager_.reset(); - } - if (client_thread_) { - client_io_service_.stop(); - if (client_thread_->joinable()) { - client_thread_->join(); - } - } - - if (grpc_server_) { - grpc_server_->Shutdown(); - } - if (handler_thread_) { - handler_io_service_.stop(); - if (handler_thread_->joinable()) { - handler_thread_->join(); - } - } - - // Clean up environment variables - unsetenv("RAY_AUTH_TOKEN"); - unsetenv("RAY_AUTH_TOKEN_PATH"); - // Reset the token loader for test isolation - AuthenticationTokenLoader::instance().ResetCache(); - } - - // Helper to execute RPC and wait for result - struct PingResult { - bool completed; - bool success; - std::string error_msg; - }; - - PingResult ExecutePingAndWait() { - PingRequest request; - auto result_promise = std::make_shared>(); - std::future result_future = result_promise->get_future(); - - Ping(request, [result_promise](const Status &status, const PingReply &reply) { - RAY_LOG(INFO) << "Token auth test replied, status=" << status; - bool success = status.ok(); - std::string error_msg = status.ok() ? "" : status.message(); - result_promise->set_value({true, success, error_msg}); - }); - - // Wait for response with timeout - if (result_future.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) { - return {false, false, "Request timed out"}; - } - - return result_future.get(); - } - - protected: - VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) - - TestServiceHandler test_service_handler_; - instrumented_io_context handler_io_service_; - std::unique_ptr handler_thread_; - std::unique_ptr grpc_server_; - - instrumented_io_context client_io_service_; - std::unique_ptr client_thread_; - std::unique_ptr client_call_manager_; - std::unique_ptr> grpc_client_; -}; - -TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { - // Both server and client have the same token - const std::string token = "test_secret_token_123"; - SetUpServerAndClient(token, token); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_TRUE(result.success) << "Request should succeed with matching token"; -} - -TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { - // Server and client have different tokens - SetUpServerAndClient("server_token", "wrong_client_token"); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; - ASSERT_TRUE(result.error_msg.find( - "InvalidAuthToken: Authentication token is missing or incorrect") != - std::string::npos) - << "Error message should contain token auth error. Got: " << result.error_msg; -} - -TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { - // Server expects token, client doesn't send one (empty token) - SetUpServerAndClient("server_token", ""); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - // If the server has a token but the client doesn't, auth should fail - ASSERT_FALSE(result.success) - << "Request should fail when client doesn't provide required token"; -} - -TEST_F(TestGrpcServerClientTokenAuthFixture, - TestClientProvidesTokenServerDoesNotRequire) { - // Client provides token, but server doesn't require one (should succeed) - SetUpServerAndClient("", "client_token"); - - auto result = ExecutePingAndWait(); - - ASSERT_TRUE(result.completed) << "Request did not complete in time"; - // Server should accept request even though client sent unnecessary token - ASSERT_TRUE(result.success) - << "Request should succeed when server doesn't require token"; -} } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/tests/grpc_test_common.h b/src/ray/rpc/tests/grpc_test_common.h new file mode 100644 index 000000000000..6a9cdcd05bc6 --- /dev/null +++ b/src/ray/rpc/tests/grpc_test_common.h @@ -0,0 +1,107 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" + +namespace ray { +namespace rpc { + +class TestServiceHandler { + public: + void HandlePing(PingRequest request, + PingReply *reply, + SendReplyCallback send_reply_callback) { + RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); + request_count++; + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + if (request.no_reply()) { + RAY_LOG(INFO) << "No reply!"; + return; + } + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + void HandlePingTimeout(PingTimeoutRequest request, + PingTimeoutReply *reply, + SendReplyCallback send_reply_callback) { + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + std::atomic request_count{0}; + std::atomic reply_failure_count{0}; + std::atomic frozen{false}; +}; + +class TestGrpcService : public GrpcService { + public: + /// Constructor. + /// + /// \param[in] handler The service handler that actually handle the requests. + explicit TestGrpcService(instrumented_io_context &handler_io_service_, + TestServiceHandler &handler) + : GrpcService(handler_io_service_), service_handler_(handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector> *server_call_factories, + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, PingTimeout, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + } + + private: + /// The grpc async service object. + TestService::AsyncService service_; + /// The service handler that actually handle the requests. + TestServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray From 123914e6ec14ce83f5c15af498a00f4b27f288cb Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 09:10:47 +0000 Subject: [PATCH 32/94] fix lint Signed-off-by: sampan --- src/ray/rpc/tests/grpc_test_common.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ray/rpc/tests/grpc_test_common.h b/src/ray/rpc/tests/grpc_test_common.h index 6a9cdcd05bc6..1ce199f79511 100644 --- a/src/ray/rpc/tests/grpc_test_common.h +++ b/src/ray/rpc/tests/grpc_test_common.h @@ -16,7 +16,9 @@ #include #include +#include #include +#include #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" From 56fb190abb20e3b0385a9f378652254c195fa8fe Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 09:13:02 +0000 Subject: [PATCH 33/94] fix imports Signed-off-by: sampan --- src/ray/rpc/tests/grpc_auth_token_tests.cc | 4 ++-- src/ray/rpc/tests/grpc_server_client_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 5feaf5563add..5501e6e568b0 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -19,11 +19,11 @@ #include #include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" -#include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" +#include "ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 0e95fc9823a5..8bc6e8284493 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -17,10 +17,10 @@ #include #include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" -#include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" +#include "ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { From 0433a167d2fb8a872104348670fc52258141146e Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 15:06:53 +0000 Subject: [PATCH 34/94] refactor and simplify changes Signed-off-by: sampan --- python/ray/_private/auth_token_loader.py | 224 ------------------ .../_private/authentication_token_setup.py | 100 ++++++++ python/ray/_private/worker.py | 2 +- python/ray/_raylet.pyx | 45 +++- python/ray/includes/common.pxd | 14 +- python/ray/includes/ray_config.pxd | 2 - python/ray/includes/ray_config.pxi | 4 - python/ray/scripts/scripts.py | 27 --- .../ray/tests/test_token_auth_integration.py | 48 ++-- .../ray/tests/unit/test_auth_token_loader.py | 223 ----------------- .../authentication_token_loader.cc | 27 +++ .../authentication_token_loader.h | 5 + 12 files changed, 202 insertions(+), 519 deletions(-) delete mode 100644 python/ray/_private/auth_token_loader.py create mode 100644 python/ray/_private/authentication_token_setup.py delete mode 100644 python/ray/tests/unit/test_auth_token_loader.py diff --git a/python/ray/_private/auth_token_loader.py b/python/ray/_private/auth_token_loader.py deleted file mode 100644 index 32cc032f4872..000000000000 --- a/python/ray/_private/auth_token_loader.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Authentication token loader for Ray. - -This module provides functions to load, generate, and cache authentication tokens -for Ray's token-based authentication system. Tokens are loaded with the following -precedence: -1. RAY_AUTH_TOKEN environment variable -2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) -3. Default token path: ~/.ray/auth_token -""" - -import logging -import os -import threading -import uuid -from pathlib import Path -from typing import Dict, Optional - -from ray._raylet import Config, reset_auth_token_cache - -logger = logging.getLogger(__name__) - -# Module-level cached variables -_cached_token: Optional[str] = None -_token_lock = threading.Lock() - - -def load_auth_token(generate_if_not_found: bool = False) -> str: - """Load the authentication token with caching. - - This function loads the token from available sources with the following precedence: - 1. RAY_AUTH_TOKEN environment variable - 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) - 3. Default token path: ~/.ray/auth_token - - The token is cached after the first successful load to avoid repeated file I/O. - - Args: - generate_if_not_found: If True, generate and save a new token if not found. - If False, return empty string if no token is found. - - Returns: - The authentication token, or empty string if not found and generation is disabled. - """ - global _cached_token - - with _token_lock: - # Return cached token if already loaded - if _cached_token is not None: - return _cached_token - - # Try to load from sources - token = _load_token_from_sources() - - # Generate if requested and not found - if not token and generate_if_not_found: - token = _generate_and_save_token_internal() - - # Cache the result (even if empty) - _cached_token = token - return _cached_token - - -def _load_token_from_sources() -> str: - """Load token from available sources (env vars and file). - - Returns: - The authentication token, or empty string if not found. - """ - # Precedence 1: RAY_AUTH_TOKEN environment variable - env_token = os.environ.get("RAY_AUTH_TOKEN", "").strip() - if env_token: - logger.debug( - "Loaded authentication token from RAY_AUTH_TOKEN environment variable" - ) - return env_token - - # Precedence 2: RAY_AUTH_TOKEN_PATH environment variable - env_token_path = os.environ.get("RAY_AUTH_TOKEN_PATH", "").strip() - if env_token_path: - token_path = Path(env_token_path).expanduser() - if not token_path.exists(): - raise FileNotFoundError(f"Token file not found: {token_path}") - token = token_path.read_text().strip() - if token: - logger.debug(f"Loaded authentication token from file: {token_path}") - return token - - # Precedence 3: Default token path ~/.ray/auth_token - default_path = _get_default_token_path() - try: - if default_path.exists(): - token = default_path.read_text().strip() - if token: - logger.debug( - f"Loaded authentication token from default path: {default_path}" - ) - return token - except Exception as e: - logger.debug(f"Failed to read token from default path ({default_path}): {e}") - - # No token found - logger.debug("No authentication token found in any source") - return "" - - -def generate_and_save_token() -> str: - """Generate a new random token and save it in the default token path. - - Returns: - The newly generated authentication token. - """ - global _cached_token - - with _token_lock: - # Check if we already have a cached token - if _cached_token is not None: - logger.warning( - "Returning cached authentication token instead of generating new one. " - "Call load_auth_token() to use existing token or clear cache first." - ) - return _cached_token - - # Generate and save token without nested lock - return _generate_and_save_token_internal() - - -def _generate_and_save_token_internal() -> str: - """Internal function to generate and save token. Assumes lock is already held.""" - global _cached_token - - # Generate a UUID-based token - token = uuid.uuid4().hex - - # Try to save the token to the default path - token_path = _get_default_token_path() - try: - # Create directory if it doesn't exist - token_path.parent.mkdir(parents=True, exist_ok=True) - - # Write token to file - token_path.write_text(token) - - # Ensure file is flushed to disk immediately - # This is critical for subprocess/C++ code to read it immediately - import os - - fd = os.open(str(token_path), os.O_RDONLY) - os.fsync(fd) - os.close(fd) - - logger.info(f"Generated new authentication token and saved to {token_path}") - except Exception as e: - logger.warning( - f"Failed to save generated token to {token_path}: {e}. " - "Token will only be available in memory." - ) - - # Cache the generated token - _cached_token = token - return token - - -def _get_default_token_path() -> Path: - """Get the default token file path (~/.ray/auth_token). - - Returns: - Path object pointing to ~/.ray/auth_token - """ - return Path.home() / ".ray" / "auth_token" - - -def setup_and_verify_auth( - system_config: Optional[Dict] = None, is_new_cluster: bool = True -) -> None: - """Verify auth configuration and ensure token is available when auth is enabled. - - This is called early during ray.init() to: - 1. Check for _system_config misuse and provide helpful error - 2. Verify token is available if auth is enabled - 3. Generate default token for new local clusters if needed - - Args: - system_config: The _system_config dict from ray.init() (checked for misuse) - is_new_cluster: True if starting new local cluster, False if connecting to an existing cluster - - Raises: - ValueError: If _system_config is used for enabling auth (should use env var instead) - """ - # Check for _system_config misuse - if system_config and system_config.get("enable_token_auth", False): - raise ValueError( - "Authentication mode should be configured via environment variable, " - "not _system_config (which is for testing only).\n" - "Please set: RAY_enable_token_auth=1\n" - "Or in Python: os.environ['RAY_enable_token_auth'] = '1'" - ) - - Config.initialize("") - - if Config.enable_token_auth(): - # For new clusters: generate token if not found - # For existing clusters: only use existing token (don't generate) - token = load_auth_token(generate_if_not_found=is_new_cluster) - - if not is_new_cluster and not token: - raise RuntimeError( - "Token authentication is enabled on the cluster you're connecting to, " - "but no authentication token was found. Please provide a token using one of:\n" - " 1. RAY_AUTH_TOKEN environment variable\n" - " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" - " 3. Default token file: ~/.ray/auth_token" - ) - - -def _reset_token_cache_for_testing(): - """Reset both Python and C++ token caches. - - Should only be used for testing purposes. - """ - global _cached_token - _cached_token = None - - # Also reset the C++ token cache - reset_auth_token_cache() diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py new file mode 100644 index 000000000000..4b0375a46b5c --- /dev/null +++ b/python/ray/_private/authentication_token_setup.py @@ -0,0 +1,100 @@ +"""Authentication token setup for Ray. + +This module provides functions to generate and save authentication tokens +for Ray's token-based authentication system. Token loading and caching is +handled by the C++ AuthenticationTokenLoader. +""" + +import logging +import os +import uuid +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def generate_and_save_token() -> str: + """Generate a new random token and save it in the default token path. + + Returns: + The newly generated authentication token. + """ + # Generate a UUID-based token + token = uuid.uuid4().hex + + token_path = _get_default_token_path() + try: + # Create directory if it doesn't exist + token_path.parent.mkdir(parents=True, exist_ok=True) + + # Write token to file with explicit flush + with open(token_path, "w") as f: + f.write(token) + f.flush() + os.fsync(f.fileno()) + + logger.info(f"Generated new authentication token and saved to {token_path}") + except Exception as e: + logger.warning(f"Failed to save generated token to {token_path}: {e}. ") + raise + + return token + + +def _get_default_token_path() -> Path: + """Get the default token file path (~/.ray/auth_token). + + Returns: + Path object pointing to ~/.ray/auth_token + """ + return Path.home() / ".ray" / "auth_token" + + +def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> None: + """Check authentication settings and setup necessary resources. + + for token based authentication, Ray calls this early during ray.init() to do the following: + 1. Check if you enabled token-based authentication. + 2. Make sure a token is available if authentication is enabled. + 3. Generate and save a default token for new local clusters if one doesn't exist. + + Args: + system_config: Raises an error if you set auth_mode in system_config instead of the environment. + is_new_cluster: Set to True if starting a new local cluster, or False if connecting + to an existing cluster. + + Raises: + RuntimeError: If authentication is enabled but no token is found when connecting + to an existing cluster. + """ + from ray._raylet import ( + AuthenticationMode, + AuthenticationTokenLoader, + get_authentication_mode, + ) + + # Check if token authentication is enabled + if get_authentication_mode() != AuthenticationMode.TOKEN: + if system_config and system_config.get("auth_mode") != "disabled": + raise RuntimeError( + "Set authentication mode with the environment, not system_config." + ) + return + + token_loader = AuthenticationTokenLoader.instance() + + if not token_loader.has_token(): + if is_new_cluster: + # Generate token for new local cluster + token = generate_and_save_token() + + # Reload cache so that subsequent calls to token_loader read the new token + token_loader.reset_cache() + else: + # You're connecting to an existing cluster—token must already exist + raise RuntimeError( + "Token authentication is enabled but no authentication token was found. Please provide a token using one of:\n" + " 1. RAY_AUTH_TOKEN environment variable\n" + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" + " 3. Default token file: ~/.ray/auth_token" + ) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index b8b206840828..ecbab76669b3 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -62,7 +62,7 @@ from ray._common import ray_option_utils from ray._common.constants import RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR from ray._common.utils import load_class -from ray._private.auth_token_loader import setup_and_verify_auth +from ray._private.authentication_token_setup import setup_and_verify_auth from ray._private.client_mode_hook import client_mode_hook from ray._private.custom_types import TensorTransportEnum from ray._private.function_manager import FunctionActorManager diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0dc61893f708..b3e072f4dcdf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -114,7 +114,9 @@ from ray.includes.common cimport ( CConcurrencyGroup, CGrpcStatusCode, CLineageReconstructionTask, - CRayAuthTokenLoader, + CAuthenticationMode, + GetAuthenticationMode, + CAuthenticationTokenLoader, move, LANGUAGE_CPP, LANGUAGE_JAVA, @@ -4950,10 +4952,41 @@ def get_session_key_from_storage(host, port, username, password, use_ssl, config return None -def reset_auth_token_cache(): - """Reset the C++ authentication token cache. +# Authentication mode enum exposed to Python +class AuthenticationMode: + DISABLED = CAuthenticationMode.DISABLED + TOKEN = CAuthenticationMode.TOKEN - This forces the RayAuthTokenLoader to reload the token from environment - variables or files on the next request. + +def get_authentication_mode(): + """Get the current authentication mode. + + Returns: + AuthenticationMode enum value (DISABLED or TOKEN) """ - CRayAuthTokenLoader.instance().ResetCache() + return GetAuthenticationMode() + + +class AuthenticationTokenLoader: + """Python wrapper for C++ AuthenticationTokenLoader singleton.""" + + @staticmethod + def instance(): + """Get the singleton instance (returns a wrapper for convenience).""" + return AuthenticationTokenLoader() + + def has_token(self): + """Check if an authentication token exists without crashing. + + Returns: + bool: True if a token exists, False otherwise + """ + return CAuthenticationTokenLoader.instance().HasToken() + + def reset_cache(self): + """Reset the C++ authentication token cache. + + This forces the token loader to reload the token from environment + variables or files on the next request. + """ + CAuthenticationTokenLoader.instance().ResetCache() diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index d0e1f95732ff..4ab1c9011a05 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -805,8 +805,16 @@ cdef extern from "ray/common/constants.h" nogil: cdef const char[] kLabelKeyTpuPodType cdef const char[] kRayInternalNamespacePrefix -cdef extern from "ray/rpc/auth_token_loader.h" namespace "ray::rpc" nogil: - cdef cppclass CRayAuthTokenLoader "ray::rpc::RayAuthTokenLoader": +cdef extern from "ray/rpc/authentication/authentication_mode.h" namespace "ray::rpc" nogil: + cdef enum CAuthenticationMode "ray::rpc::AuthenticationMode": + DISABLED "ray::rpc::AuthenticationMode::DISABLED" + TOKEN "ray::rpc::AuthenticationMode::TOKEN" + + CAuthenticationMode GetAuthenticationMode() + +cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespace "ray::rpc" nogil: + cdef cppclass CAuthenticationTokenLoader "ray::rpc::AuthenticationTokenLoader": @staticmethod - CRayAuthTokenLoader& instance() + CAuthenticationTokenLoader& instance() + c_bool HasToken() void ResetCache() diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 4404b674af06..729395a22ee3 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -88,5 +88,3 @@ cdef extern from "ray/common/ray_config.h" nogil: c_bool record_task_actor_creation_sites() const c_bool start_python_gc_manager_thread() const - - c_bool enable_token_auth() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 1a7b104f4eb7..6915e4877962 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -144,7 +144,3 @@ cdef class Config: @staticmethod def start_python_gc_manager_thread(): return RayConfig.instance().start_python_gc_manager_thread() - - @staticmethod - def enable_token_auth(): - return RayConfig.instance().enable_token_auth() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index a5b02d01d170..3c1dffe3b8f9 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -674,14 +674,6 @@ def debug(address: str, verbose: bool): "Cgroup memory and cpu controllers be enabled for this cgroup. " "This option only works if --enable-resource-isolation is set.", ) -@click.option( - "--enable-token-auth", - is_flag=True, - default=False, - help="Enable token-based authentication. Requires an existing token from " - "environment variables (RAY_AUTH_TOKEN or RAY_AUTH_TOKEN_PATH) or ~/.ray/auth_token. " - "Use ray.init(enable_token_auth=True) to auto-generate a token.", -) @add_click_logging_options @PublicAPI def start( @@ -730,7 +722,6 @@ def start( system_reserved_cpu, system_reserved_memory, cgroup_path, - enable_token_auth, ): """Start Ray processes manually on the local machine.""" @@ -793,24 +784,6 @@ def start( system_reserved_memory=system_reserved_memory, ) - # Handle token-based authentication - if enable_token_auth: - from ray._private.auth_token_loader import load_auth_token - - # Try to load token (don't generate in CLI) - token = load_auth_token(generate_if_not_found=False) - if not token: - cli_logger.abort( - "Token authentication is enabled but no token found. " - "Please set RAY_AUTH_TOKEN environment variable, " - "RAY_AUTH_TOKEN_PATH, or create ~/.ray/auth_token. " - "Alternatively, use ray.init(enable_token_auth=True) to auto-generate a token." - ) - # Only pass the flag, not the token - if system_config is None: - system_config = {} - system_config["enable_token_auth"] = "true" - redirect_output = None if not no_redirect_output else True # no client, no port -> ok diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 3e79e4666f8f..6be082bf214b 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -7,10 +7,14 @@ import pytest import ray -from ray._private.auth_token_loader import reset_token_cache +from ray._raylet import AuthenticationTokenLoader from ray.cluster_utils import Cluster +def reset_token_cache(): + AuthenticationTokenLoader.instance().reset_cache() + + @pytest.fixture(autouse=True) def clean_token_sources(): """Clean up all token sources before and after each test.""" @@ -18,7 +22,7 @@ def clean_token_sources(): env_vars_to_clean = [ "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH", - "RAY_enable_token_auth", + "RAY_auth_mode", ] original_values = {} for var in env_vars_to_clean: @@ -55,13 +59,13 @@ def clean_token_sources(): def test_local_cluster_generates_token(): - """Test ray.init() generates token for local cluster when enable_token_auth is set.""" + """Test ray.init() generates token for local cluster when auth_mode=token is set.""" # Ensure no token exists default_token_path = Path.home() / ".ray" / "auth_token" assert not default_token_path.exists() # Enable token auth via environment variable - os.environ["RAY_enable_token_auth"] = "1" + os.environ["RAY_auth_mode"] = "token" # Initialize Ray with token auth ray.init() @@ -81,34 +85,26 @@ def test_local_cluster_generates_token(): def test_connect_without_token_raises_error(): - """Test ray.init(address=...) without token fails when enable_token_auth config is set.""" + """Test ray.init(address=...) without token fails when auth_mode=token is set.""" # Test the token validation logic directly - # Clear the cached token to ensure we start fresh - import ray._private.auth_token_loader as auth_module - from ray._private.auth_token_loader import load_auth_token - - auth_module._cached_token = None - # Ensure no token exists - token = load_auth_token(generate_if_not_found=False) - assert token == "" + token_loader = AuthenticationTokenLoader.instance() + assert not token_loader.has_token() # Test the exact error message that would be raised with pytest.raises(RuntimeError, match="no authentication token was found"): - if not token: - raise RuntimeError( - "Token-based authentication is enabled on the cluster you're connecting to, " - "but no authentication token was found. Please provide a token using one of:\n" - " 1. RAY_AUTH_TOKEN environment variable\n" - " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" - " 3. Default token file: ~/.ray/auth_token" - ) + raise RuntimeError( + "Token authentication is enabled but no authentication token was found. Please provide a token using one of:\n" + " 1. RAY_AUTH_TOKEN environment variable\n" + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" + " 3. Default token file: ~/.ray/auth_token" + ) def test_token_path_nonexistent_file_fails(): """Test that setting RAY_AUTH_TOKEN_PATH to nonexistent file fails gracefully.""" # Enable token auth and set token path to nonexistent file - os.environ["RAY_enable_token_auth"] = "1" + os.environ["RAY_auth_mode"] = "token" os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" # Initialize Ray with token auth should fail @@ -122,7 +118,7 @@ def test_cluster_token_authentication(tokens_match): # Set up cluster token first cluster_token = "a" * 32 os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_enable_token_auth"] = "1" + os.environ["RAY_auth_mode"] = "token" # Create cluster with token auth enabled - node will read current env token cluster = Cluster() @@ -175,11 +171,5 @@ def test_func(): cluster.shutdown() -def test_system_config_auth_raises_error(): - """Test that using _system_config for enabling token auth raises helpful error.""" - with pytest.raises(ValueError, match="environment variable"): - ray.init(_system_config={"enable_token_auth": True}) - - if __name__ == "__main__": sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/unit/test_auth_token_loader.py b/python/ray/tests/unit/test_auth_token_loader.py deleted file mode 100644 index cdfd6add8b9f..000000000000 --- a/python/ray/tests/unit/test_auth_token_loader.py +++ /dev/null @@ -1,223 +0,0 @@ -"""Unit tests for ray._private.auth_token_loader module.""" - -import os -import sys -import tempfile -from pathlib import Path - -import pytest - -from ray._private import auth_token_loader - - -@pytest.fixture(autouse=True) -def reset_cached_token(): - """Reset the cached token before each test.""" - auth_token_loader._cached_token = None - yield - auth_token_loader._cached_token = None - - -@pytest.fixture(autouse=True) -def clean_env_vars(): - """Clean up environment variables before and after each test.""" - env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] - old_values = {var: os.environ.get(var) for var in env_vars} - - # Clear environment variables - for var in env_vars: - if var in os.environ: - del os.environ[var] - - yield - - # Restore old values - for var, value in old_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] - - -@pytest.fixture -def temp_token_file(): - """Create a temporary token file and clean it up after the test.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f: - temp_path = f.name - f.write("test-token-from-file") - yield temp_path - try: - os.unlink(temp_path) - except FileNotFoundError: - pass - - -@pytest.fixture -def default_token_path(): - """Return the default token path and clean it up after the test.""" - path = Path.home() / ".ray" / "auth_token" - yield path - try: - path.unlink() - except FileNotFoundError: - pass - - -class TestLoadAuthToken: - """Tests for load_auth_token function.""" - - def test_load_from_env_variable(self): - """Test loading token from RAY_AUTH_TOKEN environment variable.""" - os.environ["RAY_AUTH_TOKEN"] = "token-from-env" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-env" - - def test_load_from_env_path(self, temp_token_file): - """Test loading token from RAY_AUTH_TOKEN_PATH environment variable.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "test-token-from-file" - - def test_load_from_default_path(self, default_token_path): - """Test loading token from default ~/.ray/auth_token path.""" - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-from-default" - - @pytest.mark.parametrize( - "set_env,set_file,set_default,expected_token", - [ - # All set: env should win - (True, True, True, "token-from-env"), - # File and default file set: file should win - (False, True, True, "test-token-from-file"), - # Only default file set - (False, False, True, "token-from-default"), - ], - ) - def test_token_precedence_parametrized( - self, - temp_token_file, - default_token_path, - set_env, - set_file, - set_default, - expected_token, - ): - """Parametrized test for token loading precedence: env var > user-specified file > default file.""" - # Optionally set environment variable - if set_env: - os.environ["RAY_AUTH_TOKEN"] = "token-from-env" - else: - os.environ.pop("RAY_AUTH_TOKEN", None) - - # Optionally create file and set path - if set_file: - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - else: - os.environ.pop("RAY_AUTH_TOKEN_PATH", None) - - # Optionally create default file - if set_default: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text("token-from-default") - else: - if default_token_path.exists(): - default_token_path.unlink() - - # Load token and verify expected precedence - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == expected_token - - def test_no_token_found(self): - """Test behavior when no token is found.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_whitespace_handling(self, temp_token_file): - """Test that whitespace is properly trimmed from token files.""" - # Overwrite the temp file with whitespace - with open(temp_token_file, "w") as f: - f.write(" token-with-spaces \n\t") - - os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "token-with-spaces" - - def test_empty_env_variable(self): - """Test that empty environment variable is ignored.""" - os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_nonexistent_path_in_env(self): - """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" - os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" - with pytest.raises(FileNotFoundError): - auth_token_loader.load_auth_token(generate_if_not_found=False) - - -class TestTokenGeneration: - """Tests for token generation functionality.""" - - def test_generate_token(self, default_token_path): - """Test token generation when no token exists.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - - # Token should be a 32-character hex string (UUID without dashes) - assert len(token) == 32 - assert all(c in "0123456789abcdef" for c in token) - - # Token should be saved to default path - assert default_token_path.exists() - saved_token = default_token_path.read_text().strip() - assert saved_token == token - - def test_no_generation_without_flag(self): - """Test that token is not generated when flag is False.""" - token = auth_token_loader.load_auth_token(generate_if_not_found=False) - assert token == "" - - def test_dont_generate_when_token_exists(self, default_token_path): - """Test that token is not generated when one already exists.""" - os.environ["RAY_AUTH_TOKEN"] = "existing-token" - token = auth_token_loader.load_auth_token(generate_if_not_found=True) - assert token == "existing-token" - generated_token = auth_token_loader.generate_and_save_token() - assert generated_token == "existing-token" # does not generate a new token - assert not default_token_path.exists() - - def test_public_generate_and_save_token(self, default_token_path): - """Test the public generate_and_save_token function.""" - token = auth_token_loader.generate_and_save_token() - - # Token should be a 32-character hex string (UUID without dashes) - assert len(token) == 32 - assert all(c in "0123456789abcdef" for c in token) - - # Token should be saved to default path - assert default_token_path.exists() - saved_token = default_token_path.read_text().strip() - assert saved_token == token - - -class TestTokenCaching: - """Tests for token caching behavior.""" - - def test_caching_behavior(self): - """Test that token is cached after first load.""" - os.environ["RAY_AUTH_TOKEN"] = "cached-token" - token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Change environment variable (shouldn't affect cached value) - os.environ["RAY_AUTH_TOKEN"] = "new-token" - token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) - - # Should still return the cached token - assert token1 == token2 == "cached-token" - - -if __name__ == "__main__": - sys.exit(pytest.main(["-vv", __file__])) diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 59a1184e080a..dedb840b4c77 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -72,6 +72,33 @@ std::optional AuthenticationTokenLoader::GetToken() { return *cached_token_; } +bool AuthenticationTokenLoader::HasToken() { + std::lock_guard lock(token_mutex_); + + // If already loaded, check if it's a valid token + if (cached_token_.has_value()) { + return !cached_token_->empty(); + } + + // If token auth is not enabled, no token needed + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { + cached_token_ = std::nullopt; + return false; + } + + // Token auth is enabled, try to load from sources + AuthenticationToken token = LoadTokenFromSources(); + + // Cache the result + if (token.empty()) { + cached_token_ = std::nullopt; + return false; + } else { + cached_token_ = std::move(token); + return true; + } +} + // Read token from the first line of the file. trim whitespace. // Returns empty string if file cannot be opened or is empty. std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h index 4034ecbc78dd..1dc4972125d7 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -40,6 +40,11 @@ class AuthenticationTokenLoader { /// \return The authentication token, or std::nullopt if auth is disabled. std::optional GetToken(); + /// Check if a token exists without crashing. + /// Caches the token if it loads it afresh. + /// \return true if a token exists, false otherwise. + bool HasToken(); + void ResetCache() { std::lock_guard lock(token_mutex_); cached_token_.reset(); From b6c667a6f982b8da333bd1ad53e9eb135a511e26 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 15:10:27 +0000 Subject: [PATCH 35/94] fix lint Signed-off-by: sampan --- python/ray/_private/authentication_token_setup.py | 6 ++---- python/ray/_raylet.pyx | 12 ++++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index 4b0375a46b5c..5da87137f225 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -def generate_and_save_token() -> str: +def generate_and_save_token() -> None: """Generate a new random token and save it in the default token path. Returns: @@ -38,8 +38,6 @@ def generate_and_save_token() -> str: logger.warning(f"Failed to save generated token to {token_path}: {e}. ") raise - return token - def _get_default_token_path() -> Path: """Get the default token file path (~/.ray/auth_token). @@ -86,7 +84,7 @@ def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> No if not token_loader.has_token(): if is_new_cluster: # Generate token for new local cluster - token = generate_and_save_token() + generate_and_save_token() # Reload cache so that subsequent calls to token_loader read the new token token_loader.reset_cache() diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index b3e072f4dcdf..bcfb09417dd3 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -4960,7 +4960,7 @@ class AuthenticationMode: def get_authentication_mode(): """Get the current authentication mode. - + Returns: AuthenticationMode enum value (DISABLED or TOKEN) """ @@ -4969,23 +4969,23 @@ def get_authentication_mode(): class AuthenticationTokenLoader: """Python wrapper for C++ AuthenticationTokenLoader singleton.""" - + @staticmethod def instance(): """Get the singleton instance (returns a wrapper for convenience).""" return AuthenticationTokenLoader() - + def has_token(self): """Check if an authentication token exists without crashing. - + Returns: bool: True if a token exists, False otherwise """ return CAuthenticationTokenLoader.instance().HasToken() - + def reset_cache(self): """Reset the C++ authentication token cache. - + This forces the token loader to reload the token from environment variables or files on the next request. """ From 52d18acada737e6e307d2e29d1b684297575ebee Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 15:13:23 +0000 Subject: [PATCH 36/94] fix doc string Signed-off-by: sampan --- .../_private/authentication_token_setup.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index 5da87137f225..7a69ebe39b3a 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -51,18 +51,18 @@ def _get_default_token_path() -> Path: def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> None: """Check authentication settings and setup necessary resources. - for token based authentication, Ray calls this early during ray.init() to do the following: - 1. Check if you enabled token-based authentication. + Ray calls this early during ray.init() to do the following for token-based authentication: + 1. Check whether you enabled token-based authentication. 2. Make sure a token is available if authentication is enabled. - 3. Generate and save a default token for new local clusters if one doesn't exist. + 3. Generate and save a default token for new local clusters if one doesn't already exist. Args: - system_config: Raises an error if you set auth_mode in system_config instead of the environment. - is_new_cluster: Set to True if starting a new local cluster, or False if connecting + system_config: Ray raises an error if you set auth_mode in system_config instead of the environment. + is_new_cluster: Set to True if you're starting a new local cluster, or False if you're connecting to an existing cluster. Raises: - RuntimeError: If authentication is enabled but no token is found when connecting + RuntimeError: Ray raises this error if authentication is enabled but no token is found when connecting to an existing cluster. """ from ray._raylet import ( @@ -71,7 +71,7 @@ def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> No get_authentication_mode, ) - # Check if token authentication is enabled + # Check if you enabled token authentication. if get_authentication_mode() != AuthenticationMode.TOKEN: if system_config and system_config.get("auth_mode") != "disabled": raise RuntimeError( @@ -83,15 +83,15 @@ def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> No if not token_loader.has_token(): if is_new_cluster: - # Generate token for new local cluster + # Generate a token for a new local cluster. generate_and_save_token() - # Reload cache so that subsequent calls to token_loader read the new token + # Reload the cache so subsequent calls to token_loader read the new token. token_loader.reset_cache() else: - # You're connecting to an existing cluster—token must already exist + # You're connecting to an existing cluster, so an authentication token must already exist. raise RuntimeError( - "Token authentication is enabled but no authentication token was found. Please provide a token using one of:\n" + "Token authentication is enabled but no authentication token was found. Please provide a token with one of these options:\n" " 1. RAY_AUTH_TOKEN environment variable\n" " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" " 3. Default token file: ~/.ray/auth_token" From b119faeb35249b9036dd5dbdf89b966c35cf6362 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 15:16:51 +0000 Subject: [PATCH 37/94] add type hints Signed-off-by: sampan --- python/ray/_private/authentication_token_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index 7a69ebe39b3a..69c6b223103a 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -7,6 +7,7 @@ import logging import os +from typing import Any, Dict, Optional import uuid from pathlib import Path @@ -48,7 +49,7 @@ def _get_default_token_path() -> Path: return Path.home() / ".ray" / "auth_token" -def setup_and_verify_auth(system_config=None, is_new_cluster: bool = True) -> None: +def setup_and_verify_auth(system_config: Optional[Dict[str, Any]] = None, is_new_cluster: bool = True) -> None: """Check authentication settings and setup necessary resources. Ray calls this early during ray.init() to do the following for token-based authentication: From b6b7a959b342d9e50a487e9e8088c7872b190322 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 15:17:00 +0000 Subject: [PATCH 38/94] lint Signed-off-by: sampan --- python/ray/_private/authentication_token_setup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index 69c6b223103a..03cc8e0f7565 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -7,9 +7,9 @@ import logging import os -from typing import Any, Dict, Optional import uuid from pathlib import Path +from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -49,7 +49,9 @@ def _get_default_token_path() -> Path: return Path.home() / ".ray" / "auth_token" -def setup_and_verify_auth(system_config: Optional[Dict[str, Any]] = None, is_new_cluster: bool = True) -> None: +def setup_and_verify_auth( + system_config: Optional[Dict[str, Any]] = None, is_new_cluster: bool = True +) -> None: """Check authentication settings and setup necessary resources. Ray calls this early during ray.init() to do the following for token-based authentication: From ce73705ee7043afe6cbafaf67ec20b961f002689 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:17:38 +0000 Subject: [PATCH 39/94] Add authentication token logic and related tests Signed-off-by: sampan --- src/ray/common/constants.h | 2 + src/ray/common/ray_config_def.h | 6 + src/ray/rpc/authentication/BUILD.bazel | 34 ++ .../rpc/authentication/authentication_mode.cc | 37 ++ .../rpc/authentication/authentication_mode.h | 33 ++ .../rpc/authentication/authentication_token.h | 157 ++++++++ .../authentication_token_loader.cc | 167 +++++++++ .../authentication_token_loader.h | 72 ++++ src/ray/rpc/tests/BUILD.bazel | 44 +++ .../tests/authentication_token_loader_test.cc | 346 ++++++++++++++++++ .../rpc/tests/authentication_token_test.cc | 131 +++++++ 11 files changed, 1029 insertions(+) create mode 100644 src/ray/rpc/authentication/BUILD.bazel create mode 100644 src/ray/rpc/authentication/authentication_mode.cc create mode 100644 src/ray/rpc/authentication/authentication_mode.h create mode 100644 src/ray/rpc/authentication/authentication_token.h create mode 100644 src/ray/rpc/authentication/authentication_token_loader.cc create mode 100644 src/ray/rpc/authentication/authentication_token_loader.h create mode 100644 src/ray/rpc/tests/authentication_token_loader_test.cc create mode 100644 src/ray/rpc/tests/authentication_token_test.cc diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index bfd06e677e7e..08986d3b415e 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,6 +42,8 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; +constexpr char kAuthTokenKey[] = "authorization"; +constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"; diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 6e8d21956162..e4e8fc1d48ef 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -35,6 +35,12 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) /// Whether to enable cluster authentication. RAY_CONFIG(bool, enable_cluster_auth, true) +/// Whether to enable token-based authentication for RPC calls. +/// will be converted to AuthenticationMode enum defined in +/// rpc/authentication/authentication_mode.h +/// use GetAuthenticationMode() to get the authentication mode enum value. +RAY_CONFIG(std::string, auth_mode, "disabled") + /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available /// in the associated process's log file. diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel new file mode 100644 index 000000000000..8da78e5d728b --- /dev/null +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -0,0 +1,34 @@ +load("//bazel:ray.bzl", "ray_cc_library") + +ray_cc_library( + name = "authentication_mode", + srcs = ["authentication_mode.cc"], + hdrs = ["authentication_mode.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:ray_config", + "@com_google_absl//absl/strings", + ], +) + +ray_cc_library( + name = "authentication_token", + hdrs = ["authentication_token.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:constants", + "@com_github_grpc_grpc//:grpc++", + ], +) + +ray_cc_library( + name = "authentication_token_loader", + srcs = ["authentication_token_loader.cc"], + hdrs = ["authentication_token_loader.h"], + visibility = ["//visibility:public"], + deps = [ + ":authentication_mode", + ":authentication_token", + "//src/ray/util:logging", + ], +) diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc new file mode 100644 index 000000000000..1bbe209733ce --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -0,0 +1,37 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_mode.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "ray/common/ray_config.h" + +namespace ray { +namespace rpc { + +AuthenticationMode GetAuthenticationMode() { + std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); + + if (auth_mode_lower == "token") { + return AuthenticationMode::TOKEN; + } else { + return AuthenticationMode::DISABLED; + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_mode.h b/src/ray/rpc/authentication/authentication_mode.h new file mode 100644 index 000000000000..21bd165fd34b --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.h @@ -0,0 +1,33 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace rpc { + +enum class AuthenticationMode { + DISABLED, + TOKEN, +}; + +/// Get the authentication mode from the RayConfig. +/// \return The authentication mode enum value. returns AuthenticationMode::DISABLED if +/// the authentication mode is not set or is invalid. +AuthenticationMode GetAuthenticationMode(); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h new file mode 100644 index 000000000000..6846d3c08ada --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token.h @@ -0,0 +1,157 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ray/common/constants.h" + +namespace ray { +namespace rpc { + +/// Secure wrapper for authentication tokens. +/// - Wipes memory on destruction +/// - Constant-time comparison +/// - Redacted output when logged or printed +class AuthenticationToken { + public: + AuthenticationToken() = default; + explicit AuthenticationToken(std::string value) : secret_(value.begin(), value.end()) {} + + AuthenticationToken(const AuthenticationToken &other) : secret_(other.secret_) {} + AuthenticationToken &operator=(const AuthenticationToken &other) { + if (this != &other) { + SecureClear(); + secret_ = other.secret_; + } + return *this; + } + + // Move operations + AuthenticationToken(AuthenticationToken &&other) noexcept { + MoveFrom(std::move(other)); + } + AuthenticationToken &operator=(AuthenticationToken &&other) noexcept { + if (this != &other) { + SecureClear(); + MoveFrom(std::move(other)); + } + return *this; + } + ~AuthenticationToken() { SecureClear(); } + + bool empty() const noexcept { return secret_.empty(); } + + /// Constant-time equality comparison + bool Equals(const AuthenticationToken &other) const noexcept { + return ConstTimeEqual(secret_, other.secret_); + } + + /// Equality operator (constant-time) + bool operator==(const AuthenticationToken &other) const noexcept { + return Equals(other); + } + + /// Inequality operator + bool operator!=(const AuthenticationToken &other) const noexcept { + return !(*this == other); + } + + /// Set authentication metadata on a gRPC client context + /// Only call this from client-side code + void SetMetadata(grpc::ClientContext &context) const { + if (!secret_.empty()) { + context.AddMetadata(kAuthTokenKey, + kBearerPrefix + std::string(secret_.begin(), secret_.end())); + } + } + + /// Create AuthenticationToken from gRPC metadata value + /// Strips "Bearer " prefix and creates token object + /// @param metadata_value The raw value from server metadata (should include "Bearer " + /// prefix) + /// @return AuthenticationToken object (empty if format invalid) + static AuthenticationToken FromMetadata(std::string_view metadata_value) { + const std::string_view prefix(kBearerPrefix, sizeof(kBearerPrefix) - 1); + if (metadata_value.size() <= prefix.size() || + metadata_value.substr(0, prefix.size()) != prefix) { + return AuthenticationToken(); // Invalid format, return empty + } + std::string_view token_part = metadata_value.substr(prefix.size()); + return AuthenticationToken(std::string(token_part)); + } + + friend std::ostream &operator<<(std::ostream &os, const AuthenticationToken &t) { + return os << ""; + } + + private: + std::vector secret_; + + // Constant-time string comparison to avoid timing attacks. + // https://en.wikipedia.org/wiki/Timing_attack + static bool ConstTimeEqual(const std::vector &a, + const std::vector &b) noexcept { + if (a.size() != b.size()) { + return false; + } + unsigned char diff = 0; + for (size_t i = 0; i < a.size(); ++i) { + diff |= a[i] ^ b[i]; + } + return diff == 0; + } + + // replace the characters in the memory with 0 + static void ExplicitBurn(void *p, size_t n) noexcept { +#if defined(_MSC_VER) + SecureZeroMemory(p, n); +#elif defined(__STDC_LIB_EXT1__) + memset_s(p, n, 0, n); +#else + // Using array indexing instead of pointer arithmetic + volatile auto *vp = static_cast(p); + for (size_t i = 0; i < n; ++i) { + vp[i] = 0; + } +#endif + } + + void SecureClear() noexcept { + if (!secret_.empty()) { + ExplicitBurn(secret_.data(), secret_.size()); + secret_.clear(); + } + } + + void MoveFrom(AuthenticationToken &&other) noexcept { + SecureClear(); + secret_ = std::move(other.secret_); + // Clear the moved-from object explicitly for security + // Note: 'other' is already an rvalue reference, no need to move again + other.SecureClear(); + } +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc new file mode 100644 index 000000000000..59a1184e080a --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -0,0 +1,167 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include +#include + +#include "ray/util/logging.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#endif + +namespace ray { +namespace rpc { + +AuthenticationTokenLoader &AuthenticationTokenLoader::instance() { + static AuthenticationTokenLoader instance; + return instance; +} + +std::optional AuthenticationTokenLoader::GetToken() { + std::lock_guard lock(token_mutex_); + + // If already loaded, return cached value + if (cached_token_.has_value()) { + return cached_token_; + } + + // If token auth is not enabled, return std::nullopt + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { + cached_token_ = std::nullopt; + return std::nullopt; + } + + // Token auth is enabled, try to load from sources + AuthenticationToken token = LoadTokenFromSources(); + + // If no token found and auth is enabled, fail with RAY_CHECK + RAY_CHECK(!token.empty()) + << "Token authentication is enabled but no authentication token was found. " + << "Please set RAY_AUTH_TOKEN environment variable, RAY_AUTH_TOKEN_PATH to a file " + << "containing the token, or create a token file at ~/.ray/auth_token"; + + // Cache and return the loaded token + cached_token_ = std::move(token); + return *cached_token_; +} + +// Read token from the first line of the file. trim whitespace. +// Returns empty string if file cannot be opened or is empty. +std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { + std::ifstream token_file(file_path); + if (!token_file.is_open()) { + return ""; + } + + std::string token; + std::getline(token_file, token); + token_file.close(); + return token; +} + +AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { + // Precedence 1: RAY_AUTH_TOKEN environment variable + const char *env_token = std::getenv("RAY_AUTH_TOKEN"); + if (env_token != nullptr && std::string(env_token).length() > 0) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return AuthenticationToken(TrimWhitespace(std::string(env_token))); + } + + // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable + const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); + if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { + std::string token_str = TrimWhitespace(ReadTokenFromFile(env_token_path)); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + return AuthenticationToken(token_str); + } else { + RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " + << env_token_path; + } + } + + // Precedence 3: Default token path ~/.ray/auth_token + std::string default_path = GetDefaultTokenPath(); + std::string token_str = TrimWhitespace(ReadTokenFromFile(default_path)); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; + return AuthenticationToken(token_str); + } + + // No token found + RAY_LOG(DEBUG) << "No authentication token found in any source"; + return AuthenticationToken(); +} + +std::string AuthenticationTokenLoader::GetDefaultTokenPath() { + std::string home_dir; + +#ifdef _WIN32 + const char *path_separator = "\\"; + const char *userprofile = std::getenv("USERPROFILE"); + if (userprofile != nullptr) { + home_dir = userprofile; + } else { + const char *homedrive = std::getenv("HOMEDRIVE"); + const char *homepath = std::getenv("HOMEPATH"); + if (homedrive != nullptr && homepath != nullptr) { + home_dir = std::string(homedrive) + std::string(homepath); + } + } +#else + const char *path_separator = "/"; + const char *home = std::getenv("HOME"); + if (home != nullptr) { + home_dir = home; + } +#endif + + const std::string token_subpath = + std::string(path_separator) + ".ray" + std::string(path_separator) + "auth_token"; + + if (home_dir.empty()) { + RAY_LOG(WARNING) << "Cannot determine home directory for token storage"; + return "." + token_subpath; + } + + return home_dir + token_subpath; +} + +std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { + std::string whitespace = " \t\n\r\f\v"; + std::string trimmed_str = str; + trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); + return trimmed_str; +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h new file mode 100644 index 000000000000..4034ecbc78dd --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -0,0 +1,72 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions +// and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "ray/rpc/authentication/authentication_mode.h" +#include "ray/rpc/authentication/authentication_token.h" + +namespace ray { +namespace rpc { + +/// Singleton class for loading and caching authentication tokens. +/// Supports loading tokens from multiple sources with precedence: +/// 1. RAY_AUTH_TOKEN environment variable +/// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) +/// 3. Default token path: ~/.ray/auth_token (Unix) or %USERPROFILE%\.ray\auth_token +/// +/// Thread-safe with internal caching to avoid repeated file I/O. +class AuthenticationTokenLoader { + public: + static AuthenticationTokenLoader &instance(); + + /// Get the authentication token. + /// If token authentication is enabled but no token is found, fails with RAY_CHECK. + /// \return The authentication token, or std::nullopt if auth is disabled. + std::optional GetToken(); + + void ResetCache() { + std::lock_guard lock(token_mutex_); + cached_token_.reset(); + } + + AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete; + AuthenticationTokenLoader &operator=(const AuthenticationTokenLoader &) = delete; + + private: + AuthenticationTokenLoader() = default; + ~AuthenticationTokenLoader() = default; + + /// Read and trim token from file. + std::string ReadTokenFromFile(const std::string &file_path); + + /// Load token from environment or file. + AuthenticationToken LoadTokenFromSources(); + + /// Default token file path (~/.ray/auth_token or %USERPROFILE%\.ray\auth_token). + std::string GetDefaultTokenPath(); + + /// Trim whitespace from the beginning and end of the string. + std::string TrimWhitespace(const std::string &str); + + std::mutex token_mutex_; + std::optional cached_token_; +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 5fa8b14cc4db..279b68f91ba3 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,6 +18,23 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", + "grpc_test_common.h", + ], + tags = ["team:core"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "grpc_auth_token_tests", + size = "small", + srcs = [ + "grpc_auth_token_tests.cc", + "grpc_test_common.h", ], tags = ["team:core"], deps = [ @@ -40,3 +57,30 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "authentication_token_loader_test", + size = "small", + srcs = [ + "authentication_token_loader_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "authentication_token_test", + size = "small", + srcs = [ + "authentication_token_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/rpc/authentication:authentication_token", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc new file mode 100644 index 000000000000..616a13b0e457 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -0,0 +1,346 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#include // For _mkdir on Windows +#include // For _getpid on Windows +#endif + +namespace ray { +namespace rpc { + +class AuthenticationTokenLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication for tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + + // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory + // This ensures tests work in environments where HOME isn't provided +#ifdef _WIN32 + if (std::getenv("USERPROFILE") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "\\ray_test_home"; + } else { + test_home_dir_ = "C:\\Windows\\Temp\\ray_test_home"; + } + _putenv(("USERPROFILE=" + test_home_dir_).c_str()); + } + const char *home_dir = std::getenv("USERPROFILE"); + default_token_path_ = std::string(home_dir) + "\\.ray\\auth_token"; +#else + if (std::getenv("HOME") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "/ray_test_home"; + } else { + test_home_dir_ = "/tmp/ray_test_home"; + } + setenv("HOME", test_home_dir_.c_str(), 1); + } + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + default_token_path_ = std::string(home_dir) + "/.ray/auth_token"; + test_home_dir_ = home_dir; + } else { + default_token_path_ = ".ray/auth_token"; + } +#endif + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + // Clean up after test + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + // Disable token auth after tests + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + + void cleanup_env() { + unset_env_var("RAY_AUTH_TOKEN"); + unset_env_var("RAY_AUTH_TOKEN_PATH"); + remove(default_token_path_.c_str()); + } + + std::string get_temp_token_path() { +#ifdef _WIN32 + return "C:\\Windows\\Temp\\ray_test_token_" + std::to_string(_getpid()); +#else + return "/tmp/ray_test_token_" + std::to_string(getpid()); +#endif + } + + void set_env_var(const char *name, const char *value) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "=" + std::string(value); + _putenv(env_str.c_str()); +#else + setenv(name, value, 1); +#endif + } + + void unset_env_var(const char *name) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "="; + _putenv(env_str.c_str()); +#else + unsetenv(name); +#endif + } + + void ensure_ray_dir_exists() { +#ifdef _WIN32 + const char *home_dir = std::getenv("USERPROFILE"); + _mkdir(home_dir); // Create parent directory + std::string ray_dir = std::string(home_dir) + "\\.ray"; + _mkdir(ray_dir.c_str()); +#else + // Always ensure the home directory exists (it might be a test temp dir we created) + if (!test_home_dir_.empty()) { + mkdir(test_home_dir_.c_str(), + 0700); // Create if it doesn't exist (ignore error if it does) + } + + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + std::string ray_dir = std::string(home_dir) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + } +#endif + } + + void write_token_file(const std::string &path, const std::string &content) { + std::ofstream token_file(path); + token_file << content; + token_file.close(); + } + + std::string default_token_path_; + std::string test_home_dir_; // Fallback home directory for tests +}; + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { + // Set token in environment variable + set_env_var("RAY_AUTH_TOKEN", "test-token-from-env"); + + // Create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-env"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { + // Create a temporary token file + std::string temp_token_path = get_temp_token_path(); + write_token_file(temp_token_path, "test-token-from-file"); + + // Set path in environment variable + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-file"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); + + // Clean up + remove(temp_token_path.c_str()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { + // Create directory and token file in default location + ensure_ray_dir_exists(); + write_token_file(default_token_path_, "test-token-from-default"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-default"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +// Parametrized test for token loading precedence: env var > user-specified file > default +// file + +struct TokenSourceConfig { + bool set_env = false; + bool set_file = false; + bool set_default = false; + std::string expected_token; + std::string env_token = "token-from-env"; + std::string file_token = "token-from-path"; + std::string default_token = "token-from-default"; +}; + +class AuthenticationTokenLoaderPrecedenceTest + : public AuthenticationTokenLoaderTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, + AuthenticationTokenLoaderPrecedenceTest, + ::testing::Values( + // All set: env should win + TokenSourceConfig{true, true, true, "token-from-env"}, + // File and default file set: file should win + TokenSourceConfig{false, true, true, "token-from-path"}, + // Only default file set + TokenSourceConfig{ + false, false, true, "token-from-default"})); + +TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { + const auto ¶m = GetParam(); + + // Optionally set environment variable + if (param.set_env) { + set_env_var("RAY_AUTH_TOKEN", param.env_token.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN"); + } + + // Optionally create file and set path + std::string temp_token_path = get_temp_token_path(); + if (param.set_file) { + write_token_file(temp_token_path, param.file_token); + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN_PATH"); + } + + // Optionally create default file + ensure_ray_dir_exists(); + if (param.set_default) { + write_token_file(default_token_path_, param.default_token); + } else { + remove(default_token_path_.c_str()); + } + + // Always create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected(param.expected_token); + EXPECT_TRUE(token_opt->Equals(expected)); + + // Clean up token file if it was written + if (param.set_file) { + remove(temp_token_path.c_str()); + } + // Clean up default file if it was written + if (param.set_default) { + remove(default_token_path_.c_str()); + } +} + +TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { + // Disable auth for this specific test + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + + // No token set anywhere, but auth is disabled + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + EXPECT_FALSE(token_opt.has_value()); + EXPECT_FALSE(loader.GetToken().has_value()); + + // Re-enable for other tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); +} + +TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { + // Token auth is already enabled in SetUp() + // No token exists, should trigger RAY_CHECK failure + EXPECT_DEATH( + { + auto &loader = AuthenticationTokenLoader::instance(); + loader.GetToken(); + }, + "Token authentication is enabled but no authentication token was found"); +} + +TEST_F(AuthenticationTokenLoaderTest, TestCaching) { + // Set token in environment + set_env_var("RAY_AUTH_TOKEN", "cached-token"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt1 = loader.GetToken(); + + // Change environment variable (shouldn't affect cached value) + set_env_var("RAY_AUTH_TOKEN", "new-token"); + auto token_opt2 = loader.GetToken(); + + // Should still return the cached token + ASSERT_TRUE(token_opt1.has_value()); + ASSERT_TRUE(token_opt2.has_value()); + EXPECT_TRUE(token_opt1->Equals(*token_opt2)); + AuthenticationToken expected("cached-token"); + EXPECT_TRUE(token_opt2->Equals(expected)); +} + +TEST_F(AuthenticationTokenLoaderTest, TestWhitespaceHandling) { + // Create token file with whitespace + ensure_ray_dir_exists(); + write_token_file(default_token_path_, " token-with-spaces \n\t"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + // Whitespace should be trimmed + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("token-with-spaces"); + EXPECT_TRUE(token_opt->Equals(expected)); +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc new file mode 100644 index 000000000000..db88d7481da1 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -0,0 +1,131 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace ray { +namespace rpc { + +class AuthenticationTokenTest : public ::testing::Test {}; + +TEST_F(AuthenticationTokenTest, TestDefaultConstructor) { + AuthenticationToken token; + EXPECT_TRUE(token.empty()); +} + +TEST_F(AuthenticationTokenTest, TestConstructorWithValue) { + AuthenticationToken token("test-token-value"); + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token-value"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestMoveConstructor) { + AuthenticationToken token1("original-token"); + AuthenticationToken token2(std::move(token1)); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("original-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestMoveAssignment) { + AuthenticationToken token1("first-token"); + AuthenticationToken token2("second-token"); + + token2 = std::move(token1); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("first-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestSelfMoveAssignment) { + AuthenticationToken token("test-token"); + + // Self-assignment should not break the token + token = std::move(token); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestEquals) { + AuthenticationToken token1("same-token"); + AuthenticationToken token2("same-token"); + AuthenticationToken token3("different-token"); + + EXPECT_TRUE(token1.Equals(token2)); + EXPECT_FALSE(token1.Equals(token3)); + EXPECT_TRUE(token1 == token2); + EXPECT_FALSE(token1 == token3); + EXPECT_FALSE(token1 != token2); + EXPECT_TRUE(token1 != token3); +} + +TEST_F(AuthenticationTokenTest, TestEqualityDifferentLengths) { + AuthenticationToken token1("short"); + AuthenticationToken token2("much-longer-token"); + + EXPECT_FALSE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyTokens) { + AuthenticationToken token1; + AuthenticationToken token2; + + EXPECT_TRUE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyVsNonEmpty) { + AuthenticationToken token1; + AuthenticationToken token2("non-empty"); + + EXPECT_FALSE(token1.Equals(token2)); + EXPECT_FALSE(token2.Equals(token1)); +} + +TEST_F(AuthenticationTokenTest, TestRedactedOutput) { + AuthenticationToken token("super-secret-token"); + + std::ostringstream oss; + oss << token; + + std::string output = oss.str(); + EXPECT_EQ(output, ""); + EXPECT_EQ(output.find("super-secret-token"), std::string::npos); +} + +TEST_F(AuthenticationTokenTest, TestEmptyString) { + AuthenticationToken token(""); + EXPECT_TRUE(token.empty()); + AuthenticationToken expected(""); + EXPECT_TRUE(token.Equals(expected)); +} +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 341b108707c0e6f0abbcd5db7a5e56390bb55937 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:20:13 +0000 Subject: [PATCH 40/94] Add gRPC service and server logic with auth integration tests Signed-off-by: sampan --- src/ray/common/grpc_util.h | 4 + src/ray/common/status.cc | 2 +- src/ray/common/status.h | 8 +- src/ray/core_worker/BUILD.bazel | 1 + src/ray/core_worker/grpc_service.cc | 76 +++--- src/ray/core_worker/grpc_service.h | 6 +- src/ray/gcs/BUILD.bazel | 1 + src/ray/gcs/grpc_services.cc | 38 +-- src/ray/gcs/grpc_services.h | 38 ++- .../gcs_rpc_client/tests/gcs_client_test.cc | 4 +- src/ray/raylet/node_manager.cc | 2 +- src/ray/rpc/BUILD.bazel | 5 + src/ray/rpc/client_call.h | 15 +- src/ray/rpc/grpc_server.cc | 12 +- src/ray/rpc/grpc_server.h | 61 +++-- .../rpc/node_manager/node_manager_server.h | 9 +- src/ray/rpc/object_manager_server.h | 9 +- src/ray/rpc/server_call.h | 105 +++++++-- src/ray/rpc/tests/grpc_auth_token_tests.cc | 221 ++++++++++++++++++ src/ray/rpc/tests/grpc_bench/BUILD.bazel | 1 + src/ray/rpc/tests/grpc_bench/grpc_bench.cc | 10 +- src/ray/rpc/tests/grpc_server_client_test.cc | 83 +------ src/ray/rpc/tests/grpc_test_common.h | 109 +++++++++ 23 files changed, 623 insertions(+), 197 deletions(-) create mode 100644 src/ray/rpc/tests/grpc_auth_token_tests.cc create mode 100644 src/ray/rpc/tests/grpc_test_common.h diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index ae99eaf79081..ed2f8c73eda1 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -83,6 +83,10 @@ inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { if (ray_status.ok()) { return grpc::Status::OK; } + // Map Unauthenticated to gRPC's UNAUTHENTICATED status code + if (ray_status.IsUnauthenticated()) { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, ray_status.message()); + } // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means // more robust. return grpc::Status( diff --git a/src/ray/common/status.cc b/src/ray/common/status.cc index 3500ddaf3b80..528a6766412e 100644 --- a/src/ray/common/status.cc +++ b/src/ray/common/status.cc @@ -74,7 +74,7 @@ const absl::flat_hash_map kCodeToStr = { {StatusCode::RpcError, "RpcError"}, {StatusCode::OutOfResource, "OutOfResource"}, {StatusCode::ObjectRefEndOfStream, "ObjectRefEndOfStream"}, - {StatusCode::AuthError, "AuthError"}, + {StatusCode::Unauthenticated, "Unauthenticated"}, {StatusCode::InvalidArgument, "InvalidArgument"}, {StatusCode::ChannelError, "ChannelError"}, {StatusCode::ChannelTimeoutError, "ChannelTimeoutError"}, diff --git a/src/ray/common/status.h b/src/ray/common/status.h index 2544918ac263..f04040cea934 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -263,7 +263,7 @@ enum class StatusCode : char { RpcError = 30, OutOfResource = 31, ObjectRefEndOfStream = 32, - AuthError = 33, + Unauthenticated = 33, // Indicates the input value is not valid. InvalidArgument = 34, // Indicates that a channel (a mutable plasma object) is closed and cannot be @@ -415,8 +415,8 @@ class RAY_EXPORT Status { return Status(StatusCode::OutOfResource, msg); } - static Status AuthError(const std::string &msg) { - return Status(StatusCode::AuthError, msg); + static Status Unauthenticated(const std::string &msg) { + return Status(StatusCode::Unauthenticated, msg); } static Status ChannelError(const std::string &msg) { @@ -475,7 +475,7 @@ class RAY_EXPORT Status { bool IsOutOfResource() const { return code() == StatusCode::OutOfResource; } - bool IsAuthError() const { return code() == StatusCode::AuthError; } + bool IsUnauthenticated() const { return code() == StatusCode::Unauthenticated; } bool IsChannelError() const { return code() == StatusCode::ChannelError; } diff --git a/src/ray/core_worker/BUILD.bazel b/src/ray/core_worker/BUILD.bazel index a92f4f4323b0..87fee53bb63d 100644 --- a/src/ray/core_worker/BUILD.bazel +++ b/src/ray/core_worker/BUILD.bazel @@ -78,6 +78,7 @@ ray_cc_library( "//src/ray/protobuf:core_worker_cc_proto", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index e5540aa502df..adb5b62786d4 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -15,6 +15,7 @@ #include "ray/core_worker/grpc_service.h" #include +#include #include namespace ray { @@ -23,91 +24,104 @@ namespace rpc { void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, PushTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + PushTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ActorCallArgWaitComplete, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RayletNotifyGCSRestart, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectStatus, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, WaitForActorRefDeleted, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubLongPolling, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubCommandBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, UpdateObjectLocationBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectLocationsOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ReportGeneratorItemReturns, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, KillActor, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + KillActor, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + CancelTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, CancelRemoteTask, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RegisterMutableObjectReader, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetCoreWorkerStats, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, LocalGC, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, DeleteObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, SpillObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + LocalGC, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + DeleteObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + SpillObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RestoreSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, DeleteSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PlasmaObjectReady, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, Exit, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + CoreWorkerService, Exit, max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, AssignObjectOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, NumPendingTasks, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); } } // namespace rpc diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index 4559a45447c1..d605f5176533 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -29,9 +29,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/core_worker.grpc.pb.h" @@ -158,7 +161,8 @@ class CoreWorkerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: CoreWorkerService::AsyncService service_; diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index 2511da321245..b764ad410db6 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -353,6 +353,7 @@ ray_cc_library( "//src/ray/protobuf:gcs_service_cc_grpc", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index f1f3c55af3f1..66b4397782c2 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -22,7 +22,8 @@ namespace rpc { void ActorInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { /// The register & create actor RPCs take a long time, so we shouldn't limit their /// concurrency to avoid distributed deadlock. RPC_SERVICE_HANDLER(ActorInfoGcsService, RegisterActor, -1) @@ -42,13 +43,14 @@ void ActorInfoGrpcService::InitServerCallFactories( void NodeInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { // We only allow one cluster ID in the lifetime of a client. // So, if a client connects, it should not have a pre-existing different ID. RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, GetClusterId, max_active_rpcs_per_handler_, - AuthType::EMPTY_AUTH); + ClusterIdAuthType::EMPTY_AUTH); RPC_SERVICE_HANDLER(NodeInfoGcsService, RegisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, UnregisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, DrainNode, max_active_rpcs_per_handler_) @@ -61,7 +63,8 @@ void NodeInfoGrpcService::InitServerCallFactories( void NodeResourceInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( NodeResourceInfoGcsService, GetAllAvailableResources, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -75,7 +78,8 @@ void NodeResourceInfoGrpcService::InitServerCallFactories( void InternalPubSubGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalPubSubGcsService, GcsPublish, max_active_rpcs_per_handler_); RPC_SERVICE_HANDLER( InternalPubSubGcsService, GcsSubscriberPoll, max_active_rpcs_per_handler_); @@ -86,7 +90,8 @@ void InternalPubSubGrpcService::InitServerCallFactories( void JobInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(JobInfoGcsService, AddJob, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, MarkJobFinished, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, GetAllJobInfo, max_active_rpcs_per_handler_) @@ -97,7 +102,8 @@ void JobInfoGrpcService::InitServerCallFactories( void RuntimeEnvGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( RuntimeEnvGcsService, PinRuntimeEnvURI, max_active_rpcs_per_handler_) } @@ -105,7 +111,8 @@ void RuntimeEnvGrpcService::InitServerCallFactories( void WorkerInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( WorkerInfoGcsService, ReportWorkerFailure, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(WorkerInfoGcsService, GetWorkerInfo, max_active_rpcs_per_handler_) @@ -121,7 +128,8 @@ void WorkerInfoGrpcService::InitServerCallFactories( void InternalKVGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalKVGcsService, InternalKVGet, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( InternalKVGcsService, InternalKVMultiGet, max_active_rpcs_per_handler_) @@ -137,7 +145,8 @@ void InternalKVGrpcService::InitServerCallFactories( void TaskInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(TaskInfoGcsService, AddTaskEventData, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(TaskInfoGcsService, GetTaskEvents, max_active_rpcs_per_handler_) } @@ -145,7 +154,8 @@ void TaskInfoGrpcService::InitServerCallFactories( void PlacementGroupInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( PlacementGroupInfoGcsService, CreatePlacementGroup, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -166,7 +176,8 @@ namespace autoscaler { void AutoscalerStateGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( AutoscalerStateService, GetClusterResourceState, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -188,7 +199,8 @@ namespace events { void RayEventExportGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(RayEventExportGcsService, AddEvents, max_active_rpcs_per_handler_) } diff --git a/src/ray/gcs/grpc_services.h b/src/ray/gcs/grpc_services.h index d8a0899e2439..f7b34746114d 100644 --- a/src/ray/gcs/grpc_services.h +++ b/src/ray/gcs/grpc_services.h @@ -23,11 +23,13 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/id.h" #include "ray/gcs/grpc_service_interfaces.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/autoscaler.grpc.pb.h" @@ -51,7 +53,8 @@ class ActorInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: ActorInfoGcsService::AsyncService service_; @@ -74,7 +77,8 @@ class NodeInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: NodeInfoGcsService::AsyncService service_; @@ -97,7 +101,8 @@ class NodeResourceInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: NodeResourceInfoGcsService::AsyncService service_; @@ -120,7 +125,8 @@ class InternalPubSubGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: InternalPubSubGcsService::AsyncService service_; @@ -143,7 +149,8 @@ class JobInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: JobInfoGcsService::AsyncService service_; @@ -166,7 +173,8 @@ class RuntimeEnvGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: RuntimeEnvGcsService::AsyncService service_; @@ -189,7 +197,8 @@ class WorkerInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: WorkerInfoGcsService::AsyncService service_; @@ -212,7 +221,8 @@ class InternalKVGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: InternalKVGcsService::AsyncService service_; @@ -235,7 +245,8 @@ class TaskInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: TaskInfoGcsService::AsyncService service_; @@ -258,7 +269,8 @@ class PlacementGroupInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: PlacementGroupInfoGcsService::AsyncService service_; @@ -283,7 +295,8 @@ class AutoscalerStateGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: AutoscalerStateService::AsyncService service_; @@ -310,7 +323,8 @@ class RayEventExportGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: RayEventExportGcsService::AsyncService service_; diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index 0d7fe2a71b73..620dc9b9a985 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -220,7 +220,7 @@ class GcsClientTest : public ::testing::TestWithParam { auto status = stub->CheckAlive(&context, request, &reply); // If it is in memory, we don't have the new token until we connect again. if (!((!no_redis_ && status.ok()) || - (no_redis_ && GrpcStatusToRayStatus(status).IsAuthError()))) { + (no_redis_ && GrpcStatusToRayStatus(status).IsUnauthenticated()))) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); continue; @@ -991,7 +991,7 @@ TEST_P(GcsClientTest, TestGcsEmptyAuth) { auto status = stub->GetClusterId(&context, request, &reply); // We expect the wrong cluster ID - EXPECT_TRUE(GrpcStatusToRayStatus(status).IsAuthError()); + EXPECT_TRUE(GrpcStatusToRayStatus(status).IsUnauthenticated()); } TEST_P(GcsClientTest, TestGcsAuth) { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 90eaa40d4985..d9e30d76917a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -452,7 +452,7 @@ void NodeManager::RegisterGcs() { << "GCS consider this node to be dead. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " << "in the DB."; - } else if (status.IsAuthError()) { + } else if (status.IsUnauthenticated()) { RAY_LOG(FATAL) << "GCS returned an authentication error. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 23cedf7eb265..637655b02e29 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -21,6 +21,7 @@ ray_cc_library( "//src/ray/common:grpc_util", "//src/ray/common:id", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_absl//absl/synchronization", ], ) @@ -106,6 +107,7 @@ ray_cc_library( "//src/ray/common:id", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token", "//src/ray/stats:stats_metric", "@com_github_grpc_grpc//:grpc++", ], @@ -122,6 +124,7 @@ ray_cc_library( "//src/ray/common:asio", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:network_util", "//src/ray/util:thread_utils", "@com_github_grpc_grpc//:grpc++", @@ -139,6 +142,7 @@ ray_cc_library( deps = [ ":grpc_server", "//src/ray/protobuf:node_manager_cc_grpc", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) @@ -154,6 +158,7 @@ ray_cc_library( "//src/ray/object_manager:object_manager_grpc_client_manager", "//src/ray/protobuf:object_manager_cc_grpc", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", ], diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 29bb21c29ebe..319915f3e17a 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -27,9 +27,12 @@ #include "absl/synchronization/mutex.h" #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/metrics.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/util/thread_utils.h" @@ -71,6 +74,7 @@ class ClientCallImpl : public ClientCall { /// \param[in] callback The callback function to handle the reply. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, + const std::optional &auth_token, std::shared_ptr stats_handle, bool record_stats, int64_t timeout_ms = -1) @@ -85,6 +89,10 @@ class ClientCallImpl : public ClientCall { if (!cluster_id.IsNil()) { context_.AddMetadata(kClusterIdKey, cluster_id.Hex()); } + // Add authentication token if provided + if (auth_token.has_value()) { + auth_token->SetMetadata(context_); + } } Status GetStatus() override { @@ -276,7 +284,12 @@ class ClientCallManager { } auto call = std::make_shared>( - callback, cluster_id_, std::move(stats_handle), record_stats_, method_timeout_ms); + callback, + cluster_id_, + AuthenticationTokenLoader::instance().GetToken(), + std::move(stats_handle), + record_stats_, + method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 542326de0bce..e471bf7e39de 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -26,6 +26,7 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/common.h" #include "ray/util/network_util.h" #include "ray/util/thread_utils.h" @@ -178,12 +179,13 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) } void GrpcServer::RegisterService(std::unique_ptr &&service, - bool token_auth) { + bool cluster_id_auth_enabled) { + if (cluster_id_auth_enabled && cluster_id_.IsNil()) { + RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + } for (int i = 0; i < num_threads_; i++) { - if (token_auth && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; - } - service->InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_); + service->InitServerCallFactories( + cqs_[i], &server_call_factories_, cluster_id_, auth_token_); } services_.push_back(std::move(service)); } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 0727b4d550f3..bf7eb7f8c5d1 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -24,39 +24,44 @@ #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/server_call.h" namespace ray { namespace rpc { /// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no /// limit. -#define _RPC_SERVICE_HANDLER( \ - SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, \ - &SERVICE::AsyncService::Request##HANDLER, \ - service_handler_, \ - &SERVICE##Handler::Handle##HANDLER, \ - cq, \ - main_service_, \ - #SERVICE ".grpc_server." #HANDLER, \ - AUTH_TYPE == AuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ - MAX_ACTIVE_RPCS, \ - RECORD_METRICS)); \ +#define _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, \ + &SERVICE::AsyncService::Request##HANDLER, \ + service_handler_, \ + &SERVICE##Handler::Handle##HANDLER, \ + cq, \ + main_service_, \ + #SERVICE ".grpc_server." #HANDLER, \ + AUTH_TYPE == ClusterIdAuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ + auth_token, \ + MAX_ACTIVE_RPCS, \ + RECORD_METRICS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, true) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, true) /// Define a RPC service handler with gRPC server metrics disabled. #define RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, false) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, false) /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER_CUSTOM_AUTH(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE) \ @@ -90,13 +95,20 @@ class GrpcServer { const uint32_t port, bool listen_to_localhost_only, int num_threads = 1, - int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/) + int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ + std::optional auth_token = std::nullopt) : name_(std::move(name)), port_(port), listen_to_localhost_only_(listen_to_localhost_only), is_shutdown_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { + // Initialize auth token: use provided value or load from AuthenticationTokenLoader + if (auth_token.has_value()) { + auth_token_ = std::move(auth_token.value()); + } else { + auth_token_ = AuthenticationTokenLoader::instance().GetToken(); + } Init(); } @@ -157,6 +169,8 @@ class GrpcServer { const bool listen_to_localhost_only_; /// Token representing ID of this cluster. ClusterID cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. @@ -208,10 +222,13 @@ class GrpcService { /// \param[in] cq The grpc completion queue. /// \param[out] server_call_factories The `ServerCallFactory` objects, /// and the maximum number of concurrent requests that this gRPC server can handle. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. virtual void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) = 0; + const ClusterID &cluster_id, + const std::optional &auth_token) = 0; /// The main event loop, to which the service handler functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index fba7780afc69..b819a7e98a13 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -15,9 +15,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/node_manager.grpc.pb.h" #include "src/ray/protobuf/node_manager.pb.h" @@ -29,7 +32,8 @@ class ServerCallFactory; /// TODO(vitsai): Remove this when auth is implemented for node manager #define RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + NodeManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. #define RAY_NODE_MANAGER_RPC_HANDLERS \ @@ -206,7 +210,8 @@ class NodeManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::optional &auth_token) override { RAY_NODE_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index 4d294b483fff..576de9396142 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -15,9 +15,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" #include "src/ray/protobuf/object_manager.pb.h" @@ -28,7 +31,8 @@ namespace rpc { class ServerCallFactory; #define RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(ObjectManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + ObjectManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) #define RAY_OBJECT_MANAGER_RPC_HANDLERS \ RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(Push) \ @@ -76,7 +80,8 @@ class ObjectManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::optional &auth_token) override { RAY_OBJECT_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e79ae6dae22d..b84ab4e22dc2 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -20,13 +20,16 @@ #include #include #include +#include #include #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/metrics.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/stats/metric.h" @@ -34,8 +37,8 @@ namespace ray { namespace rpc { -// Authentication type of ServerCall. -enum class AuthType { +// Cluster ID authentication type of ServerCall. +enum class ClusterIdAuthType { NO_AUTH, // Do not authenticate (accept all). LAZY_AUTH, // Accept missing cluster ID, but reject incorrect one. EMPTY_AUTH, // Accept only empty cluster ID. @@ -149,7 +152,7 @@ using HandleRequestFunction = void (ServiceHandler::*)(Request, template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallImpl : public ServerCall { public: /// Constructor. @@ -159,6 +162,8 @@ class ServerCallImpl : public ServerCall { /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. /// \param[in] preprocess_function If not nullptr, it will be called before handling /// request. @@ -169,6 +174,7 @@ class ServerCallImpl : public ServerCall { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::optional &auth_token, bool record_metrics, std::function preprocess_function = nullptr) : state_(ServerCallState::PENDING), @@ -179,6 +185,7 @@ class ServerCallImpl : public ServerCall { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(auth_token), start_time_(0), record_metrics_(record_metrics) { reply_ = google::protobuf::Arena::CreateMessage(&arena_); @@ -194,8 +201,18 @@ class ServerCallImpl : public ServerCall { void HandleRequest() override { stats_handle_ = io_service_.stats().RecordStart(call_name_); bool auth_success = true; + bool token_auth_failed = false; + bool cluster_id_auth_failed = false; + + // Token authentication + if (!ValidateBearerToken()) { + auth_success = false; + token_auth_failed = true; + } + + // Cluster ID authentication if (::RayConfig::instance().enable_cluster_auth()) { - if constexpr (EnableAuth == AuthType::LAZY_AUTH) { + if constexpr (EnableAuth == ClusterIdAuthType::LAZY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -203,8 +220,9 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Wrong cluster ID token in request! Expected: " << cluster_id_.Hex() << ", but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } - } else if constexpr (EnableAuth == AuthType::EMPTY_AUTH) { + } else if constexpr (EnableAuth == ClusterIdAuthType::EMPTY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -212,6 +230,7 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Cluster ID token in request! Expected Nil, " << "but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } } } @@ -221,24 +240,32 @@ class ServerCallImpl : public ServerCall { grpc_server_req_handling_counter_.Record(1.0, {{"Method", call_name_}}); } if (!io_service_.stopped()) { - io_service_.post([this, auth_success] { HandleRequestImpl(auth_success); }, - call_name_ + ".HandleRequestImpl", - // Implement the delay of the rpc server call as the - // delay of HandleRequestImpl(). - ray::asio::testing::GetDelayUs(call_name_)); + io_service_.post( + [this, auth_success, token_auth_failed, cluster_id_auth_failed] { + HandleRequestImpl(auth_success, token_auth_failed, cluster_id_auth_failed); + }, + call_name_ + ".HandleRequestImpl", + // Implement the delay of the rpc server call as the + // delay of HandleRequestImpl(). + ray::asio::testing::GetDelayUs(call_name_)); } else { // Handle service for rpc call has stopped, we must handle the call here // to send reply and remove it from cq RAY_LOG(DEBUG) << "Handle service has been closed."; if (auth_success) { SendReply(Status::Invalid("HandleServiceClosed")); + } else if (token_auth_failed) { + SendReply(Status::Unauthenticated( + "InvalidAuthToken: Authentication token is missing or incorrect")); } else { - SendReply(Status::AuthError("WrongClusterID")); + SendReply(Status::Unauthenticated("WrongClusterID")); } } } - void HandleRequestImpl(bool auth_success) { + void HandleRequestImpl(bool auth_success, + bool token_auth_failed, + bool cluster_id_auth_failed) { if constexpr (std::is_base_of_v) { if (!service_handler_initialized_) { service_handler_.WaitUntilInitialized(); @@ -254,10 +281,15 @@ class ServerCallImpl : public ServerCall { factory_.CreateCall(); } if (!auth_success) { - boost::asio::post(GetServerCallExecutor(), [this]() { - SendReply( - Status::AuthError("WrongClusterID: Perhaps the client is accessing GCS " - "after it has restarted.")); + boost::asio::post(GetServerCallExecutor(), [this, token_auth_failed]() { + if (token_auth_failed) { + SendReply(Status::Unauthenticated( + "InvalidAuthToken: Authentication token is missing or incorrect")); + } else { + SendReply(Status::Unauthenticated( + "WrongClusterID: Perhaps the client is accessing GCS " + "after it has restarted.")); + } }); } else { (service_handler_.*handle_request_function_)( @@ -306,6 +338,32 @@ class ServerCallImpl : public ServerCall { const ServerCallFactory &GetServerCallFactory() override { return factory_; } private: + /// Validates token-based authentication. + /// Returns true if authentication succeeds or is not required. + /// Returns false if authentication is required but fails. + bool ValidateBearerToken() { + if (!auth_token_.has_value() || auth_token_->empty()) { + return true; // No auth required + } + + const auto &metadata = context_.client_metadata(); + auto it = metadata.find(kAuthTokenKey); + if (it == metadata.end()) { + RAY_LOG(WARNING) << "Missing authorization header in request!"; + return false; + } + + const std::string_view header(it->second.data(), it->second.length()); + AuthenticationToken provided_token = AuthenticationToken::FromMetadata(header); + + if (!auth_token_->Equals(provided_token)) { + RAY_LOG(WARNING) << "Invalid bearer token in request!"; + return false; + } + + return true; + } + /// Log the duration this query used void LogProcessTime() { EventTracker::RecordEnd(std::move(stats_handle_)); @@ -373,6 +431,9 @@ class ServerCallImpl : public ServerCall { /// Check skipped if empty. const ClusterID &cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; + /// The callback when sending reply successes. std::function send_reply_success_callback_ = nullptr; @@ -397,7 +458,7 @@ class ServerCallImpl : public ServerCall { ray::stats::Count grpc_server_req_failed_counter_{ GetGrpcServerReqFailedCounterMetric()}; - template + template friend class ServerCallFactoryImpl; }; @@ -425,7 +486,7 @@ template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallFactoryImpl : public ServerCallFactory { using AsyncService = typename GrpcService::AsyncService; @@ -440,6 +501,8 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] cq The `CompletionQueue`. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] max_active_rpcs Maximum request number to handle at the same time. -1 /// means no limit. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. @@ -452,6 +515,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::optional &auth_token, int64_t max_active_rpcs, bool record_metrics) : service_(service), @@ -462,6 +526,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(auth_token), max_active_rpcs_(max_active_rpcs), record_metrics_(record_metrics) {} @@ -475,6 +540,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_, call_name_, cluster_id_, + auth_token_, record_metrics_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. @@ -514,6 +580,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// Check skipped if empty. const ClusterID cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; + /// Maximum request number to handle at the same time. /// -1 means no limit. uint64_t max_active_rpcs_; diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc new file mode 100644 index 000000000000..5501e6e568b0 --- /dev/null +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -0,0 +1,221 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" + +namespace ray { +namespace rpc { + +class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { + public: + void SetUp() override { + // Configure token auth via RayConfig + std::string config_json = R"({"auth_mode": "token"})"; + RayConfig::instance().initialize(config_json); + AuthenticationTokenLoader::instance().ResetCache(); + } + + void SetUpServerAndClient(const std::string &server_token, + const std::string &client_token) { + // Set client token in environment for ClientCallManager to read from + // AuthenticationTokenLoader + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + } else { + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + unsetenv("RAY_AUTH_TOKEN"); + } + + // Start client thread FIRST + client_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + client_io_service_work_(client_io_service_.get_executor()); + client_io_service_.run(); + }); + + // Start handler thread for server + handler_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + handler_io_service_work_(handler_io_service_.get_executor()); + handler_io_service_.run(); + }); + + // Create and start server + // Pass server token explicitly for testing scenarios with different tokens + std::optional server_auth_token; + if (!server_token.empty()) { + server_auth_token = AuthenticationToken(server_token); + } else { + // Explicitly set empty token (no auth required) + server_auth_token = AuthenticationToken(""); + } + grpc_server_.reset(new GrpcServer("test", 0, true, 1, 7200000, server_auth_token)); + grpc_server_->RegisterService( + std::make_unique(handler_io_service_, test_service_handler_), + false); + grpc_server_->Run(); + + while (grpc_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Create client (will read auth token from AuthenticationTokenLoader which reads the + // environment) + client_call_manager_.reset( + new ClientCallManager(client_io_service_, false, /*local_address=*/"")); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); + } + + void TearDown() override { + if (grpc_client_) { + grpc_client_.reset(); + } + if (client_call_manager_) { + client_call_manager_.reset(); + } + if (client_thread_) { + client_io_service_.stop(); + if (client_thread_->joinable()) { + client_thread_->join(); + } + } + + if (grpc_server_) { + grpc_server_->Shutdown(); + } + if (handler_thread_) { + handler_io_service_.stop(); + if (handler_thread_->joinable()) { + handler_thread_->join(); + } + } + + // Clean up environment variables + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + // Reset the token loader for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + // Helper to execute RPC and wait for result + struct PingResult { + bool completed; + bool success; + std::string error_msg; + }; + + PingResult ExecutePingAndWait() { + PingRequest request; + auto result_promise = std::make_shared>(); + std::future result_future = result_promise->get_future(); + + Ping(request, [result_promise](const Status &status, const PingReply &reply) { + RAY_LOG(INFO) << "Token auth test replied, status=" << status; + bool success = status.ok(); + std::string error_msg = status.ok() ? "" : status.message(); + result_promise->set_value({true, success, error_msg}); + }); + + // Wait for response with timeout + if (result_future.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) { + return {false, false, "Request timed out"}; + } + + return result_future.get(); + } + + protected: + VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) + + TestServiceHandler test_service_handler_; + instrumented_io_context handler_io_service_; + std::unique_ptr handler_thread_; + std::unique_ptr grpc_server_; + + instrumented_io_context client_io_service_; + std::unique_ptr client_thread_; + std::unique_ptr client_call_manager_; + std::unique_ptr> grpc_client_; +}; + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { + // Both server and client have the same token + const std::string token = "test_secret_token_123"; + SetUpServerAndClient(token, token); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) << "Request should succeed with matching token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { + // Server and client have different tokens + SetUpServerAndClient("server_token", "wrong_client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; + ASSERT_TRUE(result.error_msg.find( + "InvalidAuthToken: Authentication token is missing or incorrect") != + std::string::npos) + << "Error message should contain token auth error. Got: " << result.error_msg; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { + // Server expects token, client doesn't send one (empty token) + SetUpServerAndClient("server_token", ""); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // If the server has a token but the client doesn't, auth should fail + ASSERT_FALSE(result.success) + << "Request should fail when client doesn't provide required token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, + TestClientProvidesTokenServerDoesNotRequire) { + // Client provides token, but server doesn't require one (should succeed) + SetUpServerAndClient("", "client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // Server should accept request even though client sent unnecessary token + ASSERT_TRUE(result.success) + << "Request should succeed when server doesn't require token"; +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/grpc_bench/BUILD.bazel b/src/ray/rpc/tests/grpc_bench/BUILD.bazel index 5238a11c0baf..4594e3873c5f 100644 --- a/src/ray/rpc/tests/grpc_bench/BUILD.bazel +++ b/src/ray/rpc/tests/grpc_bench/BUILD.bazel @@ -28,5 +28,6 @@ cc_binary( ":helloworld_cc_lib", "//src/ray/common:asio", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc index 552b83bff3bc..81dd9477f948 100644 --- a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc +++ b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc @@ -13,10 +13,12 @@ // limitations under the License. #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/rpc/test/grpc_bench/helloworld.grpc.pb.h" #include "src/ray/rpc/test/grpc_bench/helloworld.pb.h" @@ -57,9 +59,11 @@ class GreeterGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override{ - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - Greeter, SayHello, -1, AuthType::NO_AUTH)} + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( + Greeter, SayHello, -1, ClusterIdAuthType::NO_AUTH); + } /// The grpc async service object. Greeter::AsyncService service_; diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index f51a80b99f73..8bc6e8284493 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -14,92 +14,16 @@ #include #include -#include +#include #include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" -#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { -class TestServiceHandler { - public: - void HandlePing(PingRequest request, - PingReply *reply, - SendReplyCallback send_reply_callback) { - RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); - request_count++; - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - if (request.no_reply()) { - RAY_LOG(INFO) << "No reply!"; - return; - } - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - void HandlePingTimeout(PingTimeoutRequest request, - PingTimeoutReply *reply, - SendReplyCallback send_reply_callback) { - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - std::atomic request_count{0}; - std::atomic reply_failure_count{0}; - std::atomic frozen{false}; -}; - -class TestGrpcService : public GrpcService { - public: - /// Constructor. - /// - /// \param[in] handler The service handler that actually handle the requests. - explicit TestGrpcService(instrumented_io_context &handler_io_service_, - TestServiceHandler &handler) - : GrpcService(handler_io_service_), service_handler_(handler){}; - - protected: - grpc::Service &GetGrpcService() override { return service_; } - - void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector> *server_call_factories, - const ClusterID &cluster_id) override { - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, Ping, /*max_active_rpcs=*/1, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, PingTimeout, /*max_active_rpcs=*/1, AuthType::NO_AUTH); - } - - private: - /// The grpc async service object. - TestService::AsyncService service_; - /// The service handler that actually handle the requests. - TestServiceHandler &service_handler_; -}; class TestGrpcServerClientFixture : public ::testing::Test { public: @@ -326,6 +250,7 @@ TEST_F(TestGrpcServerClientFixture, TestTimeoutMacro) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/tests/grpc_test_common.h b/src/ray/rpc/tests/grpc_test_common.h new file mode 100644 index 000000000000..1ce199f79511 --- /dev/null +++ b/src/ray/rpc/tests/grpc_test_common.h @@ -0,0 +1,109 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" + +namespace ray { +namespace rpc { + +class TestServiceHandler { + public: + void HandlePing(PingRequest request, + PingReply *reply, + SendReplyCallback send_reply_callback) { + RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); + request_count++; + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + if (request.no_reply()) { + RAY_LOG(INFO) << "No reply!"; + return; + } + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + void HandlePingTimeout(PingTimeoutRequest request, + PingTimeoutReply *reply, + SendReplyCallback send_reply_callback) { + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + std::atomic request_count{0}; + std::atomic reply_failure_count{0}; + std::atomic frozen{false}; +}; + +class TestGrpcService : public GrpcService { + public: + /// Constructor. + /// + /// \param[in] handler The service handler that actually handle the requests. + explicit TestGrpcService(instrumented_io_context &handler_io_service_, + TestServiceHandler &handler) + : GrpcService(handler_io_service_), service_handler_(handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector> *server_call_factories, + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, PingTimeout, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + } + + private: + /// The grpc async service object. + TestService::AsyncService service_; + /// The service handler that actually handle the requests. + TestServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray From c821c21aca7127bb3d5917f11069ea85589c8bab Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:25:26 +0000 Subject: [PATCH 41/94] revert unneeded changs from src/ray/rpc/tests/BUILD.bazel Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 279b68f91ba3..0e253f612952 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,29 +18,9 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ - "//src/ray/protobuf:test_service_cc_grpc", - "//src/ray/rpc:grpc_client", - "//src/ray/rpc:grpc_server", - "@com_google_googletest//:gtest_main", - ], -) - -ray_cc_test( - name = "grpc_auth_token_tests", - size = "small", - srcs = [ - "grpc_auth_token_tests.cc", - "grpc_test_common.h", - ], - tags = ["team:core"], - deps = [ - "//src/ray/protobuf:test_service_cc_grpc", - "//src/ray/rpc:grpc_client", - "//src/ray/rpc:grpc_server", "@com_google_googletest//:gtest_main", ], ) From a14dc6951a0f66045867053e349112b6496583c5 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:26:07 +0000 Subject: [PATCH 42/94] readd dependencies Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 0e253f612952..d5113ae0d3aa 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -21,6 +21,9 @@ ray_cc_test( ], tags = ["team:core"], deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", "@com_google_googletest//:gtest_main", ], ) From e340d0788fd25fe295d32f2d698c2622d0eb8de7 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:35:56 +0000 Subject: [PATCH 43/94] fix build issues Signed-off-by: sampan --- python/ray/_private/authentication_token_setup.py | 6 +++++- python/ray/tests/BUILD.bazel | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index 03cc8e0f7565..a422df70c008 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -76,7 +76,11 @@ def setup_and_verify_auth( # Check if you enabled token authentication. if get_authentication_mode() != AuthenticationMode.TOKEN: - if system_config and system_config.get("auth_mode") != "disabled": + if ( + system_config + and "auth_mode" in system_config + and system_config["auth_mode"] != "disabled" + ): raise RuntimeError( "Set authentication mode with the environment, not system_config." ) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index bac342f2f602..b89b42bb8759 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -367,6 +367,7 @@ py_test_module_list( "test_task_metrics.py", "test_tempdir.py", "test_tls_auth.py", + "test_token_auth_integration.py", "test_traceback.py", "test_worker_capping.py", "test_worker_state.py", From 4801ed7353d131bd905a372a27a60cb55a8ec658 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 24 Oct 2025 07:17:54 +0000 Subject: [PATCH 44/94] address comments + fix build Signed-off-by: sampan --- .../rpc/authentication/authentication_token.h | 5 +-- .../authentication_token_loader.cc | 44 +++++++++++-------- .../tests/authentication_token_loader_test.cc | 8 ++-- .../rpc/tests/authentication_token_test.cc | 11 ----- 4 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index 6846d3c08ada..4f32310784de 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -92,8 +92,8 @@ class AuthenticationToken { /// prefix) /// @return AuthenticationToken object (empty if format invalid) static AuthenticationToken FromMetadata(std::string_view metadata_value) { - const std::string_view prefix(kBearerPrefix, sizeof(kBearerPrefix) - 1); - if (metadata_value.size() <= prefix.size() || + const std::string_view prefix(kBearerPrefix); + if (metadata_value.size() < prefix.size() || metadata_value.substr(0, prefix.size()) != prefix) { return AuthenticationToken(); // Invalid format, return empty } @@ -145,7 +145,6 @@ class AuthenticationToken { } void MoveFrom(AuthenticationToken &&other) noexcept { - SecureClear(); secret_ = std::move(other.secret_); // Clear the moved-from object explicitly for security // Note: 'other' is already an rvalue reference, no need to move again diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 59a1184e080a..621f28fe351c 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -20,11 +20,6 @@ #include "ray/util/logging.h" -#if defined(__APPLE__) || defined(__linux__) -#include -#include -#endif - #ifdef _WIN32 #ifndef _WINDOWS_ #ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related @@ -63,9 +58,10 @@ std::optional AuthenticationTokenLoader::GetToken() { // If no token found and auth is enabled, fail with RAY_CHECK RAY_CHECK(!token.empty()) - << "Token authentication is enabled but no authentication token was found. " - << "Please set RAY_AUTH_TOKEN environment variable, RAY_AUTH_TOKEN_PATH to a file " - << "containing the token, or create a token file at ~/.ray/auth_token"; + << "Token authentication is enabled but Ray couldn't find an authentication token. " + << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " + "point to a file with the token, " + << "or create a token file at ~/.ray/auth_token."; // Cache and return the loaded token cached_token_ = std::move(token); @@ -89,22 +85,26 @@ std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { // Precedence 1: RAY_AUTH_TOKEN environment variable const char *env_token = std::getenv("RAY_AUTH_TOKEN"); - if (env_token != nullptr && std::string(env_token).length() > 0) { - RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " - "variable"; - return AuthenticationToken(TrimWhitespace(std::string(env_token))); + if (env_token != nullptr) { + std::string token_str(env_token); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return AuthenticationToken(TrimWhitespace(token_str)); + } } // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); - if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { - std::string token_str = TrimWhitespace(ReadTokenFromFile(env_token_path)); - if (!token_str.empty()) { - RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + if (env_token_path != nullptr) { + std::string path_str(env_token_path); + if (!path_str.empty()) { + std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); + RAY_CHECK(!token_str.empty()) + << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " + << path_str; + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); - } else { - RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " - << env_token_path; } } @@ -159,6 +159,12 @@ std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { std::string whitespace = " \t\n\r\f\v"; std::string trimmed_str = str; trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + + // if the string is empty, return it + if (trimmed_str.empty()) { + return trimmed_str; + } + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); return trimmed_str; } diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 616a13b0e457..2332c6d09313 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -109,8 +109,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { void set_env_var(const char *name, const char *value) { #ifdef _WIN32 - std::string env_str = std::string(name) + "=" + std::string(value); - _putenv(env_str.c_str()); + _putenv_s(name, value); #else setenv(name, value, 1); #endif @@ -118,8 +117,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { void unset_env_var(const char *name) { #ifdef _WIN32 - std::string env_str = std::string(name) + "="; - _putenv(env_str.c_str()); + _putenv_s(name, "") #else unsetenv(name); #endif @@ -301,7 +299,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { auto &loader = AuthenticationTokenLoader::instance(); loader.GetToken(); }, - "Token authentication is enabled but no authentication token was found"); + "Token authentication is enabled but Ray couldn't find an authentication token."); } TEST_F(AuthenticationTokenLoaderTest, TestCaching) { diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc index db88d7481da1..77ae4eb7cfc2 100644 --- a/src/ray/rpc/tests/authentication_token_test.cc +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -59,17 +59,6 @@ TEST_F(AuthenticationTokenTest, TestMoveAssignment) { EXPECT_TRUE(token1.empty()); } -TEST_F(AuthenticationTokenTest, TestSelfMoveAssignment) { - AuthenticationToken token("test-token"); - - // Self-assignment should not break the token - token = std::move(token); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected("test-token"); - EXPECT_TRUE(token.Equals(expected)); -} - TEST_F(AuthenticationTokenTest, TestEquals) { AuthenticationToken token1("same-token"); AuthenticationToken token2("same-token"); From d24f23c3bb3049b41c7046b14685ba02233217d6 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 24 Oct 2025 09:36:31 +0000 Subject: [PATCH 45/94] address comments Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 2 +- src/ray/rpc/server_call.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index e471bf7e39de..785d65be741b 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -181,7 +181,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { if (cluster_id_auth_enabled && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + RAY_LOG(FATAL) << "Expected cluster ID for cluster ID authentication!"; } for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index b84ab4e22dc2..bb7a431934cf 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -211,7 +211,7 @@ class ServerCallImpl : public ServerCall { } // Cluster ID authentication - if (::RayConfig::instance().enable_cluster_auth()) { + if (auth_success && ::RayConfig::instance().enable_cluster_auth()) { if constexpr (EnableAuth == ClusterIdAuthType::LAZY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); From e9cc57f0878efec5b0db7d4a936fbd9076d8dda9 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 05:58:18 +0000 Subject: [PATCH 46/94] address comments Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 4 +--- src/ray/rpc/server_call.h | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 785d65be741b..daeebff99e28 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,9 +180,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - if (cluster_id_auth_enabled && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for cluster ID authentication!"; - } + RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( cqs_[i], &server_call_factories_, cluster_id_, auth_token_); diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index bb7a431934cf..b691cc52fe09 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -205,7 +205,7 @@ class ServerCallImpl : public ServerCall { bool cluster_id_auth_failed = false; // Token authentication - if (!ValidateBearerToken()) { + if (!ValidateAuthenticationToken()) { auth_success = false; token_auth_failed = true; } @@ -341,7 +341,7 @@ class ServerCallImpl : public ServerCall { /// Validates token-based authentication. /// Returns true if authentication succeeds or is not required. /// Returns false if authentication is required but fails. - bool ValidateBearerToken() { + bool ValidateAuthenticationToken() { if (!auth_token_.has_value() || auth_token_->empty()) { return true; // No auth required } From f8c08e0c3018de9e1f685894f2a6d7ad21df4494 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 06:26:24 +0000 Subject: [PATCH 47/94] fix lint Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index daeebff99e28..781c2da790df 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,7 +180,8 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; + RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) + << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( cqs_[i], &server_call_factories_, cluster_id_, auth_token_); From a7a8efa42c3dfb7bd095f30b50a7eb17ae3d3801 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 06:55:55 +0000 Subject: [PATCH 48/94] fix ci Signed-off-by: sampan --- src/ray/common/grpc_util.h | 4 ++++ src/ray/rpc/grpc_server.cc | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index ed2f8c73eda1..52858cca2207 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -110,6 +110,10 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { // status code. return {StatusCode::TimedOut, GrpcStatusToRayStatusMessage(grpc_status)}; } + if (grpc_status.error_code() == grpc::StatusCode::UNAUTHENTICATED) { + // UNAUTHENTICATED means authentication failed (e.g., wrong cluster ID). + return Status::Unauthenticated(GrpcStatusToRayStatusMessage(grpc_status)); + } if (grpc_status.error_code() == grpc::StatusCode::ABORTED) { // This is a status generated by ray code. // See RayStatusToGrpcStatus for details. diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 781c2da790df..5809cc005783 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,7 +180,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) + RAY_CHECK(!cluster_id_auth_enabled || !cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( From 5910ecfca9f961744c1747553782ea4b081c0a86 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 02:59:48 +0000 Subject: [PATCH 49/94] fix build.bazel and imports Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 17 +++++++++++++++++ src/ray/rpc/tests/grpc_auth_token_tests.cc | 8 ++++---- src/ray/rpc/tests/grpc_server_client_test.cc | 8 ++++---- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index d5113ae0d3aa..279b68f91ba3 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,6 +18,23 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", + "grpc_test_common.h", + ], + tags = ["team:core"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "grpc_auth_token_tests", + size = "small", + srcs = [ + "grpc_auth_token_tests.cc", + "grpc_test_common.h", ], tags = ["team:core"], deps = [ diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 5501e6e568b0..1c55dfb511a5 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -19,11 +19,11 @@ #include #include "gtest/gtest.h" -#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/authentication/authentication_token_loader.h" -#include "ray/rpc/grpc_client.h" -#include "ray/rpc/grpc_server.h" -#include "ray/rpc/tests/grpc_test_common.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/grpc_client.h" +#include "src/ray/rpc/grpc_server.h" +#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 8bc6e8284493..07e87c1a2f44 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -17,10 +17,10 @@ #include #include "gtest/gtest.h" -#include "ray/protobuf/test_service.grpc.pb.h" -#include "ray/rpc/grpc_client.h" -#include "ray/rpc/grpc_server.h" -#include "ray/rpc/tests/grpc_test_common.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/grpc_client.h" +#include "src/ray/rpc/grpc_server.h" +#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { From d36e22fd6c26814afda5c2ac665fb08d5badd4f0 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 03:05:22 +0000 Subject: [PATCH 50/94] fix lint Signed-off-by: sampan --- src/ray/rpc/tests/grpc_auth_token_tests.cc | 4 ++-- src/ray/rpc/tests/grpc_server_client_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 1c55dfb511a5..5feaf5563add 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -20,9 +20,9 @@ #include "gtest/gtest.h" #include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/grpc_client.h" -#include "src/ray/rpc/grpc_server.h" #include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 07e87c1a2f44..0e95fc9823a5 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -17,9 +17,9 @@ #include #include "gtest/gtest.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/grpc_client.h" -#include "src/ray/rpc/grpc_server.h" #include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { From 4063d743d06a8cc9a79acb2f3f02674c8ceea58b Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 03:13:18 +0000 Subject: [PATCH 51/94] fix lint issues Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 17 ++++++++++++++--- src/ray/rpc/tests/grpc_auth_token_tests.cc | 2 +- src/ray/rpc/tests/grpc_server_client_test.cc | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 279b68f91ba3..6f12ac30f65e 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "ray_cc_test") +load("//bazel:ray.bzl", "ray_cc_library", "ray_cc_test") ray_cc_test( name = "rpc_chaos_test", @@ -13,15 +13,25 @@ ray_cc_test( ], ) +ray_cc_library( + name = "grpc_test_common", + testonly = True, + hdrs = ["grpc_test_common.h"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_server", + ], +) + ray_cc_test( name = "grpc_server_client_test", size = "small", srcs = [ "grpc_server_client_test.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ + ":grpc_test_common", "//src/ray/protobuf:test_service_cc_grpc", "//src/ray/rpc:grpc_client", "//src/ray/rpc:grpc_server", @@ -34,13 +44,14 @@ ray_cc_test( size = "small", srcs = [ "grpc_auth_token_tests.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ + ":grpc_test_common", "//src/ray/protobuf:test_service_cc_grpc", "//src/ray/rpc:grpc_client", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_googletest//:gtest_main", ], ) diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 5feaf5563add..4499b4c43129 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -22,8 +22,8 @@ #include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 0e95fc9823a5..09a168eac9b5 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -19,8 +19,8 @@ #include "gtest/gtest.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { From 9537a00c967f2fcba65c68d0da0df17eac1bb774 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 03:25:22 +0000 Subject: [PATCH 52/94] address comments Signed-off-by: sampan --- .../_private/authentication_token_setup.py | 17 +++---- python/ray/_raylet.pyx | 44 +----------------- python/ray/includes/common.pxd | 14 ------ .../ray/includes/rpc_token_authentication.pxd | 15 +++++++ .../ray/includes/rpc_token_authentication.pxi | 45 +++++++++++++++++++ 5 files changed, 68 insertions(+), 67 deletions(-) create mode 100644 python/ray/includes/rpc_token_authentication.pxd create mode 100644 python/ray/includes/rpc_token_authentication.pxi diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication_token_setup.py index a422df70c008..5115fd632d43 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication_token_setup.py @@ -6,11 +6,16 @@ """ import logging -import os import uuid from pathlib import Path from typing import Any, Dict, Optional +from ray._raylet import ( + AuthenticationMode, + AuthenticationTokenLoader, + get_authentication_mode, +) + logger = logging.getLogger(__name__) @@ -31,12 +36,9 @@ def generate_and_save_token() -> None: # Write token to file with explicit flush with open(token_path, "w") as f: f.write(token) - f.flush() - os.fsync(f.fileno()) logger.info(f"Generated new authentication token and saved to {token_path}") - except Exception as e: - logger.warning(f"Failed to save generated token to {token_path}: {e}. ") + except Exception: raise @@ -68,11 +70,6 @@ def setup_and_verify_auth( RuntimeError: Ray raises this error if authentication is enabled but no token is found when connecting to an existing cluster. """ - from ray._raylet import ( - AuthenticationMode, - AuthenticationTokenLoader, - get_authentication_mode, - ) # Check if you enabled token authentication. if get_authentication_mode() != AuthenticationMode.TOKEN: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bcfb09417dd3..d8846c5a9cad 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -114,9 +114,6 @@ from ray.includes.common cimport ( CConcurrencyGroup, CGrpcStatusCode, CLineageReconstructionTask, - CAuthenticationMode, - GetAuthenticationMode, - CAuthenticationTokenLoader, move, LANGUAGE_CPP, LANGUAGE_JAVA, @@ -204,6 +201,7 @@ include "includes/metric.pxi" include "includes/setproctitle.pxi" include "includes/raylet_client.pxi" include "includes/gcs_subscriber.pxi" +include "includes/rpc_token_authentication.pxi" import ray from ray.exceptions import ( @@ -4950,43 +4948,3 @@ def get_session_key_from_storage(host, port, username, password, use_ssl, config else: logger.info("Could not retrieve session key from storage.") return None - - -# Authentication mode enum exposed to Python -class AuthenticationMode: - DISABLED = CAuthenticationMode.DISABLED - TOKEN = CAuthenticationMode.TOKEN - - -def get_authentication_mode(): - """Get the current authentication mode. - - Returns: - AuthenticationMode enum value (DISABLED or TOKEN) - """ - return GetAuthenticationMode() - - -class AuthenticationTokenLoader: - """Python wrapper for C++ AuthenticationTokenLoader singleton.""" - - @staticmethod - def instance(): - """Get the singleton instance (returns a wrapper for convenience).""" - return AuthenticationTokenLoader() - - def has_token(self): - """Check if an authentication token exists without crashing. - - Returns: - bool: True if a token exists, False otherwise - """ - return CAuthenticationTokenLoader.instance().HasToken() - - def reset_cache(self): - """Reset the C++ authentication token cache. - - This forces the token loader to reload the token from environment - variables or files on the next request. - """ - CAuthenticationTokenLoader.instance().ResetCache() diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 4ab1c9011a05..3ec069ab333b 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -804,17 +804,3 @@ cdef extern from "ray/common/constants.h" nogil: cdef const char[] kLabelKeyTpuWorkerId cdef const char[] kLabelKeyTpuPodType cdef const char[] kRayInternalNamespacePrefix - -cdef extern from "ray/rpc/authentication/authentication_mode.h" namespace "ray::rpc" nogil: - cdef enum CAuthenticationMode "ray::rpc::AuthenticationMode": - DISABLED "ray::rpc::AuthenticationMode::DISABLED" - TOKEN "ray::rpc::AuthenticationMode::TOKEN" - - CAuthenticationMode GetAuthenticationMode() - -cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespace "ray::rpc" nogil: - cdef cppclass CAuthenticationTokenLoader "ray::rpc::AuthenticationTokenLoader": - @staticmethod - CAuthenticationTokenLoader& instance() - c_bool HasToken() - void ResetCache() diff --git a/python/ray/includes/rpc_token_authentication.pxd b/python/ray/includes/rpc_token_authentication.pxd new file mode 100644 index 000000000000..e5fe409eef1a --- /dev/null +++ b/python/ray/includes/rpc_token_authentication.pxd @@ -0,0 +1,15 @@ +from libcpp cimport bool as c_bool + +cdef extern from "ray/rpc/authentication/authentication_mode.h" namespace "ray::rpc" nogil: + cdef enum CAuthenticationMode "ray::rpc::AuthenticationMode": + DISABLED "ray::rpc::AuthenticationMode::DISABLED" + TOKEN "ray::rpc::AuthenticationMode::TOKEN" + + CAuthenticationMode GetAuthenticationMode() + +cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespace "ray::rpc" nogil: + cdef cppclass CAuthenticationTokenLoader "ray::rpc::AuthenticationTokenLoader": + @staticmethod + CAuthenticationTokenLoader& instance() + c_bool HasToken() + void ResetCache() diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi new file mode 100644 index 000000000000..cf909668d218 --- /dev/null +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -0,0 +1,45 @@ +from ray.includes.rpc_token_authentication cimport ( + CAuthenticationMode, + GetAuthenticationMode, + CAuthenticationTokenLoader, +) + + +# Authentication mode enum exposed to Python +class AuthenticationMode: + DISABLED = CAuthenticationMode.DISABLED + TOKEN = CAuthenticationMode.TOKEN + + +def get_authentication_mode(): + """Get the current authentication mode. + + Returns: + AuthenticationMode enum value (DISABLED or TOKEN) + """ + return GetAuthenticationMode() + + +class AuthenticationTokenLoader: + """Python wrapper for C++ AuthenticationTokenLoader singleton.""" + + @staticmethod + def instance(): + """Get the singleton instance (returns a wrapper for convenience).""" + return AuthenticationTokenLoader() + + def has_token(self): + """Check if an authentication token exists without crashing. + + Returns: + bool: True if a token exists, False otherwise + """ + return CAuthenticationTokenLoader.instance().HasToken() + + def reset_cache(self): + """Reset the C++ authentication token cache. + + This forces the token loader to reload the token from environment + variables or files on the next request. + """ + CAuthenticationTokenLoader.instance().ResetCache() From 0e6f59b1bb5385cf41449e9e46de83bd1ea96712 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 09:06:32 +0000 Subject: [PATCH 53/94] [Core] Token auth support in Dashboard head Signed-off-by: sampan --- python/ray/dashboard/auth_utils.py | 47 ++++++ python/ray/dashboard/authentication_utils.py | 47 ++++++ python/ray/dashboard/http_server_head.py | 30 ++++ .../ray/includes/rpc_token_authentication.pxd | 12 ++ .../ray/includes/rpc_token_authentication.pxi | 26 +++ python/ray/tests/test_dashboard_auth.py | 154 ++++++++++++++++++ 6 files changed, 316 insertions(+) create mode 100644 python/ray/dashboard/auth_utils.py create mode 100644 python/ray/dashboard/authentication_utils.py create mode 100644 python/ray/tests/test_dashboard_auth.py diff --git a/python/ray/dashboard/auth_utils.py b/python/ray/dashboard/auth_utils.py new file mode 100644 index 000000000000..b01ed7c0113f --- /dev/null +++ b/python/ray/dashboard/auth_utils.py @@ -0,0 +1,47 @@ +"""Authentication utilities for Ray dashboard.""" + +from ray.includes.rpc_token_authentication import ( + AuthenticationMode, + get_authentication_mode, + validate_authentication_token, +) + + +def is_token_auth_enabled() -> bool: + """Check if token authentication is enabled. + + Returns: + bool: True if auth_mode is set to "token", False otherwise + """ + return get_authentication_mode() == AuthenticationMode.TOKEN + + +def should_authenticate_request(method: str) -> bool: + """Determine if request method requires authentication. + + Only mutable operations (POST, PUT, PATCH, DELETE) require authentication. + + Args: + method: HTTP method (e.g., "GET", "POST", "PUT", "DELETE") + + Returns: + bool: True if the method requires authentication, False otherwise + """ + return method in ["POST", "PUT", "PATCH", "DELETE"] + + +def validate_request_token(auth_header: str) -> bool: + """Validate the Authorization header from an HTTP request. + + Args: + auth_header: The Authorization header value (e.g., "Bearer ") + + Returns: + bool: True if token is valid, False otherwise + """ + if not auth_header: + return False + + # validate_authentication_token expects full "Bearer " format + # and performs equality comparison via C++ layer + return validate_authentication_token(auth_header) diff --git a/python/ray/dashboard/authentication_utils.py b/python/ray/dashboard/authentication_utils.py new file mode 100644 index 000000000000..b01ed7c0113f --- /dev/null +++ b/python/ray/dashboard/authentication_utils.py @@ -0,0 +1,47 @@ +"""Authentication utilities for Ray dashboard.""" + +from ray.includes.rpc_token_authentication import ( + AuthenticationMode, + get_authentication_mode, + validate_authentication_token, +) + + +def is_token_auth_enabled() -> bool: + """Check if token authentication is enabled. + + Returns: + bool: True if auth_mode is set to "token", False otherwise + """ + return get_authentication_mode() == AuthenticationMode.TOKEN + + +def should_authenticate_request(method: str) -> bool: + """Determine if request method requires authentication. + + Only mutable operations (POST, PUT, PATCH, DELETE) require authentication. + + Args: + method: HTTP method (e.g., "GET", "POST", "PUT", "DELETE") + + Returns: + bool: True if the method requires authentication, False otherwise + """ + return method in ["POST", "PUT", "PATCH", "DELETE"] + + +def validate_request_token(auth_header: str) -> bool: + """Validate the Authorization header from an HTTP request. + + Args: + auth_header: The Authorization header value (e.g., "Bearer ") + + Returns: + bool: True if token is valid, False otherwise + """ + if not auth_header: + return False + + # validate_authentication_token expects full "Bearer " format + # and performs equality comparison via C++ layer + return validate_authentication_token(auth_header) diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 49f748309271..948cd53eecbf 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -20,6 +20,7 @@ from ray._common.network_utils import build_address, parse_address from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop +from ray.dashboard import authentication_utils as auth_utils from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics from ray.dashboard.head import DashboardHeadModule @@ -163,6 +164,34 @@ def get_address(self): assert self.http_host and self.http_port return self.http_host, self.http_port + @aiohttp.web.middleware + async def auth_middleware(self, request, handler): + """Authenticate requests when token auth is enabled.""" + + # Skip if auth not enabled + if not auth_utils.is_token_auth_enabled(): + return await handler(request) + + # Skip if request method doesn't require auth (GET, HEAD) + if not auth_utils.should_authenticate_request(request.method): + return await handler(request) + + # Extract and validate token + auth_header = request.headers.get("Authorization", "") + + if not auth_header: + return aiohttp.web.Response( + status=401, text="Unauthorized: Missing authentication token" + ) + + # Validate token + if not auth_utils.validate_request_token(auth_header): + return aiohttp.web.Response( + status=403, text="Forbidden: Invalid authentication token" + ) + + return await handler(request) + @aiohttp.web.middleware async def path_clean_middleware(self, request, handler): if request.path.startswith("/static") or request.path.startswith("/logs"): @@ -251,6 +280,7 @@ async def run( app = aiohttp.web.Application( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, middlewares=[ + self.auth_middleware, self.metrics_middleware, self.path_clean_middleware, self.browsers_no_post_put_middleware, diff --git a/python/ray/includes/rpc_token_authentication.pxd b/python/ray/includes/rpc_token_authentication.pxd index e5fe409eef1a..b589b2e0c287 100644 --- a/python/ray/includes/rpc_token_authentication.pxd +++ b/python/ray/includes/rpc_token_authentication.pxd @@ -1,4 +1,6 @@ from libcpp cimport bool as c_bool +from libcpp.string cimport string +from ray.includes.optional cimport optional cdef extern from "ray/rpc/authentication/authentication_mode.h" namespace "ray::rpc" nogil: cdef enum CAuthenticationMode "ray::rpc::AuthenticationMode": @@ -7,9 +9,19 @@ cdef extern from "ray/rpc/authentication/authentication_mode.h" namespace "ray:: CAuthenticationMode GetAuthenticationMode() +cdef extern from "ray/rpc/authentication/authentication_token.h" namespace "ray::rpc" nogil: + cdef cppclass CAuthenticationToken "ray::rpc::AuthenticationToken": + CAuthenticationToken() + CAuthenticationToken(string value) + c_bool empty() + c_bool Equals(const CAuthenticationToken& other) + @staticmethod + CAuthenticationToken FromMetadata(string metadata_value) + cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespace "ray::rpc" nogil: cdef cppclass CAuthenticationTokenLoader "ray::rpc::AuthenticationTokenLoader": @staticmethod CAuthenticationTokenLoader& instance() c_bool HasToken() void ResetCache() + optional[CAuthenticationToken] GetToken() diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index cf909668d218..ce70c2d72915 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -1,6 +1,7 @@ from ray.includes.rpc_token_authentication cimport ( CAuthenticationMode, GetAuthenticationMode, + CAuthenticationToken, CAuthenticationTokenLoader, ) @@ -20,6 +21,31 @@ def get_authentication_mode(): return GetAuthenticationMode() +def validate_authentication_token(provided_token: str) -> bool: + """Validate provided authentication token against expected token. + + Args: + provided_token: Full authorization header value (e.g., "Bearer ") + + Returns: + bool: True if tokens match, False otherwise + """ + # Get expected token from loader + cdef optional[CAuthenticationToken] expected_opt = CAuthenticationTokenLoader.instance().GetToken() + + if not expected_opt.has_value(): + return False + + # Parse provided token from Bearer format + cdef CAuthenticationToken provided = CAuthenticationToken.FromMetadata(provided_token.encode()) + + if provided.empty(): + return False + + # Use constant-time comparison from C++ + return expected_opt.value().Equals(provided) + + class AuthenticationTokenLoader: """Python wrapper for C++ AuthenticationTokenLoader singleton.""" diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py new file mode 100644 index 000000000000..5a0748015d52 --- /dev/null +++ b/python/ray/tests/test_dashboard_auth.py @@ -0,0 +1,154 @@ +"""Tests for dashboard token authentication.""" + +import os + +import pytest +import requests + +import ray +from ray.cluster_utils import Cluster + + +@pytest.fixture +def cleanup_env(): + """Clean up environment variables after each test.""" + yield + # Clean up environment variables + if "RAY_auth_mode" in os.environ: + del os.environ["RAY_auth_mode"] + if "RAY_AUTH_TOKEN" in os.environ: + del os.environ["RAY_AUTH_TOKEN"] + + +def test_dashboard_post_requires_auth_with_valid_token(cleanup_env): + """Test that POST requests succeed with valid token when auth is enabled.""" + test_token = "test_token_12345678901234567890123456789012" + os.environ["RAY_auth_mode"] = "token" + os.environ["RAY_AUTH_TOKEN"] = test_token + + cluster = Cluster() + cluster.add_node() + + try: + ray.init(address=cluster.address) + dashboard_url = ray._private.worker._global_node.webui_url + + # POST with valid auth should succeed + headers = {"Authorization": f"Bearer {test_token}"} + response = requests.get(f"http://{dashboard_url}/api/cluster_status") + # GET should work without auth + assert response.status_code == 200 + + finally: + ray.shutdown() + cluster.shutdown() + + +def test_dashboard_post_requires_auth_missing_token(cleanup_env): + """Test that POST requests fail without token when auth is enabled.""" + test_token = "test_token_12345678901234567890123456789012" + os.environ["RAY_auth_mode"] = "token" + os.environ["RAY_AUTH_TOKEN"] = test_token + + cluster = Cluster() + cluster.add_node() + + try: + ray.init(address=cluster.address) + dashboard_url = ray._private.worker._global_node.webui_url + + # POST without auth should fail with 401 + # We need to find a POST endpoint to test + # For now, let's test with a generic endpoint + response = requests.post( + f"http://{dashboard_url}/api/component_activities", + json={"test": "data"}, + ) + assert response.status_code == 401 + + finally: + ray.shutdown() + cluster.shutdown() + + +def test_dashboard_post_requires_auth_invalid_token(cleanup_env): + """Test that POST requests fail with invalid token when auth is enabled.""" + correct_token = "test_token_12345678901234567890123456789012" + wrong_token = "wrong_token_00000000000000000000000000000000" + os.environ["RAY_auth_mode"] = "token" + os.environ["RAY_AUTH_TOKEN"] = correct_token + + cluster = Cluster() + cluster.add_node() + + try: + ray.init(address=cluster.address) + dashboard_url = ray._private.worker._global_node.webui_url + + # POST with wrong token should fail with 403 + headers = {"Authorization": f"Bearer {wrong_token}"} + response = requests.post( + f"http://{dashboard_url}/api/component_activities", + json={"test": "data"}, + headers=headers, + ) + assert response.status_code == 403 + + finally: + ray.shutdown() + cluster.shutdown() + + +def test_dashboard_get_no_auth_required(cleanup_env): + """Test that GET requests don't require auth even when token mode is enabled.""" + test_token = "test_token_12345678901234567890123456789012" + os.environ["RAY_auth_mode"] = "token" + os.environ["RAY_AUTH_TOKEN"] = test_token + + cluster = Cluster() + cluster.add_node() + + try: + ray.init(address=cluster.address) + dashboard_url = ray._private.worker._global_node.webui_url + + # GET without auth should succeed + response = requests.get(f"http://{dashboard_url}/") + assert response.status_code == 200 + + # GET to API endpoint should also succeed + response = requests.get(f"http://{dashboard_url}/api/cluster_status") + assert response.status_code == 200 + + finally: + ray.shutdown() + cluster.shutdown() + + +def test_dashboard_auth_disabled(cleanup_env): + """Test that auth is not enforced when auth_mode is disabled.""" + os.environ["RAY_auth_mode"] = "disabled" + + cluster = Cluster() + cluster.add_node() + + try: + ray.init(address=cluster.address) + dashboard_url = ray._private.worker._global_node.webui_url + + # POST without auth should succeed when auth is disabled + response = requests.post( + f"http://{dashboard_url}/api/component_activities", json={"test": "data"} + ) + # Should not return 401 or 403 + assert response.status_code not in [401, 403] + + finally: + ray.shutdown() + cluster.shutdown() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-vv", __file__])) From e343d54b0aff7f2d8b8bd7e07af79c9ee12e01ad Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 09:09:12 +0000 Subject: [PATCH 54/94] fix lint Signed-off-by: sampan --- python/ray/tests/test_dashboard_auth.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 5a0748015d52..1698afb1ffc2 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -35,8 +35,15 @@ def test_dashboard_post_requires_auth_with_valid_token(cleanup_env): # POST with valid auth should succeed headers = {"Authorization": f"Bearer {test_token}"} - response = requests.get(f"http://{dashboard_url}/api/cluster_status") + response = requests.post( + f"http://{dashboard_url}/api/component_activities", + json={"test": "data"}, + headers=headers, + ) + assert response.status_code == 403 + # GET should work without auth + response = requests.get(f"http://{dashboard_url}/api/cluster_status") assert response.status_code == 200 finally: From 94c5cc679f9e3bb079685f40ad5c57f019f578ad Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 08:50:00 +0000 Subject: [PATCH 55/94] fix tests Signed-off-by: sampan --- .../ray/_private/authentication/__init__.py | 0 .../authentication_token_generator.py | 6 +++ .../authentication_token_setup.py | 38 +++++++------- python/ray/_private/worker.py | 8 +-- .../ray/tests/test_token_auth_integration.py | 51 ++++++++++--------- .../authentication_token_loader.cc | 1 - 6 files changed, 59 insertions(+), 45 deletions(-) create mode 100644 python/ray/_private/authentication/__init__.py create mode 100644 python/ray/_private/authentication/authentication_token_generator.py rename python/ray/_private/{ => authentication}/authentication_token_setup.py (67%) diff --git a/python/ray/_private/authentication/__init__.py b/python/ray/_private/authentication/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/_private/authentication/authentication_token_generator.py b/python/ray/_private/authentication/authentication_token_generator.py new file mode 100644 index 000000000000..331584cb5dd4 --- /dev/null +++ b/python/ray/_private/authentication/authentication_token_generator.py @@ -0,0 +1,6 @@ +import uuid + + +# TODO: this is a placeholder for the actual authentication token generator. Will be replaced with a proper implementation. +def generate_new_authentication_token() -> str: + return uuid.uuid4().hex diff --git a/python/ray/_private/authentication_token_setup.py b/python/ray/_private/authentication/authentication_token_setup.py similarity index 67% rename from python/ray/_private/authentication_token_setup.py rename to python/ray/_private/authentication/authentication_token_setup.py index 5115fd632d43..99d0bb382481 100644 --- a/python/ray/_private/authentication_token_setup.py +++ b/python/ray/_private/authentication/authentication_token_setup.py @@ -6,10 +6,12 @@ """ import logging -import uuid from pathlib import Path from typing import Any, Dict, Optional +from ray._private.authentication.authentication_token_generator import ( + generate_new_authentication_token, +) from ray._raylet import ( AuthenticationMode, AuthenticationTokenLoader, @@ -18,6 +20,13 @@ logger = logging.getLogger(__name__) +TOKEN_AUTH_ENABLED_BUT_NO_TOKEN_FOUND_ERROR_MESSAGE = ( + "Token authentication is enabled but no authentication token was found. Please provide a token with one of these options:\n" + + " 1. RAY_AUTH_TOKEN environment variable\n" + + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" + + " 3. Default token file: ~/.ray/auth_token" +) + def generate_and_save_token() -> None: """Generate a new random token and save it in the default token path. @@ -26,14 +35,14 @@ def generate_and_save_token() -> None: The newly generated authentication token. """ # Generate a UUID-based token - token = uuid.uuid4().hex + token = generate_new_authentication_token() token_path = _get_default_token_path() try: # Create directory if it doesn't exist token_path.parent.mkdir(parents=True, exist_ok=True) - # Write token to file with explicit flush + # Write token to file with explicit flush and fsync with open(token_path, "w") as f: f.write(token) @@ -51,10 +60,10 @@ def _get_default_token_path() -> Path: return Path.home() / ".ray" / "auth_token" -def setup_and_verify_auth( - system_config: Optional[Dict[str, Any]] = None, is_new_cluster: bool = True +def ensure_token_if_auth_enabled( + system_config: Optional[Dict[str, Any]] = None, create_token_if_missing: bool = True ) -> None: - """Check authentication settings and setup necessary resources. + """Check authentication settings and set up token resources if authentication is enabled. Ray calls this early during ray.init() to do the following for token-based authentication: 1. Check whether you enabled token-based authentication. @@ -63,8 +72,7 @@ def setup_and_verify_auth( Args: system_config: Ray raises an error if you set auth_mode in system_config instead of the environment. - is_new_cluster: Set to True if you're starting a new local cluster, or False if you're connecting - to an existing cluster. + create_token_if_missing: Generate a new token if one doesn't already exist. Raises: RuntimeError: Ray raises this error if authentication is enabled but no token is found when connecting @@ -79,24 +87,18 @@ def setup_and_verify_auth( and system_config["auth_mode"] != "disabled" ): raise RuntimeError( - "Set authentication mode with the environment, not system_config." + "Set authentication mode can only be set with the `RAY_auth_mode` environment variable, not using the system_config." ) return token_loader = AuthenticationTokenLoader.instance() if not token_loader.has_token(): - if is_new_cluster: - # Generate a token for a new local cluster. + if create_token_if_missing: + # Generate a new token. generate_and_save_token() # Reload the cache so subsequent calls to token_loader read the new token. token_loader.reset_cache() else: - # You're connecting to an existing cluster, so an authentication token must already exist. - raise RuntimeError( - "Token authentication is enabled but no authentication token was found. Please provide a token with one of these options:\n" - " 1. RAY_AUTH_TOKEN environment variable\n" - " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" - " 3. Default token file: ~/.ray/auth_token" - ) + raise RuntimeError(TOKEN_AUTH_ENABLED_BUT_NO_TOKEN_FOUND_ERROR_MESSAGE) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index c0613827203c..e68e8e270de1 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -62,7 +62,9 @@ from ray._common import ray_option_utils from ray._common.constants import RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR from ray._common.utils import load_class -from ray._private.authentication_token_setup import setup_and_verify_auth +from ray._private.authentication.authentication_token_setup import ( + ensure_token_if_auth_enabled, +) from ray._private.client_mode_hook import client_mode_hook from ray._private.custom_types import TensorTransportEnum from ray._private.function_manager import FunctionActorManager @@ -1867,7 +1869,7 @@ def sigterm_handler(signum, frame): # In this case, we need to start a new cluster. # Setup and verify authentication for new cluster - setup_and_verify_auth(_system_config, is_new_cluster=True) + ensure_token_if_auth_enabled(_system_config, create_token_if_missing=True) # Don't collect usage stats in ray.init() unless it's a nightly wheel. from ray._common.usage import usage_lib @@ -1957,7 +1959,7 @@ def sigterm_handler(signum, frame): ) # Setup and verify authentication for connecting to existing cluster - setup_and_verify_auth(_system_config, is_new_cluster=False) + ensure_token_if_auth_enabled(_system_config, create_token_if_missing=False) # In this case, we only need to connect the node. ray_params = ray._private.parameter.RayParams( diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 6be082bf214b..6c120b2b9c51 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -7,7 +7,7 @@ import pytest import ray -from ray._raylet import AuthenticationTokenLoader +from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import Cluster @@ -66,6 +66,7 @@ def test_local_cluster_generates_token(): # Enable token auth via environment variable os.environ["RAY_auth_mode"] = "token" + Config.initialize("") # Initialize Ray with token auth ray.init() @@ -86,30 +87,33 @@ def test_local_cluster_generates_token(): def test_connect_without_token_raises_error(): """Test ray.init(address=...) without token fails when auth_mode=token is set.""" - # Test the token validation logic directly - # Ensure no token exists - token_loader = AuthenticationTokenLoader.instance() - assert not token_loader.has_token() - - # Test the exact error message that would be raised - with pytest.raises(RuntimeError, match="no authentication token was found"): - raise RuntimeError( - "Token authentication is enabled but no authentication token was found. Please provide a token using one of:\n" - " 1. RAY_AUTH_TOKEN environment variable\n" - " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" - " 3. Default token file: ~/.ray/auth_token" - ) - - -def test_token_path_nonexistent_file_fails(): - """Test that setting RAY_AUTH_TOKEN_PATH to nonexistent file fails gracefully.""" - # Enable token auth and set token path to nonexistent file + # Set up a cluster with token auth enabled + cluster_token = "testtoken12345678901234567890" + os.environ["RAY_AUTH_TOKEN"] = cluster_token os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" + Config.initialize("") - # Initialize Ray with token auth should fail - with pytest.raises((FileNotFoundError, RuntimeError)): - ray.init() + # Create cluster with token auth enabled + cluster = Cluster() + cluster.add_node() + + try: + # Remove the token from the environment so we try to connect without it + os.environ["RAY_auth_mode"] = "disabled" + os.environ["RAY_AUTH_TOKEN"] = "" + Config.initialize("") + reset_token_cache() + + # Ensure no token exists + token_loader = AuthenticationTokenLoader.instance() + assert not token_loader.has_token() + + # Try to connect to the cluster without a token - should raise RuntimeError + with pytest.raises(ConnectionError): + ray.init(address=cluster.address) + + finally: + cluster.shutdown() @pytest.mark.parametrize("tokens_match", [True, False]) @@ -119,6 +123,7 @@ def test_cluster_token_authentication(tokens_match): cluster_token = "a" * 32 os.environ["RAY_AUTH_TOKEN"] = cluster_token os.environ["RAY_auth_mode"] = "token" + Config.initialize("") # Create cluster with token auth enabled - node will read current env token cluster = Cluster() diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index c4e71dd67b05..a7a6ae1a23ab 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -87,7 +87,6 @@ bool AuthenticationTokenLoader::HasToken() { // Cache the result if (token.empty()) { - cached_token_ = std::nullopt; return false; } else { cached_token_ = std::move(token); From e34a8bd1160a138ed2489ef52c2ae46575de5181 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 09:24:32 +0000 Subject: [PATCH 56/94] fix tests and address comments Signed-off-by: sampan --- python/ray/dashboard/auth_utils.py | 47 ----------- python/ray/dashboard/authentication_utils.py | 18 +--- python/ray/dashboard/http_server_head.py | 6 +- python/ray/tests/test_dashboard_auth.py | 86 ++++++-------------- 4 files changed, 29 insertions(+), 128 deletions(-) delete mode 100644 python/ray/dashboard/auth_utils.py diff --git a/python/ray/dashboard/auth_utils.py b/python/ray/dashboard/auth_utils.py deleted file mode 100644 index b01ed7c0113f..000000000000 --- a/python/ray/dashboard/auth_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Authentication utilities for Ray dashboard.""" - -from ray.includes.rpc_token_authentication import ( - AuthenticationMode, - get_authentication_mode, - validate_authentication_token, -) - - -def is_token_auth_enabled() -> bool: - """Check if token authentication is enabled. - - Returns: - bool: True if auth_mode is set to "token", False otherwise - """ - return get_authentication_mode() == AuthenticationMode.TOKEN - - -def should_authenticate_request(method: str) -> bool: - """Determine if request method requires authentication. - - Only mutable operations (POST, PUT, PATCH, DELETE) require authentication. - - Args: - method: HTTP method (e.g., "GET", "POST", "PUT", "DELETE") - - Returns: - bool: True if the method requires authentication, False otherwise - """ - return method in ["POST", "PUT", "PATCH", "DELETE"] - - -def validate_request_token(auth_header: str) -> bool: - """Validate the Authorization header from an HTTP request. - - Args: - auth_header: The Authorization header value (e.g., "Bearer ") - - Returns: - bool: True if token is valid, False otherwise - """ - if not auth_header: - return False - - # validate_authentication_token expects full "Bearer " format - # and performs equality comparison via C++ layer - return validate_authentication_token(auth_header) diff --git a/python/ray/dashboard/authentication_utils.py b/python/ray/dashboard/authentication_utils.py index b01ed7c0113f..2b39b8d00eb6 100644 --- a/python/ray/dashboard/authentication_utils.py +++ b/python/ray/dashboard/authentication_utils.py @@ -1,6 +1,4 @@ -"""Authentication utilities for Ray dashboard.""" - -from ray.includes.rpc_token_authentication import ( +from ray._raylet import ( AuthenticationMode, get_authentication_mode, validate_authentication_token, @@ -16,20 +14,6 @@ def is_token_auth_enabled() -> bool: return get_authentication_mode() == AuthenticationMode.TOKEN -def should_authenticate_request(method: str) -> bool: - """Determine if request method requires authentication. - - Only mutable operations (POST, PUT, PATCH, DELETE) require authentication. - - Args: - method: HTTP method (e.g., "GET", "POST", "PUT", "DELETE") - - Returns: - bool: True if the method requires authentication, False otherwise - """ - return method in ["POST", "PUT", "PATCH", "DELETE"] - - def validate_request_token(auth_header: str) -> bool: """Validate the Authorization header from an HTTP request. diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 948cd53eecbf..5f7054900c18 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -172,10 +172,6 @@ async def auth_middleware(self, request, handler): if not auth_utils.is_token_auth_enabled(): return await handler(request) - # Skip if request method doesn't require auth (GET, HEAD) - if not auth_utils.should_authenticate_request(request.method): - return await handler(request) - # Extract and validate token auth_header = request.headers.get("Authorization", "") @@ -280,8 +276,8 @@ async def run( app = aiohttp.web.Application( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, middlewares=[ - self.auth_middleware, self.metrics_middleware, + self.auth_middleware, self.path_clean_middleware, self.browsers_no_post_put_middleware, self.cache_control_static_middleware, diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 1698afb1ffc2..cb57c785a518 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -6,6 +6,7 @@ import requests import ray +from ray._raylet import Config from ray.cluster_utils import Cluster @@ -20,30 +21,25 @@ def cleanup_env(): del os.environ["RAY_AUTH_TOKEN"] -def test_dashboard_post_requires_auth_with_valid_token(cleanup_env): - """Test that POST requests succeed with valid token when auth is enabled.""" +def test_dashboard_request_requires_auth_with_valid_token(cleanup_env): + """Test that requests succeed with valid token when auth is enabled.""" test_token = "test_token_12345678901234567890123456789012" os.environ["RAY_auth_mode"] = "token" os.environ["RAY_AUTH_TOKEN"] = test_token - + Config.initialize("") cluster = Cluster() cluster.add_node() try: - ray.init(address=cluster.address) - dashboard_url = ray._private.worker._global_node.webui_url + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] - # POST with valid auth should succeed + # Request with valid auth should succeed headers = {"Authorization": f"Bearer {test_token}"} - response = requests.post( + response = requests.get( f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, headers=headers, ) - assert response.status_code == 403 - - # GET should work without auth - response = requests.get(f"http://{dashboard_url}/api/cluster_status") assert response.status_code == 200 finally: @@ -51,23 +47,21 @@ def test_dashboard_post_requires_auth_with_valid_token(cleanup_env): cluster.shutdown() -def test_dashboard_post_requires_auth_missing_token(cleanup_env): - """Test that POST requests fail without token when auth is enabled.""" +def test_dashboard_request_requires_auth_missing_token(cleanup_env): + """Test that requests fail without token when auth is enabled.""" test_token = "test_token_12345678901234567890123456789012" os.environ["RAY_auth_mode"] = "token" os.environ["RAY_AUTH_TOKEN"] = test_token - + Config.initialize("") cluster = Cluster() cluster.add_node() try: - ray.init(address=cluster.address) - dashboard_url = ray._private.worker._global_node.webui_url + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] - # POST without auth should fail with 401 - # We need to find a POST endpoint to test - # For now, let's test with a generic endpoint - response = requests.post( + # GET without auth should fail with 401 + response = requests.get( f"http://{dashboard_url}/api/component_activities", json={"test": "data"}, ) @@ -78,23 +72,23 @@ def test_dashboard_post_requires_auth_missing_token(cleanup_env): cluster.shutdown() -def test_dashboard_post_requires_auth_invalid_token(cleanup_env): - """Test that POST requests fail with invalid token when auth is enabled.""" +def test_dashboard_request_requires_auth_invalid_token(cleanup_env): + """Test that requests fail with invalid token when auth is enabled.""" correct_token = "test_token_12345678901234567890123456789012" wrong_token = "wrong_token_00000000000000000000000000000000" os.environ["RAY_auth_mode"] = "token" os.environ["RAY_AUTH_TOKEN"] = correct_token - + Config.initialize("") cluster = Cluster() cluster.add_node() try: - ray.init(address=cluster.address) - dashboard_url = ray._private.worker._global_node.webui_url + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] - # POST with wrong token should fail with 403 + # Request with wrong token should fail with 403 headers = {"Authorization": f"Bearer {wrong_token}"} - response = requests.post( + response = requests.get( f"http://{dashboard_url}/api/component_activities", json={"test": "data"}, headers=headers, @@ -106,32 +100,6 @@ def test_dashboard_post_requires_auth_invalid_token(cleanup_env): cluster.shutdown() -def test_dashboard_get_no_auth_required(cleanup_env): - """Test that GET requests don't require auth even when token mode is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - - cluster = Cluster() - cluster.add_node() - - try: - ray.init(address=cluster.address) - dashboard_url = ray._private.worker._global_node.webui_url - - # GET without auth should succeed - response = requests.get(f"http://{dashboard_url}/") - assert response.status_code == 200 - - # GET to API endpoint should also succeed - response = requests.get(f"http://{dashboard_url}/api/cluster_status") - assert response.status_code == 200 - - finally: - ray.shutdown() - cluster.shutdown() - - def test_dashboard_auth_disabled(cleanup_env): """Test that auth is not enforced when auth_mode is disabled.""" os.environ["RAY_auth_mode"] = "disabled" @@ -140,15 +108,15 @@ def test_dashboard_auth_disabled(cleanup_env): cluster.add_node() try: - ray.init(address=cluster.address) - dashboard_url = ray._private.worker._global_node.webui_url + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] - # POST without auth should succeed when auth is disabled - response = requests.post( + # GET without auth should succeed when auth is disabled + response = requests.get( f"http://{dashboard_url}/api/component_activities", json={"test": "data"} ) # Should not return 401 or 403 - assert response.status_code not in [401, 403] + assert response.status_code == 200 finally: ray.shutdown() From cb009334cd2822d63b9e307f56755bb2808b7391 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 09:26:26 +0000 Subject: [PATCH 57/94] add test to bazel Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index cf5c69b54ac0..a6774586e20a 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -543,6 +543,7 @@ py_test_module_list( "test_concurrency_group.py", "test_core_worker_io_thread_stack_size.py", "test_cross_language.py", + "test_dashboard_auth", "test_debug_tools.py", "test_distributed_sort.py", "test_environ.py", From e39247d3481ccaaa47ce1a087f704b868c8ecf6d Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 15:43:03 +0000 Subject: [PATCH 58/94] fix typo Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index a6774586e20a..3b11301a5318 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -543,7 +543,7 @@ py_test_module_list( "test_concurrency_group.py", "test_core_worker_io_thread_stack_size.py", "test_cross_language.py", - "test_dashboard_auth", + "test_dashboard_auth.py", "test_debug_tools.py", "test_distributed_sort.py", "test_environ.py", From bf4866a6506af3b08f77a0ce2e2c4a56f1e82f29 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 15:54:19 +0000 Subject: [PATCH 59/94] attempt to fix tests Signed-off-by: sampan --- python/ray/tests/test_token_auth_integration.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 6c120b2b9c51..4ddc77f87f3f 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -37,6 +37,8 @@ def clean_token_sources(): original_content = default_token_path.read_text() default_token_path.unlink() + Config.initialize("") + # Reset token caches (both Python and C++) reset_token_cache() @@ -54,8 +56,12 @@ def clean_token_sources(): default_token_path.parent.mkdir(parents=True, exist_ok=True) default_token_path.write_text(original_content) + if ray.is_initialized(): + ray.shutdown() + # Reset token caches again after test reset_token_cache() + Config.initialize("") def test_local_cluster_generates_token(): From 61646af2eb8c04079a20d94b29a9ec568e9260cc Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 29 Oct 2025 05:14:19 +0000 Subject: [PATCH 60/94] attempt to fix test in CI Signed-off-by: sampan --- .../ray/tests/test_token_auth_integration.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 4ddc77f87f3f..522fda6f25f4 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -5,6 +5,8 @@ from pathlib import Path import pytest +import shutil +import tempfile import ray from ray._raylet import AuthenticationTokenLoader, Config @@ -18,6 +20,23 @@ def reset_token_cache(): @pytest.fixture(autouse=True) def clean_token_sources(): """Clean up all token sources before and after each test.""" + # This follows the same pattern as authentication_token_loader_test.cc + if "HOME" not in os.environ: + # Use TEST_TMPDIR if available (Bazel sets this), otherwise use system temp + test_tmpdir = os.environ.get("TEST_TMPDIR") + if test_tmpdir: + temp_home = os.path.join(test_tmpdir, "ray_test_home") + else: + temp_home = "/tmp/ray_test_home" + + # Create the directory if it doesn't exist + os.makedirs(temp_home, exist_ok=True) + os.environ["HOME"] = temp_home + home_was_set = False + else: + temp_home = None + home_was_set = True + # Clean environment variables env_vars_to_clean = [ "RAY_AUTH_TOKEN", @@ -33,6 +52,7 @@ def clean_token_sources(): # Clean default token file default_token_path = Path.home() / ".ray" / "auth_token" original_exists = default_token_path.exists() + original_content = None if original_exists: original_content = default_token_path.read_text() default_token_path.unlink() @@ -52,7 +72,7 @@ def clean_token_sources(): del os.environ[var] # Restore default token file - if original_exists: + if original_exists and original_content is not None: default_token_path.parent.mkdir(parents=True, exist_ok=True) default_token_path.write_text(original_content) @@ -63,12 +83,24 @@ def clean_token_sources(): reset_token_cache() Config.initialize("") + # Clean up temporary HOME if we created one + # Only delete if we set it and it was temporary + if temp_home is not None and not home_was_set: + try: + if os.path.exists(temp_home): + shutil.rmtree(temp_home) + except Exception: + pass # Best effort cleanup + # Remove the HOME env var we set + if "HOME" in os.environ and os.environ["HOME"] == temp_home: + del os.environ["HOME"] + def test_local_cluster_generates_token(): """Test ray.init() generates token for local cluster when auth_mode=token is set.""" # Ensure no token exists default_token_path = Path.home() / ".ray" / "auth_token" - assert not default_token_path.exists() + assert not default_token_path.exists(), f"Token file already exists at {default_token_path}" # Enable token auth via environment variable os.environ["RAY_auth_mode"] = "token" @@ -79,7 +111,11 @@ def test_local_cluster_generates_token(): try: # Verify token file was created - assert default_token_path.exists() + assert default_token_path.exists(), ( + f"Token file was not created at {default_token_path}. " + f"HOME={os.environ.get('HOME')}, " + f"Files in {default_token_path.parent}: {list(default_token_path.parent.iterdir()) if default_token_path.parent.exists() else 'directory does not exist'}" + ) token = default_token_path.read_text().strip() assert len(token) == 32 assert all(c in "0123456789abcdef" for c in token) From 5b3cc5b0828d3ea6e06461f3ba9d27aca79c0d70 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 29 Oct 2025 05:14:38 +0000 Subject: [PATCH 61/94] fix lint Signed-off-by: sampan --- python/ray/tests/test_token_auth_integration.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 522fda6f25f4..e934ffbb1193 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -1,12 +1,11 @@ """Integration tests for token-based authentication in Ray.""" import os +import shutil import sys from pathlib import Path import pytest -import shutil -import tempfile import ray from ray._raylet import AuthenticationTokenLoader, Config @@ -28,7 +27,7 @@ def clean_token_sources(): temp_home = os.path.join(test_tmpdir, "ray_test_home") else: temp_home = "/tmp/ray_test_home" - + # Create the directory if it doesn't exist os.makedirs(temp_home, exist_ok=True) os.environ["HOME"] = temp_home @@ -100,7 +99,9 @@ def test_local_cluster_generates_token(): """Test ray.init() generates token for local cluster when auth_mode=token is set.""" # Ensure no token exists default_token_path = Path.home() / ".ray" / "auth_token" - assert not default_token_path.exists(), f"Token file already exists at {default_token_path}" + assert ( + not default_token_path.exists() + ), f"Token file already exists at {default_token_path}" # Enable token auth via environment variable os.environ["RAY_auth_mode"] = "token" From c0c2e0550d107ae14efe892aebe852e7f06dc594 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 29 Oct 2025 06:52:04 +0000 Subject: [PATCH 62/94] [Core] Verify token presence when using ray start CLI Signed-off-by: sampan --- python/ray/scripts/scripts.py | 9 + .../ray/tests/test_token_auth_integration.py | 220 ++++++++++++++++++ 2 files changed, 229 insertions(+) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index c1a7230f55a9..0eb84b7b7be2 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -26,6 +26,9 @@ from ray._common.network_utils import build_address, parse_address from ray._common.usage import usage_lib from ray._common.utils import load_class +from ray._private.authentication.authentication_token_setup import ( + ensure_token_if_auth_enabled, +) from ray._private.internal_api import memory_summary from ray._private.label_utils import ( parse_node_labels_from_yaml_file, @@ -937,6 +940,9 @@ def start( " flag of `ray start` command." ) + # Ensure auth token is available if authentication mode is token + ensure_token_if_auth_enabled(system_config, create_token_if_missing=False) + node = ray._private.node.Node( ray_params, head=True, shutdown_at_exit=block, spawn_reaper=block ) @@ -1094,6 +1100,9 @@ def start( cli_logger.labeled_value("Local node IP", ray_params.node_ip_address) + # Ensure auth token is available if authentication mode is token + ensure_token_if_auth_enabled(system_config, create_token_if_missing=False) + node = ray._private.node.Node( ray_params, head=False, shutdown_at_exit=block, spawn_reaper=block ) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index e934ffbb1193..efff4243b2dd 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -2,12 +2,15 @@ import os import shutil +import subprocess import sys from pathlib import Path +from typing import Optional import pytest import ray +from ray._private.test_utils import wait_for_condition from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import Cluster @@ -16,6 +19,70 @@ def reset_token_cache(): AuthenticationTokenLoader.instance().reset_cache() +def _run_ray_start_and_verify_status( + args: list, env: dict, expect_success: bool = True, timeout: int = 30 +) -> subprocess.CompletedProcess: + """Helper to run ray start command with proper error handling.""" + result = subprocess.run( + ["ray", "start"] + args, + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if expect_success: + assert result.returncode == 0, ( + f"ray start should have succeeded. " + f"stdout: {result.stdout}, stderr: {result.stderr}" + ) + else: + assert result.returncode != 0, ( + f"ray start should have failed but succeeded. " + f"stdout: {result.stdout}, stderr: {result.stderr}" + ) + # Check that error message mentions token + error_output = result.stdout + result.stderr + assert ( + "authentication token" in error_output.lower() + or "token" in error_output.lower() + ), f"Error message should mention token. Got: {error_output}" + + return result + + +def _cleanup_ray_start(env: Optional[dict] = None): + """Helper to clean up ray start processes.""" + # Ensure any ray.init() connection is closed first + if ray.is_initialized(): + ray.shutdown() + + # Stop with a longer timeout + subprocess.run( + ["ray", "stop", "--force"], + env=env, + capture_output=True, + timeout=60, # Increased timeout for flaky cleanup + check=False, # Don't raise on non-zero exit + ) + + # Wait for ray processes to actually stop + def ray_stopped(): + result = subprocess.run( + ["ray", "status"], + capture_output=True, + check=False, + ) + # ray status returns non-zero when no cluster is running + return result.returncode != 0 + + try: + wait_for_condition(ray_stopped, timeout=10, retry_interval_ms=500) + except Exception: + # Best effort - don't fail the test if we can't verify it stopped + pass + + @pytest.fixture(autouse=True) def clean_token_sources(): """Clean up all token sources before and after each test.""" @@ -78,6 +145,14 @@ def clean_token_sources(): if ray.is_initialized(): ray.shutdown() + # Ensure all ray processes are stopped + subprocess.run( + ["ray", "stop", "--force"], + capture_output=True, + timeout=60, + check=False, + ) + # Reset token caches again after test reset_token_cache() Config.initialize("") @@ -219,5 +294,150 @@ def test_func(): cluster.shutdown() +@pytest.mark.parametrize("is_head", [True, False]) +def test_ray_start_without_token_raises_error(is_head): + """Test that ray start fails when auth_mode=token but no token exists.""" + # Set up environment with token auth enabled but no token + env = os.environ.copy() + env["RAY_auth_mode"] = "token" + env.pop("RAY_AUTH_TOKEN", None) + env.pop("RAY_AUTH_TOKEN_PATH", None) + + # Ensure no default token file exists (already cleaned by fixture) + default_token_path = Path.home() / ".ray" / "auth_token" + assert not default_token_path.exists() + + # When specifying an address, we need a head node to connect to + cluster = None + if not is_head: + # Start head node with token + cluster_token = "a" * 32 + os.environ["RAY_AUTH_TOKEN"] = cluster_token + os.environ["RAY_auth_mode"] = "token" + Config.initialize("") + cluster = Cluster() + cluster.add_node() + + try: + # Prepare arguments + if is_head: + args = ["--head", "--port=0"] + else: + args = [f"--address={cluster.address}"] + + # Try to start node - should fail + _run_ray_start_and_verify_status(args, env, expect_success=False) + + finally: + if cluster: + cluster.shutdown() + + +def test_ray_start_head_with_token_succeeds(): + """Test that ray start --head succeeds when token auth is enabled with a valid token.""" + # Set up environment with token auth and a valid token + test_token = "a" * 32 + env = os.environ.copy() + env["RAY_AUTH_TOKEN"] = test_token + env["RAY_auth_mode"] = "token" + + try: + # Start head node - should succeed + _run_ray_start_and_verify_status( + ["--head", "--port=0"], env, expect_success=True + ) + + # Verify we can connect to the cluster with ray.init() + os.environ["RAY_AUTH_TOKEN"] = test_token + os.environ["RAY_auth_mode"] = "token" + Config.initialize("") + reset_token_cache() + + # Wait for cluster to be ready + def cluster_ready(): + try: + ray.init(address="auto") + return True + except Exception: + return False + + wait_for_condition(cluster_ready, timeout=10) + assert ray.is_initialized() + + # Test basic operations work + @ray.remote + def test_func(): + return "success" + + result = ray.get(test_func.remote()) + assert result == "success" + + finally: + # Cleanup handles ray.shutdown() internally + _cleanup_ray_start(env) + + +@pytest.mark.parametrize("token_match", ["correct", "incorrect"]) +def test_ray_start_address_with_token(token_match): + """Test ray start --address=... with correct or incorrect token.""" + # Start a head node with token auth + cluster_token = "a" * 32 + os.environ["RAY_AUTH_TOKEN"] = cluster_token + os.environ["RAY_auth_mode"] = "token" + Config.initialize("") + + cluster = Cluster() + cluster.add_node(num_cpus=1) + + try: + # Set up environment for worker + env = os.environ.copy() + env["RAY_auth_mode"] = "token" + + if token_match == "correct": + env["RAY_AUTH_TOKEN"] = cluster_token + expect_success = True + else: + # Use different token + env["RAY_AUTH_TOKEN"] = "b" * 32 + expect_success = False + + # Start worker node + _run_ray_start_and_verify_status( + [f"--address={cluster.address}", "--num-cpus=1"], + env, + expect_success=expect_success, + ) + + if token_match == "correct": + try: + # Connect and verify the cluster has 2 nodes (head + worker) + ray.init(address=cluster.address) + + # Wait for worker node to register + def worker_joined(): + return len(ray.nodes()) >= 2 + + wait_for_condition(worker_joined, timeout=10) + + nodes = ray.nodes() + assert ( + len(nodes) >= 2 + ), f"Expected at least 2 nodes, got {len(nodes)}: {nodes}" + + finally: + # Always shutdown ray.init() connection before cleanup + if ray.is_initialized(): + ray.shutdown() + # Clean up the worker node started with ray start + _cleanup_ray_start(env) + + finally: + # Clean up cluster + if ray.is_initialized(): + ray.shutdown() + cluster.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-vv", __file__])) From 47c8042becdbd7fa4fe1157f77274b4577e5ba0d Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 29 Oct 2025 14:27:18 +0000 Subject: [PATCH 63/94] move common fixtures to conftest.py Signed-off-by: sampan --- .../authentication_constants.py | 29 +++ .../authentication_token_setup.py | 10 +- python/ray/dashboard/modules/dashboard_sdk.py | 46 +++- .../ray/includes/rpc_token_authentication.pxd | 1 + .../ray/includes/rpc_token_authentication.pxi | 33 +++ python/ray/tests/conftest.py | 83 +++++++ python/ray/tests/test_dashboard_auth.py | 19 +- .../ray/tests/test_submission_client_auth.py | 221 ++++++++++++++++++ .../ray/tests/test_token_auth_integration.py | 2 +- .../rpc/authentication/authentication_token.h | 11 + 10 files changed, 429 insertions(+), 26 deletions(-) create mode 100644 python/ray/_private/authentication/authentication_constants.py create mode 100644 python/ray/tests/test_submission_client_auth.py diff --git a/python/ray/_private/authentication/authentication_constants.py b/python/ray/_private/authentication/authentication_constants.py new file mode 100644 index 000000000000..17e70232f4d4 --- /dev/null +++ b/python/ray/_private/authentication/authentication_constants.py @@ -0,0 +1,29 @@ +"""Centralized authentication constants and error messages for Ray. + +This module provides reusable error messages for authentication failures +across CLI, dashboard, and other Ray python components. +""" + +# Token setup instructions (used in multiple contexts) +TOKEN_SETUP_INSTRUCTIONS = """Please provide an authentication token using one of these methods: + 1. Set the RAY_AUTH_TOKEN environment variable + 2. Set the RAY_AUTH_TOKEN_PATH environment variable (pointing to a token file) + 3. Create a token file at the default location: ~/.ray/auth_token""" + +# When token auth is enabled but no token is found anywhere +TOKEN_AUTH_ENABLED_BUT_NO_TOKEN_FOUND_ERROR_MESSAGE = ( + "Token authentication is enabled but no authentication token was found. " + + TOKEN_SETUP_INSTRUCTIONS +) + +# When HTTP request fails with 401 (Unauthorized - missing token) +HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE = ( + "The Ray cluster requires authentication, but no token was provided.\n\n" + + TOKEN_SETUP_INSTRUCTIONS +) + +# When HTTP request fails with 403 (Forbidden - invalid token) +HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE = ( + "The authentication token you provided is invalid or incorrect.\n\n" + + TOKEN_SETUP_INSTRUCTIONS +) diff --git a/python/ray/_private/authentication/authentication_token_setup.py b/python/ray/_private/authentication/authentication_token_setup.py index 99d0bb382481..8ad292430406 100644 --- a/python/ray/_private/authentication/authentication_token_setup.py +++ b/python/ray/_private/authentication/authentication_token_setup.py @@ -9,6 +9,9 @@ from pathlib import Path from typing import Any, Dict, Optional +from ray._private.authentication.authentication_constants import ( + TOKEN_AUTH_ENABLED_BUT_NO_TOKEN_FOUND_ERROR_MESSAGE, +) from ray._private.authentication.authentication_token_generator import ( generate_new_authentication_token, ) @@ -20,13 +23,6 @@ logger = logging.getLogger(__name__) -TOKEN_AUTH_ENABLED_BUT_NO_TOKEN_FOUND_ERROR_MESSAGE = ( - "Token authentication is enabled but no authentication token was found. Please provide a token with one of these options:\n" - + " 1. RAY_AUTH_TOKEN environment variable\n" - + " 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file)\n" - + " 3. Default token file: ~/.ray/auth_token" -) - def generate_and_save_token() -> None: """Generate a new random token and save it in the default token path. diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index 6b0dfdaadab7..2cdf738590a2 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -12,6 +12,7 @@ import yaml import ray +from ray._private.authentication import authentication_constants from ray._private.runtime_env.packaging import ( create_package, get_uri_for_directory, @@ -20,7 +21,9 @@ from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed from ray._private.utils import split_address +from ray._raylet import AuthenticationTokenLoader from ray.autoscaler._private.cli_logger import cli_logger +from ray.dashboard.authentication_utils import is_token_auth_enabled from ray.dashboard.modules.job.common import uri_to_http_components from ray.util.annotations import DeveloperAPI, PublicAPI @@ -222,7 +225,11 @@ def __init__( self._default_metadata = cluster_info.metadata or {} # Headers used for all requests sent to job server, optional and only # needed for cases like authentication to remote cluster. - self._headers = cluster_info.headers + self._headers = cluster_info.headers or {} + + # Add authentication token if token auth is enabled + self._set_auth_header_if_enabled() + # Set SSL verify parameter for the requests library and create an ssl_context # object when needed for the aiohttp library. self._verify = verify @@ -242,6 +249,22 @@ def __init__( else: self._ssl_context = None + def _set_auth_header_if_enabled(self): + """Add authentication token to headers if token auth is enabled.""" + if is_token_auth_enabled(): + token_loader = AuthenticationTokenLoader.instance() + token_added = token_loader.set_token_for_http_header(self._headers) + + if not token_added: + # Token auth is enabled but no token found or Authorization already set + if "Authorization" not in self._headers: + # No token found - log warning but don't fail yet + # Let the server return 401 for a better error message + logger.warning( + "Token authentication is enabled but no token was found. " + "Requests to authenticated clusters will fail." + ) + def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None ): @@ -293,14 +316,15 @@ def _do_request( json_data: Optional[dict] = None, **kwargs, ) -> "requests.Response": - """Perform the actual HTTP request + """Perform the actual HTTP request with authentication error handling. Keyword arguments other than "cookies", "headers" are forwarded to the `requests.request()`. """ url = self._address + endpoint logger.debug(f"Sending request to {url} with json data: {json_data or {}}.") - return requests.request( + + response = requests.request( method, url, cookies=self._cookies, @@ -311,6 +335,22 @@ def _do_request( **kwargs, ) + # Check for authentication errors and provide helpful messages + if response.status_code == 401: + # Unauthorized - missing or no token provided + raise RuntimeError( + f"Authentication required: {response.text}\n\n" + + authentication_constants.HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE + ) + elif response.status_code == 403: + # Forbidden - invalid token + raise RuntimeError( + f"Authentication failed: {response.text}\n\n" + + authentication_constants.HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE + ) + + return response + def _package_exists( self, package_uri: str, diff --git a/python/ray/includes/rpc_token_authentication.pxd b/python/ray/includes/rpc_token_authentication.pxd index b589b2e0c287..cc6c05e92dac 100644 --- a/python/ray/includes/rpc_token_authentication.pxd +++ b/python/ray/includes/rpc_token_authentication.pxd @@ -15,6 +15,7 @@ cdef extern from "ray/rpc/authentication/authentication_token.h" namespace "ray: CAuthenticationToken(string value) c_bool empty() c_bool Equals(const CAuthenticationToken& other) + string ToHttpHeaderValue() @staticmethod CAuthenticationToken FromMetadata(string metadata_value) diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index ce70c2d72915..3c030680baab 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -69,3 +69,36 @@ class AuthenticationTokenLoader: variables or files on the next request. """ CAuthenticationTokenLoader.instance().ResetCache() + + def set_token_for_http_header(self, headers: dict): + """Add authentication token to HTTP headers dictionary if token auth is enabled. + + This method loads the token from C++ AuthenticationTokenLoader and adds it + to the provided headers dictionary as the Authorization header. It only adds + the token if: + - Token authentication is enabled + - A token exists + - The Authorization header is not already set in the headers + + Args: + headers: Dictionary of HTTP headers to modify (modified in-place) + + Returns: + bool: True if token was added, False otherwise + """ + # Don't override if user explicitly set Authorization header + if "Authorization" in headers: + return False + + # Check if token exists (doesn't crash, returns bool) + if not self.has_token(): + return False + + # Get the token from C++ layer + cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken() + + if not token_opt.has_value() || token_opt.value().empty(): + return False + + headers["Authorization"] = token_opt.value().ToAuthorizationHeaderValue() + return True diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 0d8fc6427536..89cadc96ee73 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -42,6 +42,7 @@ start_redis_sentinel_instance, teardown_tls, ) +from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import AutoscalingCluster, Cluster, cluster_not_supported import psutil @@ -54,6 +55,88 @@ START_REDIS_WAIT_RETRIES = int(os.environ.get("RAY_START_REDIS_WAIT_RETRIES", "60")) +@pytest.fixture +def cleanup_auth_token_env(): + """Reset Ray authentication-related environment variables and caches.""" + + env_vars = ["RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] + original_env = {var: os.environ.get(var) for var in env_vars} + + default_token_path = Path.home() / ".ray" / "auth_token" + token_file_exists = default_token_path.exists() + token_file_contents = default_token_path.read_text() if token_file_exists else None + + AuthenticationTokenLoader.instance().reset_cache() + Config.initialize("") + + try: + yield + finally: + for var, value in original_env.items(): + if value is None: + os.environ.pop(var, None) + else: + os.environ[var] = value + + if token_file_exists: + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text(token_file_contents) + else: + default_token_path.unlink(missing_ok=True) + + AuthenticationTokenLoader.instance().reset_cache() + Config.initialize("") + + +@pytest.fixture +def setup_cluster_with_token_auth(cleanup_auth_token_env): + """Spin up a Ray cluster with token authentication enabled.""" + + test_token = "test_token_12345678901234567890123456789012" + os.environ["RAY_auth_mode"] = "token" + os.environ["RAY_AUTH_TOKEN"] = test_token + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + cluster = Cluster() + cluster.add_node() + + try: + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] + yield { + "cluster": cluster, + "dashboard_url": f"http://{dashboard_url}", + "token": test_token, + } + finally: + ray.shutdown() + cluster.shutdown() + + +@pytest.fixture +def setup_cluster_without_token_auth(cleanup_auth_token_env): + """Spin up a Ray cluster with authentication disabled.""" + + os.environ["RAY_auth_mode"] = "disabled" + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + cluster = Cluster() + cluster.add_node() + + try: + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] + yield { + "cluster": cluster, + "dashboard_url": f"http://{dashboard_url}", + } + finally: + ray.shutdown() + cluster.shutdown() + + @pytest.fixture(autouse=True) def pre_envs(monkeypatch): # To make test run faster diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index cb57c785a518..5f1615573f0c 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -10,18 +10,7 @@ from ray.cluster_utils import Cluster -@pytest.fixture -def cleanup_env(): - """Clean up environment variables after each test.""" - yield - # Clean up environment variables - if "RAY_auth_mode" in os.environ: - del os.environ["RAY_auth_mode"] - if "RAY_AUTH_TOKEN" in os.environ: - del os.environ["RAY_AUTH_TOKEN"] - - -def test_dashboard_request_requires_auth_with_valid_token(cleanup_env): +def test_dashboard_request_requires_auth_with_valid_token(cleanup_auth_token_env): """Test that requests succeed with valid token when auth is enabled.""" test_token = "test_token_12345678901234567890123456789012" os.environ["RAY_auth_mode"] = "token" @@ -47,7 +36,7 @@ def test_dashboard_request_requires_auth_with_valid_token(cleanup_env): cluster.shutdown() -def test_dashboard_request_requires_auth_missing_token(cleanup_env): +def test_dashboard_request_requires_auth_missing_token(cleanup_auth_token_env): """Test that requests fail without token when auth is enabled.""" test_token = "test_token_12345678901234567890123456789012" os.environ["RAY_auth_mode"] = "token" @@ -72,7 +61,7 @@ def test_dashboard_request_requires_auth_missing_token(cleanup_env): cluster.shutdown() -def test_dashboard_request_requires_auth_invalid_token(cleanup_env): +def test_dashboard_request_requires_auth_invalid_token(cleanup_auth_token_env): """Test that requests fail with invalid token when auth is enabled.""" correct_token = "test_token_12345678901234567890123456789012" wrong_token = "wrong_token_00000000000000000000000000000000" @@ -100,7 +89,7 @@ def test_dashboard_request_requires_auth_invalid_token(cleanup_env): cluster.shutdown() -def test_dashboard_auth_disabled(cleanup_env): +def test_dashboard_auth_disabled(cleanup_auth_token_env): """Test that auth is not enforced when auth_mode is disabled.""" os.environ["RAY_auth_mode"] = "disabled" diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py new file mode 100644 index 000000000000..2da02a85f065 --- /dev/null +++ b/python/ray/tests/test_submission_client_auth.py @@ -0,0 +1,221 @@ +import os +import tempfile +from pathlib import Path + +import pytest + +import ray +from ray._private.authentication.authentication_constants import ( + HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, + HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, +) +from ray._raylet import AuthenticationTokenLoader, Config +from ray.cluster_utils import Cluster +from ray.dashboard.modules.job.sdk import JobSubmissionClient +from ray.util.state import StateApiClient + + +def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): + """Test that SubmissionClient automatically adds token to headers.""" + # Token is already set in environment from setup_cluster_with_token_auth fixture + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Verify Authorization header was added + assert "Authorization" in client._headers + assert client._headers["Authorization"].startswith("Bearer ") + + +def test_submission_client_without_token_shows_helpful_error( + setup_cluster_with_token_auth, +): + """Test that requests without token show helpful error message.""" + # Remove token from environment + os.environ.pop("RAY_AUTH_TOKEN", None) + os.environ["RAY_auth_mode"] = "disabled" + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Make a request - should fail with helpful message + with pytest.raises(RuntimeError) as exc_info: + client.get_version() + + expected_message = ( + "Authentication required: Unauthorized: Missing authentication token\n\n" + f"{HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE}" + ) + assert str(exc_info.value) == expected_message + + +def test_submission_client_with_invalid_token_shows_helpful_error( + setup_cluster_with_token_auth, +): + """Test that requests with wrong token show helpful error message.""" + # Set wrong token + wrong_token = "wrong_token_00000000000000000000000000000000" + os.environ["RAY_AUTH_TOKEN"] = wrong_token + AuthenticationTokenLoader.instance().reset_cache() + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Make a request - should fail with helpful message + with pytest.raises(RuntimeError) as exc_info: + client.get_version() + + expected_message = ( + "Authentication failed: Forbidden: Invalid authentication token\n\n" + f"{HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE}" + ) + assert str(exc_info.value) == expected_message + + +def test_submission_client_with_valid_token_succeeds(setup_cluster_with_token_auth): + """Test that requests with valid token succeed.""" + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Make a request - should succeed + version = client.get_version() + assert version is not None + + +def test_job_submission_client_inherits_auth(setup_cluster_with_token_auth): + """Test that JobSubmissionClient inherits auth from SubmissionClient.""" + client = JobSubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Verify Authorization header was added + assert "Authorization" in client._headers + assert client._headers["Authorization"].startswith("Bearer ") + + # Verify client can make authenticated requests + version = client.get_version() + assert version is not None + + +def test_state_api_client_inherits_auth(setup_cluster_with_token_auth): + """Test that StateApiClient inherits auth from SubmissionClient.""" + client = StateApiClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + # Verify Authorization header was added + assert "Authorization" in client._headers + assert client._headers["Authorization"].startswith("Bearer ") + + +def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): + """Test that user-provided Authorization header is not overridden.""" + custom_auth = "Bearer custom_token" + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient( + address=setup_cluster_with_token_auth["dashboard_url"], + headers={"Authorization": custom_auth}, + ) + + # Verify custom value is preserved + assert client._headers["Authorization"] == custom_auth + + +def test_error_messages_contain_instructions(setup_cluster_with_token_auth): + """Test that all auth error messages contain setup instructions.""" + # Test 401 error (missing token) + os.environ.pop("RAY_AUTH_TOKEN", None) + os.environ["RAY_auth_mode"] = "disabled" + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + with pytest.raises(RuntimeError) as exc_info: + client.get_version() + + expected_missing = ( + "Authentication required: Unauthorized: Missing authentication token\n\n" + f"{HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE}" + ) + assert str(exc_info.value) == expected_missing + + # Test 403 error (invalid token) + os.environ["RAY_AUTH_TOKEN"] = "wrong_token_00000000000000000000000000000000" + os.environ["RAY_auth_mode"] = "token" + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + client2 = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) + + with pytest.raises(RuntimeError) as exc_info: + client2.get_version() + + expected_invalid = ( + "Authentication failed: Forbidden: Invalid authentication token\n\n" + f"{HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE}" + ) + assert str(exc_info.value) == expected_invalid + + +@pytest.mark.parametrize("token_source", ["env_var", "token_path", "default_path"]) +def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): + """Test that SubmissionClient loads tokens from all supported sources.""" + + test_token = "test_token_12345678901234567890123456789012" + os.environ["RAY_auth_mode"] = "token" + + token_file_path = None + default_token_path = Path.home() / ".ray" / "auth_token" + + if token_source == "env_var": + os.environ["RAY_AUTH_TOKEN"] = test_token + elif token_source == "token_path": + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(test_token) + token_file_path = tmp.name + os.environ["RAY_AUTH_TOKEN_PATH"] = token_file_path + else: + default_token_path.parent.mkdir(parents=True, exist_ok=True) + default_token_path.write_text(test_token) + + Config.initialize("") + AuthenticationTokenLoader.instance().reset_cache() + + cluster = Cluster() + cluster.add_node() + + try: + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=f"http://{dashboard_url}") + assert client._headers["Authorization"] == f"Bearer {test_token}" + finally: + ray.shutdown() + cluster.shutdown() + if token_file_path: + os.unlink(token_file_path) + + +def test_no_token_added_when_auth_disabled(setup_cluster_without_token_auth): + """Test that no Authorization header is injected when auth is disabled.""" + + from ray.dashboard.modules.dashboard_sdk import SubmissionClient + + client = SubmissionClient(address=setup_cluster_without_token_auth["dashboard_url"]) + + assert "Authorization" not in client._headers + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index efff4243b2dd..eea727927477 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -84,7 +84,7 @@ def ray_stopped(): @pytest.fixture(autouse=True) -def clean_token_sources(): +def clean_token_sources(cleanup_auth_token_env): """Clean up all token sources before and after each test.""" # This follows the same pattern as authentication_token_loader_test.cc if "HOME" not in os.environ: diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index 4f32310784de..bdedd5bf6e84 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -86,6 +86,17 @@ class AuthenticationToken { } } + /// Get token as HTTP Authorization header value + /// WARNING: This exposes the raw token. Use sparingly. + /// Returns "Bearer " format suitable for Authorization header + /// @return Authorization header value, or empty string if token is empty + std::string ToAuthorizationHeaderValue() const { + if (secret_.empty()) { + return ""; + } + return kBearerPrefix + std::string(secret_.begin(), secret_.end()); + } + /// Create AuthenticationToken from gRPC metadata value /// Strips "Bearer " prefix and creates token object /// @param metadata_value The raw value from server metadata (should include "Bearer " From 3d84e625840af0619781b4997259fbe81e04f705 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 05:40:13 +0000 Subject: [PATCH 64/94] refactor all authentication tests and reuse common code Signed-off-by: sampan --- .../ray/includes/rpc_token_authentication.pxi | 2 +- python/ray/tests/BUILD.bazel | 1 + python/ray/tests/authentication_test_utils.py | 150 ++++++++++++++++++ python/ray/tests/conftest.py | 53 +++---- python/ray/tests/test_dashboard_auth.py | 148 ++++++----------- .../ray/tests/test_submission_client_auth.py | 61 +++---- .../ray/tests/test_token_auth_integration.py | 132 ++++----------- 7 files changed, 282 insertions(+), 265 deletions(-) create mode 100644 python/ray/tests/authentication_test_utils.py diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index 3c030680baab..beda93da86cd 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -97,7 +97,7 @@ class AuthenticationTokenLoader: # Get the token from C++ layer cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken() - if not token_opt.has_value() || token_opt.value().empty(): + if not token_opt.has_value() or token_opt.value().empty(): return False headers["Authorization"] = token_opt.value().ToAuthorizationHeaderValue() diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 3b11301a5318..d5857192fd39 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -579,6 +579,7 @@ py_test_module_list( "test_runtime_env_py_executable.py", "test_state_api_summary.py", "test_streaming_generator_regression.py", + "test_submission_client_auth.py", "test_system_metrics.py", "test_task_events_3.py", "test_task_metrics_reconstruction.py", diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py new file mode 100644 index 000000000000..068ab3370978 --- /dev/null +++ b/python/ray/tests/authentication_test_utils.py @@ -0,0 +1,150 @@ +import os +import shutil +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + +from ray._raylet import AuthenticationTokenLoader, Config + +AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") +DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" + + +def reset_auth_token_state() -> None: + """Reset authentication token and auth_mode config.""" + + AuthenticationTokenLoader.instance().reset_cache() + Config.initialize("") + + +def set_auth_mode(mode: str) -> None: + """Set the authentication mode environment variable.""" + + os.environ["RAY_auth_mode"] = mode + + +def set_env_auth_token(token: str) -> None: + """Configure the authentication token via environment variable.""" + + os.environ["RAY_AUTH_TOKEN"] = token + os.environ.pop("RAY_AUTH_TOKEN_PATH", None) + + +def set_auth_token_path(token: str, path: Path) -> None: + """Write the authentication token to a specific path and point the loader to it.""" + + token_path = Path(path) + token_path.parent.mkdir(parents=True, exist_ok=True) + token_path.write_text(token) + os.environ["RAY_AUTH_TOKEN_PATH"] = str(token_path) + os.environ.pop("RAY_AUTH_TOKEN", None) + + +def set_default_auth_token(token: str) -> Path: + """Write the authentication token to the default ~/.ray/auth_token location.""" + + default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path.parent.mkdir(parents=True, exist_ok=True) + default_path.write_text(token) + return default_path + + +def clear_auth_token_sources(remove_default: bool = False) -> None: + """Clear authentication-related environment variables and optional default token file.""" + + for var in ("RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"): + os.environ.pop(var, None) + + if remove_default: + default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path.unlink(missing_ok=True) + + +@dataclass +class AuthenticationEnvSnapshot: + original_env: Dict[str, Optional[str]] + original_home: Optional[str] + home_was_set: bool + temp_home: Optional[Path] + default_token_path: Path + default_token_exists: bool + default_token_contents: Optional[str] + + @classmethod + def capture(cls) -> "AuthenticationEnvSnapshot": + """Capture current authentication-related environment state.""" + + original_env = {var: os.environ.get(var) for var in AUTH_ENV_VARS} + home_was_set = "HOME" in os.environ + original_home = os.environ.get("HOME") + temp_home: Optional[Path] = None + + if not home_was_set: + # in CI $HOME may not be set which can cause issues with tests related to default auth token file. + test_tmpdir = os.environ.get("TEST_TMPDIR") + base_dir = Path(test_tmpdir) if test_tmpdir else Path(tempfile.gettempdir()) + temp_home = base_dir / "ray_test_home" + temp_home.mkdir(parents=True, exist_ok=True) + os.environ["HOME"] = str(temp_home) + + default_token_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_token_exists = default_token_path.exists() + default_token_contents = ( + default_token_path.read_text() if default_token_exists else None + ) + + return cls( + original_env=original_env, + original_home=original_home, + home_was_set=home_was_set, + temp_home=temp_home, + default_token_path=default_token_path, + default_token_exists=default_token_exists, + default_token_contents=default_token_contents, + ) + + def clear_default_token(self) -> None: + """Remove the default token file for the current HOME.""" + + self.default_token_path.unlink(missing_ok=True) + + def restore(self) -> None: + """Restore the captured environment, HOME, and default token file state.""" + + for var, value in self.original_env.items(): + if value is None: + os.environ.pop(var, None) + else: + os.environ[var] = value + + if self.home_was_set: + if self.original_home is None: + os.environ.pop("HOME", None) + else: + os.environ["HOME"] = self.original_home + + if self.default_token_exists: + self.default_token_path.parent.mkdir(parents=True, exist_ok=True) + self.default_token_path.write_text(self.default_token_contents or "") + else: + self.default_token_path.unlink(missing_ok=True) + + if not self.home_was_set: + current_home = os.environ.get("HOME") + if self.temp_home is not None and current_home == str(self.temp_home): + os.environ.pop("HOME", None) + if self.temp_home is not None and self.temp_home.exists(): + shutil.rmtree(self.temp_home, ignore_errors=True) + + +@contextmanager +def authentication_env_guard(): + """Context manager that restores authentication environment state on exit.""" + + snapshot = AuthenticationEnvSnapshot.capture() + try: + yield snapshot + finally: + snapshot.restore() diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 89cadc96ee73..8bbd0ec04717 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -42,8 +42,14 @@ start_redis_sentinel_instance, teardown_tls, ) -from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import AutoscalingCluster, Cluster, cluster_not_supported +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) import psutil @@ -57,35 +63,13 @@ @pytest.fixture def cleanup_auth_token_env(): - """Reset Ray authentication-related environment variables and caches.""" - - env_vars = ["RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] - original_env = {var: os.environ.get(var) for var in env_vars} - - default_token_path = Path.home() / ".ray" / "auth_token" - token_file_exists = default_token_path.exists() - token_file_contents = default_token_path.read_text() if token_file_exists else None - - AuthenticationTokenLoader.instance().reset_cache() - Config.initialize("") + """Reset authentication environment variables, files, and caches.""" - try: + with authentication_env_guard() as snapshot: + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() yield - finally: - for var, value in original_env.items(): - if value is None: - os.environ.pop(var, None) - else: - os.environ[var] = value - - if token_file_exists: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(token_file_contents) - else: - default_token_path.unlink(missing_ok=True) - - AuthenticationTokenLoader.instance().reset_cache() - Config.initialize("") + reset_auth_token_state() @pytest.fixture @@ -93,10 +77,9 @@ def setup_cluster_with_token_auth(cleanup_auth_token_env): """Spin up a Ray cluster with token authentication enabled.""" test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_auth_mode("token") + set_env_auth_token(test_token) + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -118,9 +101,9 @@ def setup_cluster_with_token_auth(cleanup_auth_token_env): def setup_cluster_without_token_auth(cleanup_auth_token_env): """Spin up a Ray cluster with authentication disabled.""" - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() cluster = Cluster() cluster.add_node() diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 5f1615573f0c..635999945ea4 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -1,115 +1,63 @@ """Tests for dashboard token authentication.""" -import os - -import pytest import requests -import ray -from ray._raylet import Config -from ray.cluster_utils import Cluster - -def test_dashboard_request_requires_auth_with_valid_token(cleanup_auth_token_env): +def test_dashboard_request_requires_auth_with_valid_token( + setup_cluster_with_token_auth, +): """Test that requests succeed with valid token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with valid auth should succeed - headers = {"Authorization": f"Bearer {test_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - headers=headers, - ) - assert response.status_code == 200 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_missing_token(cleanup_auth_token_env): + + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": f"Bearer {cluster_info['token']}"} + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + headers=headers, + ) + + assert response.status_code == 200 + + +def test_dashboard_request_requires_auth_missing_token(setup_cluster_with_token_auth): """Test that requests fail without token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # GET without auth should fail with 401 - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - ) - assert response.status_code == 401 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_invalid_token(cleanup_auth_token_env): + + cluster_info = setup_cluster_with_token_auth + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + ) + + assert response.status_code == 401 + + +def test_dashboard_request_requires_auth_invalid_token(setup_cluster_with_token_auth): """Test that requests fail with invalid token when auth is enabled.""" - correct_token = "test_token_12345678901234567890123456789012" - wrong_token = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = correct_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with wrong token should fail with 403 - headers = {"Authorization": f"Bearer {wrong_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - headers=headers, - ) - assert response.status_code == 403 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_auth_disabled(cleanup_auth_token_env): - """Test that auth is not enforced when auth_mode is disabled.""" - os.environ["RAY_auth_mode"] = "disabled" - cluster = Cluster() - cluster.add_node() + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": "Bearer wrong_token_00000000000000000000000000000000"} + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + headers=headers, + ) + + assert response.status_code == 403 + + +def test_dashboard_auth_disabled(setup_cluster_without_token_auth): + """Test that auth is not enforced when auth_mode is disabled.""" - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] + cluster_info = setup_cluster_without_token_auth - # GET without auth should succeed when auth is disabled - response = requests.get( - f"http://{dashboard_url}/api/component_activities", json={"test": "data"} - ) - # Should not return 401 or 403 - assert response.status_code == 200 + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + ) - finally: - ray.shutdown() - cluster.shutdown() + assert response.status_code == 200 if __name__ == "__main__": diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index 2da02a85f065..c617d11e327e 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -1,6 +1,6 @@ -import os import tempfile from pathlib import Path +from typing import Optional import pytest @@ -9,9 +9,16 @@ HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, ) -from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import Cluster from ray.dashboard.modules.job.sdk import JobSubmissionClient +from ray.tests.authentication_test_utils import ( + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_auth_token_path, + set_default_auth_token, + set_env_auth_token, +) from ray.util.state import StateApiClient @@ -32,10 +39,9 @@ def test_submission_client_without_token_shows_helpful_error( ): """Test that requests without token show helpful error message.""" # Remove token from environment - os.environ.pop("RAY_AUTH_TOKEN", None) - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + clear_auth_token_sources(remove_default=True) + set_auth_mode("disabled") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -58,8 +64,9 @@ def test_submission_client_with_invalid_token_shows_helpful_error( """Test that requests with wrong token show helpful error message.""" # Set wrong token wrong_token = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_AUTH_TOKEN"] = wrong_token - AuthenticationTokenLoader.instance().reset_cache() + set_env_auth_token(wrong_token) + set_auth_mode("token") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -127,10 +134,9 @@ def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): def test_error_messages_contain_instructions(setup_cluster_with_token_auth): """Test that all auth error messages contain setup instructions.""" # Test 401 error (missing token) - os.environ.pop("RAY_AUTH_TOKEN", None) - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + clear_auth_token_sources(remove_default=True) + set_auth_mode("disabled") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -146,10 +152,9 @@ def test_error_messages_contain_instructions(setup_cluster_with_token_auth): assert str(exc_info.value) == expected_missing # Test 403 error (invalid token) - os.environ["RAY_AUTH_TOKEN"] = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_env_auth_token("wrong_token_00000000000000000000000000000000") + set_auth_mode("token") + reset_auth_token_state() client2 = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) @@ -168,24 +173,20 @@ def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): """Test that SubmissionClient loads tokens from all supported sources.""" test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" + set_auth_mode("token") - token_file_path = None - default_token_path = Path.home() / ".ray" / "auth_token" + token_file_path: Optional[Path] = None if token_source == "env_var": - os.environ["RAY_AUTH_TOKEN"] = test_token + set_env_auth_token(test_token) elif token_source == "token_path": - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: - tmp.write(test_token) - token_file_path = tmp.name - os.environ["RAY_AUTH_TOKEN_PATH"] = token_file_path + with tempfile.NamedTemporaryFile(delete=False) as tmp: + token_file_path = Path(tmp.name) + set_auth_token_path(test_token, token_file_path) else: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(test_token) + set_default_auth_token(test_token) - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -201,8 +202,8 @@ def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): finally: ray.shutdown() cluster.shutdown() - if token_file_path: - os.unlink(token_file_path) + if token_source == "token_path" and token_file_path: + token_file_path.unlink(missing_ok=True) def test_no_token_added_when_auth_disabled(setup_cluster_without_token_auth): diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index eea727927477..0168ab958b83 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -1,7 +1,6 @@ """Integration tests for token-based authentication in Ray.""" import os -import shutil import subprocess import sys from pathlib import Path @@ -11,12 +10,14 @@ import ray from ray._private.test_utils import wait_for_condition -from ray._raylet import AuthenticationTokenLoader, Config +from ray._raylet import AuthenticationTokenLoader from ray.cluster_utils import Cluster - - -def reset_token_cache(): - AuthenticationTokenLoader.instance().reset_cache() +from ray.tests.authentication_test_utils import ( + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) def _run_ray_start_and_verify_status( @@ -85,67 +86,16 @@ def ray_stopped(): @pytest.fixture(autouse=True) def clean_token_sources(cleanup_auth_token_env): - """Clean up all token sources before and after each test.""" - # This follows the same pattern as authentication_token_loader_test.cc - if "HOME" not in os.environ: - # Use TEST_TMPDIR if available (Bazel sets this), otherwise use system temp - test_tmpdir = os.environ.get("TEST_TMPDIR") - if test_tmpdir: - temp_home = os.path.join(test_tmpdir, "ray_test_home") - else: - temp_home = "/tmp/ray_test_home" - - # Create the directory if it doesn't exist - os.makedirs(temp_home, exist_ok=True) - os.environ["HOME"] = temp_home - home_was_set = False - else: - temp_home = None - home_was_set = True - - # Clean environment variables - env_vars_to_clean = [ - "RAY_AUTH_TOKEN", - "RAY_AUTH_TOKEN_PATH", - "RAY_auth_mode", - ] - original_values = {} - for var in env_vars_to_clean: - original_values[var] = os.environ.get(var) - if var in os.environ: - del os.environ[var] - - # Clean default token file - default_token_path = Path.home() / ".ray" / "auth_token" - original_exists = default_token_path.exists() - original_content = None - if original_exists: - original_content = default_token_path.read_text() - default_token_path.unlink() - - Config.initialize("") + """Ensure authentication-related state is clean around each test.""" - # Reset token caches (both Python and C++) - reset_token_cache() + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() yield - # Restore environment variables - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] - - # Restore default token file - if original_exists and original_content is not None: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(original_content) - if ray.is_initialized(): ray.shutdown() - # Ensure all ray processes are stopped subprocess.run( ["ray", "stop", "--force"], capture_output=True, @@ -153,21 +103,7 @@ def clean_token_sources(cleanup_auth_token_env): check=False, ) - # Reset token caches again after test - reset_token_cache() - Config.initialize("") - - # Clean up temporary HOME if we created one - # Only delete if we set it and it was temporary - if temp_home is not None and not home_was_set: - try: - if os.path.exists(temp_home): - shutil.rmtree(temp_home) - except Exception: - pass # Best effort cleanup - # Remove the HOME env var we set - if "HOME" in os.environ and os.environ["HOME"] == temp_home: - del os.environ["HOME"] + reset_auth_token_state() def test_local_cluster_generates_token(): @@ -179,8 +115,8 @@ def test_local_cluster_generates_token(): ), f"Token file already exists at {default_token_path}" # Enable token auth via environment variable - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_auth_mode("token") + reset_auth_token_state() # Initialize Ray with token auth ray.init() @@ -207,9 +143,9 @@ def test_connect_without_token_raises_error(): """Test ray.init(address=...) without token fails when auth_mode=token is set.""" # Set up a cluster with token auth enabled cluster_token = "testtoken12345678901234567890" - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() # Create cluster with token auth enabled cluster = Cluster() @@ -217,10 +153,9 @@ def test_connect_without_token_raises_error(): try: # Remove the token from the environment so we try to connect without it - os.environ["RAY_auth_mode"] = "disabled" - os.environ["RAY_AUTH_TOKEN"] = "" - Config.initialize("") - reset_token_cache() + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() # Ensure no token exists token_loader = AuthenticationTokenLoader.instance() @@ -239,9 +174,9 @@ def test_cluster_token_authentication(tokens_match): """Test cluster authentication with matching and non-matching tokens.""" # Set up cluster token first cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() # Create cluster with token auth enabled - node will read current env token cluster = Cluster() @@ -254,10 +189,10 @@ def test_cluster_token_authentication(tokens_match): else: client_token = "b" * 32 # Different token - should fail - os.environ["RAY_AUTH_TOKEN"] = client_token + set_env_auth_token(client_token) # Reset cached token so it reads the new environment variable - reset_token_cache() + reset_auth_token_state() if tokens_match: # Should succeed - test gRPC calls work @@ -312,9 +247,9 @@ def test_ray_start_without_token_raises_error(is_head): if not is_head: # Start head node with token cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -348,10 +283,9 @@ def test_ray_start_head_with_token_succeeds(): ) # Verify we can connect to the cluster with ray.init() - os.environ["RAY_AUTH_TOKEN"] = test_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") - reset_token_cache() + set_env_auth_token(test_token) + set_auth_mode("token") + reset_auth_token_state() # Wait for cluster to be ready def cluster_ready(): @@ -382,9 +316,9 @@ def test_ray_start_address_with_token(token_match): """Test ray start --address=... with correct or incorrect token.""" # Start a head node with token auth cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() cluster = Cluster() cluster.add_node(num_cpus=1) From 9e0e64f20e4ee993e07b4f986b16d1053e75684b Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 05:42:08 +0000 Subject: [PATCH 65/94] fix lint Signed-off-by: sampan --- python/ray/tests/conftest.py | 2 +- python/ray/tests/test_dashboard_auth.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 8bbd0ec04717..ca7c45fd1dbf 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -65,7 +65,7 @@ def cleanup_auth_token_env(): """Reset authentication environment variables, files, and caches.""" - with authentication_env_guard() as snapshot: + with authentication_env_guard(): clear_auth_token_sources(remove_default=True) reset_auth_token_state() yield diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 635999945ea4..7407fc199a1d 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -1,5 +1,8 @@ """Tests for dashboard token authentication.""" +import sys + +import pytest import requests @@ -61,6 +64,5 @@ def test_dashboard_auth_disabled(setup_cluster_without_token_auth): if __name__ == "__main__": - import sys sys.exit(pytest.main(["-vv", __file__])) From bc4b5637fa7f8130e0e636bc764e274b2e817cb7 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 06:28:26 +0000 Subject: [PATCH 66/94] Add token auth support for runtime env agent Signed-off-by: sampan --- .../http_token_authentication.py | 92 +++++++++++ python/ray/_private/runtime_env/agent/main.py | 7 +- python/ray/dashboard/http_server_head.py | 32 +--- python/ray/dashboard/modules/dashboard_sdk.py | 38 ++--- .../ray/includes/rpc_token_authentication.pxi | 2 +- python/ray/tests/BUILD.bazel | 2 + python/ray/tests/authentication_test_utils.py | 151 ++++++++++++++++++ python/ray/tests/conftest.py | 52 +++--- python/ray/tests/test_dashboard_auth.py | 146 ++++++----------- .../ray/tests/test_runtime_env_agent_auth.py | 126 +++++++++++++++ .../ray/tests/test_submission_client_auth.py | 61 +++---- .../ray/tests/test_token_auth_integration.py | 132 ++++----------- python/ray/util/client/server/proxier.py | 27 +++- src/ray/raylet/runtime_env_agent_client.cc | 6 + .../tests/runtime_env_agent_client_test.cc | 76 +++++++++ 15 files changed, 630 insertions(+), 320 deletions(-) create mode 100644 python/ray/_private/authentication/http_token_authentication.py create mode 100644 python/ray/tests/authentication_test_utils.py create mode 100644 python/ray/tests/test_runtime_env_agent_auth.py diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py new file mode 100644 index 000000000000..9813520b4521 --- /dev/null +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -0,0 +1,92 @@ +import logging +from typing import Dict, Optional + +from aiohttp import web + +from ray._private.authentication import authentication_constants + + +def create_token_authentication_middleware() -> web.middleware: + """Return an aiohttp middleware that validates bearer tokens when enabled.""" + + from ray.dashboard import authentication_utils as auth_utils + + @web.middleware + async def auth_middleware(request: web.Request, handler): + if not auth_utils.is_token_auth_enabled(): + return await handler(request) + + auth_header = request.headers.get("Authorization", "") + if not auth_header: + return web.Response( + status=401, text="Unauthorized: Missing authentication token" + ) + + if not auth_utils.validate_request_token(auth_header): + return web.Response( + status=403, text="Forbidden: Invalid authentication token" + ) + + return await handler(request) + + return auth_middleware + + +def apply_token_if_enabled( + headers: Dict[str, str], logger: Optional[logging.Logger] = None +) -> bool: + """Inject Authorization header when token auth is enabled. + + Args: + headers: Mutable mapping of HTTP headers. Updated in place. + logger: Optional logger used for warning when token is missing. + + Returns: + bool: True if the token was added to headers, False otherwise. + """ + + if headers is None: + raise ValueError("headers must be provided") + + if "Authorization" in headers: + return False + + from ray.dashboard import authentication_utils as auth_utils + from ray._raylet import AuthenticationTokenLoader + + if not auth_utils.is_token_auth_enabled(): + return False + + token_loader = AuthenticationTokenLoader.instance() + token_added = token_loader.set_token_for_http_header(headers) + + if not token_added and logger is not None: + logger.warning( + "Token authentication is enabled but no token was found. " + "Requests to authenticated clusters will fail." + ) + + return token_added + + +def format_authentication_http_error(status: int, body: str) -> Optional[str]: + """Return a user-friendly authentication error message, if applicable.""" + + if status == 401: + return ( + "Authentication required: {body}\n\n{details}".format( + body=body, + details=authentication_constants.HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, + ) + ) + + if status == 403: + return ( + "Authentication failed: {body}\n\n{details}".format( + body=body, + details=authentication_constants.HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, + ) + ) + + return None + diff --git a/python/ray/_private/runtime_env/agent/main.py b/python/ray/_private/runtime_env/agent/main.py index e65de4d63bd4..1049e7a933ad 100644 --- a/python/ray/_private/runtime_env/agent/main.py +++ b/python/ray/_private/runtime_env/agent/main.py @@ -25,6 +25,9 @@ def import_libs(): import runtime_env_consts # noqa: E402 from aiohttp import web # noqa: E402 +from ray._private.authentication.http_token_authentication import ( + create_token_authentication_middleware, +) from runtime_env_agent import RuntimeEnvAgent # noqa: E402 if __name__ == "__main__": @@ -194,7 +197,9 @@ async def get_runtime_envs_info(request: web.Request) -> web.Response: body=reply.SerializeToString(), content_type="application/octet-stream" ) - app = web.Application() + app = web.Application( + middlewares=[create_token_authentication_middleware()] + ) app.router.add_post("/get_or_create_runtime_env", get_or_create_runtime_env) app.router.add_post( diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 5f7054900c18..b8123f23bc85 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -20,9 +20,11 @@ from ray._common.network_utils import build_address, parse_address from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop -from ray.dashboard import authentication_utils as auth_utils from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics from ray.dashboard.head import DashboardHeadModule +from ray._private.authentication.http_token_authentication import ( + create_token_authentication_middleware, +) # All third-party dependencies that are not included in the minimal Ray # installation must be included in this file. This allows us to determine if @@ -164,30 +166,6 @@ def get_address(self): assert self.http_host and self.http_port return self.http_host, self.http_port - @aiohttp.web.middleware - async def auth_middleware(self, request, handler): - """Authenticate requests when token auth is enabled.""" - - # Skip if auth not enabled - if not auth_utils.is_token_auth_enabled(): - return await handler(request) - - # Extract and validate token - auth_header = request.headers.get("Authorization", "") - - if not auth_header: - return aiohttp.web.Response( - status=401, text="Unauthorized: Missing authentication token" - ) - - # Validate token - if not auth_utils.validate_request_token(auth_header): - return aiohttp.web.Response( - status=403, text="Forbidden: Invalid authentication token" - ) - - return await handler(request) - @aiohttp.web.middleware async def path_clean_middleware(self, request, handler): if request.path.startswith("/static") or request.path.startswith("/logs"): @@ -273,11 +251,13 @@ async def run( # Http server should be initialized after all modules loaded. # working_dir uploads for job submission can be up to 100MiB. + token_auth_middleware = create_token_authentication_middleware() + app = aiohttp.web.Application( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, middlewares=[ self.metrics_middleware, - self.auth_middleware, + token_auth_middleware, self.path_clean_middleware, self.browsers_no_post_put_middleware, self.cache_control_static_middleware, diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index 2cdf738590a2..d45a9259f19b 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -12,7 +12,10 @@ import yaml import ray -from ray._private.authentication import authentication_constants +from ray._private.authentication.http_token_authentication import ( + apply_token_if_enabled, + format_authentication_http_error, +) from ray._private.runtime_env.packaging import ( create_package, get_uri_for_directory, @@ -21,9 +24,7 @@ from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed from ray._private.utils import split_address -from ray._raylet import AuthenticationTokenLoader from ray.autoscaler._private.cli_logger import cli_logger -from ray.dashboard.authentication_utils import is_token_auth_enabled from ray.dashboard.modules.job.common import uri_to_http_components from ray.util.annotations import DeveloperAPI, PublicAPI @@ -251,19 +252,7 @@ def __init__( def _set_auth_header_if_enabled(self): """Add authentication token to headers if token auth is enabled.""" - if is_token_auth_enabled(): - token_loader = AuthenticationTokenLoader.instance() - token_added = token_loader.set_token_for_http_header(self._headers) - - if not token_added: - # Token auth is enabled but no token found or Authorization already set - if "Authorization" not in self._headers: - # No token found - log warning but don't fail yet - # Let the server return 401 for a better error message - logger.warning( - "Token authentication is enabled but no token was found. " - "Requests to authenticated clusters will fail." - ) + apply_token_if_enabled(self._headers, logger) def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None @@ -336,18 +325,11 @@ def _do_request( ) # Check for authentication errors and provide helpful messages - if response.status_code == 401: - # Unauthorized - missing or no token provided - raise RuntimeError( - f"Authentication required: {response.text}\n\n" - + authentication_constants.HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE - ) - elif response.status_code == 403: - # Forbidden - invalid token - raise RuntimeError( - f"Authentication failed: {response.text}\n\n" - + authentication_constants.HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE - ) + formatted_error = format_authentication_http_error( + response.status_code, response.text + ) + if formatted_error: + raise RuntimeError(formatted_error) return response diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index 3c030680baab..beda93da86cd 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -97,7 +97,7 @@ class AuthenticationTokenLoader: # Get the token from C++ layer cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken() - if not token_opt.has_value() || token_opt.value().empty(): + if not token_opt.has_value() or token_opt.value().empty(): return False headers["Authorization"] = token_opt.value().ToAuthorizationHeaderValue() diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 3b11301a5318..e2f9ccddc3d9 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -68,6 +68,7 @@ py_test_module_list( "test_reference_counting_2.py", "test_reference_counting_standalone.py", "test_runtime_env_agent.py", + "test_runtime_env_agent_auth.py", "test_util_helpers.py", ], tags = [ @@ -589,6 +590,7 @@ py_test_module_list( "test_wait.py", "test_widgets.py", "test_worker_graceful_shutdown.py", + "test_submission_client_auth.py" ], tags = [ "exclusive", diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py new file mode 100644 index 000000000000..63bc2ea519c6 --- /dev/null +++ b/python/ray/tests/authentication_test_utils.py @@ -0,0 +1,151 @@ +import os +import shutil +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + +from ray._raylet import AuthenticationTokenLoader, Config + + +AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") +DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" + + +def reset_auth_token_state() -> None: + """Reset authentication state in both Python and C++ layers.""" + + AuthenticationTokenLoader.instance().reset_cache() + Config.initialize("") + + +def set_auth_mode(mode: str) -> None: + """Set the authentication mode environment variable.""" + + os.environ["RAY_auth_mode"] = mode + + +def set_env_auth_token(token: str) -> None: + """Configure the authentication token via environment variable.""" + + os.environ["RAY_AUTH_TOKEN"] = token + os.environ.pop("RAY_AUTH_TOKEN_PATH", None) + + +def set_auth_token_path(token: str, path: Path) -> None: + """Write the authentication token to a specific path and point the loader to it.""" + + token_path = Path(path) + token_path.parent.mkdir(parents=True, exist_ok=True) + token_path.write_text(token) + os.environ["RAY_AUTH_TOKEN_PATH"] = str(token_path) + os.environ.pop("RAY_AUTH_TOKEN", None) + + +def set_default_auth_token(token: str) -> Path: + """Write the authentication token to the default ~/.ray/auth_token location.""" + + default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path.parent.mkdir(parents=True, exist_ok=True) + default_path.write_text(token) + return default_path + + +def clear_auth_token_sources(remove_default: bool = False) -> None: + """Clear authentication-related environment variables and optional default token file.""" + + for var in ("RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"): + os.environ.pop(var, None) + + if remove_default: + default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path.unlink(missing_ok=True) + + +@dataclass +class AuthenticationEnvSnapshot: + original_env: Dict[str, Optional[str]] + original_home: Optional[str] + home_was_set: bool + temp_home: Optional[Path] + default_token_path: Path + default_token_exists: bool + default_token_contents: Optional[str] + + @classmethod + def capture(cls) -> "AuthenticationEnvSnapshot": + """Capture current authentication-related environment state.""" + + original_env = {var: os.environ.get(var) for var in AUTH_ENV_VARS} + home_was_set = "HOME" in os.environ + original_home = os.environ.get("HOME") + temp_home: Optional[Path] = None + + if not home_was_set: + test_tmpdir = os.environ.get("TEST_TMPDIR") + base_dir = Path(test_tmpdir) if test_tmpdir else Path(tempfile.gettempdir()) + temp_home = base_dir / "ray_test_home" + temp_home.mkdir(parents=True, exist_ok=True) + os.environ["HOME"] = str(temp_home) + + default_token_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_token_exists = default_token_path.exists() + default_token_contents = ( + default_token_path.read_text() if default_token_exists else None + ) + + return cls( + original_env=original_env, + original_home=original_home, + home_was_set=home_was_set, + temp_home=temp_home, + default_token_path=default_token_path, + default_token_exists=default_token_exists, + default_token_contents=default_token_contents, + ) + + def clear_default_token(self) -> None: + """Remove the default token file for the current HOME.""" + + self.default_token_path.unlink(missing_ok=True) + + def restore(self) -> None: + """Restore the captured environment, HOME, and default token file state.""" + + for var, value in self.original_env.items(): + if value is None: + os.environ.pop(var, None) + else: + os.environ[var] = value + + if self.home_was_set: + if self.original_home is None: + os.environ.pop("HOME", None) + else: + os.environ["HOME"] = self.original_home + + if self.default_token_exists: + self.default_token_path.parent.mkdir(parents=True, exist_ok=True) + self.default_token_path.write_text(self.default_token_contents or "") + else: + self.default_token_path.unlink(missing_ok=True) + + if not self.home_was_set: + current_home = os.environ.get("HOME") + if self.temp_home is not None and current_home == str(self.temp_home): + os.environ.pop("HOME", None) + if self.temp_home is not None and self.temp_home.exists(): + shutil.rmtree(self.temp_home, ignore_errors=True) + + +@contextmanager +def authentication_env_guard(): + """Context manager that restores authentication environment state on exit.""" + + snapshot = AuthenticationEnvSnapshot.capture() + try: + yield snapshot + finally: + snapshot.restore() + diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 89cadc96ee73..f90018692a1e 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -44,6 +44,13 @@ ) from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import AutoscalingCluster, Cluster, cluster_not_supported +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) import psutil @@ -57,35 +64,13 @@ @pytest.fixture def cleanup_auth_token_env(): - """Reset Ray authentication-related environment variables and caches.""" - - env_vars = ["RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] - original_env = {var: os.environ.get(var) for var in env_vars} - - default_token_path = Path.home() / ".ray" / "auth_token" - token_file_exists = default_token_path.exists() - token_file_contents = default_token_path.read_text() if token_file_exists else None - - AuthenticationTokenLoader.instance().reset_cache() - Config.initialize("") + """Reset authentication environment variables, files, and caches.""" - try: + with authentication_env_guard() as snapshot: + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() yield - finally: - for var, value in original_env.items(): - if value is None: - os.environ.pop(var, None) - else: - os.environ[var] = value - - if token_file_exists: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(token_file_contents) - else: - default_token_path.unlink(missing_ok=True) - - AuthenticationTokenLoader.instance().reset_cache() - Config.initialize("") + reset_auth_token_state() @pytest.fixture @@ -93,10 +78,9 @@ def setup_cluster_with_token_auth(cleanup_auth_token_env): """Spin up a Ray cluster with token authentication enabled.""" test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_auth_mode("token") + set_env_auth_token(test_token) + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -118,9 +102,9 @@ def setup_cluster_with_token_auth(cleanup_auth_token_env): def setup_cluster_without_token_auth(cleanup_auth_token_env): """Spin up a Ray cluster with authentication disabled.""" - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() cluster = Cluster() cluster.add_node() diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 5f1615573f0c..02ac47b213de 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -1,115 +1,61 @@ """Tests for dashboard token authentication.""" -import os - -import pytest import requests -import ray -from ray._raylet import Config -from ray.cluster_utils import Cluster - -def test_dashboard_request_requires_auth_with_valid_token(cleanup_auth_token_env): +def test_dashboard_request_requires_auth_with_valid_token(setup_cluster_with_token_auth): """Test that requests succeed with valid token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with valid auth should succeed - headers = {"Authorization": f"Bearer {test_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - headers=headers, - ) - assert response.status_code == 200 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_missing_token(cleanup_auth_token_env): + + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": f"Bearer {cluster_info['token']}"} + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + headers=headers, + ) + + assert response.status_code == 200 + + +def test_dashboard_request_requires_auth_missing_token(setup_cluster_with_token_auth): """Test that requests fail without token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # GET without auth should fail with 401 - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - ) - assert response.status_code == 401 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_invalid_token(cleanup_auth_token_env): + + cluster_info = setup_cluster_with_token_auth + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + ) + + assert response.status_code == 401 + + +def test_dashboard_request_requires_auth_invalid_token(setup_cluster_with_token_auth): """Test that requests fail with invalid token when auth is enabled.""" - correct_token = "test_token_12345678901234567890123456789012" - wrong_token = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = correct_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with wrong token should fail with 403 - headers = {"Authorization": f"Bearer {wrong_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - headers=headers, - ) - assert response.status_code == 403 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_auth_disabled(cleanup_auth_token_env): - """Test that auth is not enforced when auth_mode is disabled.""" - os.environ["RAY_auth_mode"] = "disabled" - cluster = Cluster() - cluster.add_node() + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": "Bearer wrong_token_00000000000000000000000000000000"} + + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + headers=headers, + ) + + assert response.status_code == 403 + + +def test_dashboard_auth_disabled(setup_cluster_without_token_auth): + """Test that auth is not enforced when auth_mode is disabled.""" - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] + cluster_info = setup_cluster_without_token_auth - # GET without auth should succeed when auth is disabled - response = requests.get( - f"http://{dashboard_url}/api/component_activities", json={"test": "data"} - ) - # Should not return 401 or 403 - assert response.status_code == 200 + response = requests.get( + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, + ) - finally: - ray.shutdown() - cluster.shutdown() + assert response.status_code == 200 if __name__ == "__main__": diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py new file mode 100644 index 000000000000..1a84cbdbd3b2 --- /dev/null +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -0,0 +1,126 @@ +import logging +import urllib.error +import urllib.parse +import urllib.request + +import pytest + +import ray +from ray._private.authentication.http_token_authentication import ( + apply_token_if_enabled, + format_authentication_http_error, +) +from ray.core.generated import runtime_env_agent_pb2 +from ray.tests.authentication_test_utils import reset_auth_token_state, set_auth_mode, set_env_auth_token + + +def _agent_url(agent_address: str, path: str) -> str: + return urllib.parse.urljoin(agent_address, path) + + +def _make_get_or_create_request() -> runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest: + request = runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest() + request.job_id = b"ray_client_test" + request.serialized_runtime_env = "{}" + request.runtime_env_config.setup_timeout_seconds = 1 + request.source_process = "pytest" + return request + + +def test_runtime_env_agent_requires_auth_missing_token(setup_cluster_with_token_auth): + agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + request = _make_get_or_create_request() + + with pytest.raises(urllib.error.HTTPError) as exc_info: + urllib.request.urlopen( # noqa: S310 - test controlled + urllib.request.Request( + _agent_url(agent_address, "/get_or_create_runtime_env"), + data=request.SerializeToString(), + headers={"Content-Type": "application/octet-stream"}, + method="POST", + ), + timeout=5, + ) + + assert exc_info.value.code == 401 + body = exc_info.value.read().decode("utf-8", "ignore") + assert "Missing authentication token" in body + formatted = format_authentication_http_error(401, body) + assert formatted.startswith("Authentication required") + + +def test_runtime_env_agent_rejects_invalid_token(setup_cluster_with_token_auth): + agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + request = _make_get_or_create_request() + + with pytest.raises(urllib.error.HTTPError) as exc_info: + urllib.request.urlopen( # noqa: S310 - test controlled + urllib.request.Request( + _agent_url(agent_address, "/get_or_create_runtime_env"), + data=request.SerializeToString(), + headers={ + "Content-Type": "application/octet-stream", + "Authorization": "Bearer wrong_token", + }, + method="POST", + ), + timeout=5, + ) + + assert exc_info.value.code == 403 + body = exc_info.value.read().decode("utf-8", "ignore") + assert "Invalid authentication token" in body + formatted = format_authentication_http_error(403, body) + assert formatted.startswith("Authentication failed") + + +def test_runtime_env_agent_accepts_valid_token(setup_cluster_with_token_auth): + agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + token = setup_cluster_with_token_auth["token"] + request = _make_get_or_create_request() + + with urllib.request.urlopen( # noqa: S310 - test controlled + urllib.request.Request( + _agent_url(agent_address, "/get_or_create_runtime_env"), + data=request.SerializeToString(), + headers={ + "Content-Type": "application/octet-stream", + "Authorization": f"Bearer {token}", + }, + method="POST", + ), + timeout=5, + ) as response: + reply = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply() + reply.ParseFromString(response.read()) + assert reply.status == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK + + +def test_apply_token_if_enabled_adds_header(cleanup_auth_token_env): + set_auth_mode("token") + set_env_auth_token("apptoken1234567890") + reset_auth_token_state() + + headers = {} + added = apply_token_if_enabled(headers, logging.getLogger(__name__)) + + assert added is True + assert headers["Authorization"] == "Bearer apptoken1234567890" + + +def test_apply_token_if_enabled_respects_existing_header(cleanup_auth_token_env): + set_auth_mode("token") + set_env_auth_token("apptoken1234567890") + reset_auth_token_state() + + headers = {"Authorization": "Bearer custom"} + added = apply_token_if_enabled(headers, logging.getLogger(__name__)) + + assert added is False + assert headers["Authorization"] == "Bearer custom" + + +def test_format_authentication_http_error_non_auth_status(): + assert format_authentication_http_error(404, "not found") is None + + diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index 2da02a85f065..a136618ca6b4 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -1,6 +1,7 @@ import os import tempfile from pathlib import Path +from typing import Optional import pytest @@ -9,10 +10,17 @@ HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, ) -from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import Cluster from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.util.state import StateApiClient +from ray.tests.authentication_test_utils import ( + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_auth_token_path, + set_default_auth_token, + set_env_auth_token, +) def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): @@ -32,10 +40,9 @@ def test_submission_client_without_token_shows_helpful_error( ): """Test that requests without token show helpful error message.""" # Remove token from environment - os.environ.pop("RAY_AUTH_TOKEN", None) - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + clear_auth_token_sources(remove_default=True) + set_auth_mode("disabled") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -58,8 +65,9 @@ def test_submission_client_with_invalid_token_shows_helpful_error( """Test that requests with wrong token show helpful error message.""" # Set wrong token wrong_token = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_AUTH_TOKEN"] = wrong_token - AuthenticationTokenLoader.instance().reset_cache() + set_env_auth_token(wrong_token) + set_auth_mode("token") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -127,10 +135,9 @@ def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): def test_error_messages_contain_instructions(setup_cluster_with_token_auth): """Test that all auth error messages contain setup instructions.""" # Test 401 error (missing token) - os.environ.pop("RAY_AUTH_TOKEN", None) - os.environ["RAY_auth_mode"] = "disabled" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + clear_auth_token_sources(remove_default=True) + set_auth_mode("disabled") + reset_auth_token_state() from ray.dashboard.modules.dashboard_sdk import SubmissionClient @@ -146,10 +153,9 @@ def test_error_messages_contain_instructions(setup_cluster_with_token_auth): assert str(exc_info.value) == expected_missing # Test 403 error (invalid token) - os.environ["RAY_AUTH_TOKEN"] = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + set_env_auth_token("wrong_token_00000000000000000000000000000000") + set_auth_mode("token") + reset_auth_token_state() client2 = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) @@ -168,24 +174,21 @@ def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): """Test that SubmissionClient loads tokens from all supported sources.""" test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" + set_auth_mode("token") - token_file_path = None - default_token_path = Path.home() / ".ray" / "auth_token" + token_file_path: Optional[Path] = None if token_source == "env_var": - os.environ["RAY_AUTH_TOKEN"] = test_token + set_env_auth_token(test_token) elif token_source == "token_path": - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: - tmp.write(test_token) - token_file_path = tmp.name - os.environ["RAY_AUTH_TOKEN_PATH"] = token_file_path + with tempfile.NamedTemporaryFile(delete=False) as tmp: + token_file_path = Path(tmp.name) + set_auth_token_path(test_token, token_file_path) else: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(test_token) + token_file_path = Path.home() / ".ray" / "auth_token" + set_default_auth_token(test_token) - Config.initialize("") - AuthenticationTokenLoader.instance().reset_cache() + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -201,8 +204,8 @@ def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): finally: ray.shutdown() cluster.shutdown() - if token_file_path: - os.unlink(token_file_path) + if token_source == "token_path" and token_file_path: + token_file_path.unlink(missing_ok=True) def test_no_token_added_when_auth_disabled(setup_cluster_without_token_auth): diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index eea727927477..0168ab958b83 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -1,7 +1,6 @@ """Integration tests for token-based authentication in Ray.""" import os -import shutil import subprocess import sys from pathlib import Path @@ -11,12 +10,14 @@ import ray from ray._private.test_utils import wait_for_condition -from ray._raylet import AuthenticationTokenLoader, Config +from ray._raylet import AuthenticationTokenLoader from ray.cluster_utils import Cluster - - -def reset_token_cache(): - AuthenticationTokenLoader.instance().reset_cache() +from ray.tests.authentication_test_utils import ( + clear_auth_token_sources, + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) def _run_ray_start_and_verify_status( @@ -85,67 +86,16 @@ def ray_stopped(): @pytest.fixture(autouse=True) def clean_token_sources(cleanup_auth_token_env): - """Clean up all token sources before and after each test.""" - # This follows the same pattern as authentication_token_loader_test.cc - if "HOME" not in os.environ: - # Use TEST_TMPDIR if available (Bazel sets this), otherwise use system temp - test_tmpdir = os.environ.get("TEST_TMPDIR") - if test_tmpdir: - temp_home = os.path.join(test_tmpdir, "ray_test_home") - else: - temp_home = "/tmp/ray_test_home" - - # Create the directory if it doesn't exist - os.makedirs(temp_home, exist_ok=True) - os.environ["HOME"] = temp_home - home_was_set = False - else: - temp_home = None - home_was_set = True - - # Clean environment variables - env_vars_to_clean = [ - "RAY_AUTH_TOKEN", - "RAY_AUTH_TOKEN_PATH", - "RAY_auth_mode", - ] - original_values = {} - for var in env_vars_to_clean: - original_values[var] = os.environ.get(var) - if var in os.environ: - del os.environ[var] - - # Clean default token file - default_token_path = Path.home() / ".ray" / "auth_token" - original_exists = default_token_path.exists() - original_content = None - if original_exists: - original_content = default_token_path.read_text() - default_token_path.unlink() - - Config.initialize("") + """Ensure authentication-related state is clean around each test.""" - # Reset token caches (both Python and C++) - reset_token_cache() + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() yield - # Restore environment variables - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] - - # Restore default token file - if original_exists and original_content is not None: - default_token_path.parent.mkdir(parents=True, exist_ok=True) - default_token_path.write_text(original_content) - if ray.is_initialized(): ray.shutdown() - # Ensure all ray processes are stopped subprocess.run( ["ray", "stop", "--force"], capture_output=True, @@ -153,21 +103,7 @@ def clean_token_sources(cleanup_auth_token_env): check=False, ) - # Reset token caches again after test - reset_token_cache() - Config.initialize("") - - # Clean up temporary HOME if we created one - # Only delete if we set it and it was temporary - if temp_home is not None and not home_was_set: - try: - if os.path.exists(temp_home): - shutil.rmtree(temp_home) - except Exception: - pass # Best effort cleanup - # Remove the HOME env var we set - if "HOME" in os.environ and os.environ["HOME"] == temp_home: - del os.environ["HOME"] + reset_auth_token_state() def test_local_cluster_generates_token(): @@ -179,8 +115,8 @@ def test_local_cluster_generates_token(): ), f"Token file already exists at {default_token_path}" # Enable token auth via environment variable - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_auth_mode("token") + reset_auth_token_state() # Initialize Ray with token auth ray.init() @@ -207,9 +143,9 @@ def test_connect_without_token_raises_error(): """Test ray.init(address=...) without token fails when auth_mode=token is set.""" # Set up a cluster with token auth enabled cluster_token = "testtoken12345678901234567890" - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() # Create cluster with token auth enabled cluster = Cluster() @@ -217,10 +153,9 @@ def test_connect_without_token_raises_error(): try: # Remove the token from the environment so we try to connect without it - os.environ["RAY_auth_mode"] = "disabled" - os.environ["RAY_AUTH_TOKEN"] = "" - Config.initialize("") - reset_token_cache() + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() # Ensure no token exists token_loader = AuthenticationTokenLoader.instance() @@ -239,9 +174,9 @@ def test_cluster_token_authentication(tokens_match): """Test cluster authentication with matching and non-matching tokens.""" # Set up cluster token first cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() # Create cluster with token auth enabled - node will read current env token cluster = Cluster() @@ -254,10 +189,10 @@ def test_cluster_token_authentication(tokens_match): else: client_token = "b" * 32 # Different token - should fail - os.environ["RAY_AUTH_TOKEN"] = client_token + set_env_auth_token(client_token) # Reset cached token so it reads the new environment variable - reset_token_cache() + reset_auth_token_state() if tokens_match: # Should succeed - test gRPC calls work @@ -312,9 +247,9 @@ def test_ray_start_without_token_raises_error(is_head): if not is_head: # Start head node with token cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() cluster = Cluster() cluster.add_node() @@ -348,10 +283,9 @@ def test_ray_start_head_with_token_succeeds(): ) # Verify we can connect to the cluster with ray.init() - os.environ["RAY_AUTH_TOKEN"] = test_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") - reset_token_cache() + set_env_auth_token(test_token) + set_auth_mode("token") + reset_auth_token_state() # Wait for cluster to be ready def cluster_ready(): @@ -382,9 +316,9 @@ def test_ray_start_address_with_token(token_match): """Test ray start --address=... with correct or incorrect token.""" # Start a head node with token auth cluster_token = "a" * 32 - os.environ["RAY_AUTH_TOKEN"] = cluster_token - os.environ["RAY_auth_mode"] = "token" - Config.initialize("") + set_env_auth_token(cluster_token) + set_auth_mode("token") + reset_auth_token_state() cluster = Cluster() cluster.add_node(num_cpus=1) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 7bc959e3df17..34b360409332 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -20,6 +20,10 @@ import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 from ray._common.network_utils import build_address, is_localhost from ray._private.client_mode_hook import disable_client_hook +from ray._private.authentication.http_token_authentication import ( + apply_token_if_enabled, + format_authentication_http_error, +) from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server @@ -246,8 +250,9 @@ def _create_runtime_env( self._runtime_env_agent_address, "/get_or_create_runtime_env" ) data = create_env_request.SerializeToString() - req = urllib.request.Request(url, data=data, method="POST") - req.add_header("Content-Type", "application/octet-stream") + headers = {"Content-Type": "application/octet-stream"} + apply_token_if_enabled(headers, logger) + req = urllib.request.Request(url, data=data, method="POST", headers=headers) response = urllib.request.urlopen(req, timeout=None) response_data = response.read() r = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply() @@ -265,6 +270,24 @@ def _create_runtime_env( ) else: assert False, f"Unknown status: {r.status}." + except urllib.error.HTTPError as e: + body = "" + try: + body = e.read().decode("utf-8", "ignore") + except Exception: + body = e.reason if hasattr(e, "reason") else str(e) + + formatted_error = format_authentication_http_error(e.code, body or "") + if formatted_error: + raise RuntimeError(formatted_error) from e + + last_exception = e + logger.warning( + f"GetOrCreateRuntimeEnv request failed: {e}. " + f"Retrying after {wait_time_s}s. " + f"{max_retries-retries} retries remaining." + ) + except urllib.error.URLError as e: last_exception = e logger.warning( diff --git a/src/ray/raylet/runtime_env_agent_client.cc b/src/ray/raylet/runtime_env_agent_client.cc index 4934f3da36cb..2c7a5d9bda94 100644 --- a/src/ray/raylet/runtime_env_agent_client.cc +++ b/src/ray/raylet/runtime_env_agent_client.cc @@ -28,6 +28,7 @@ #include "absl/strings/str_format.h" #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/util/logging.h" #include "ray/util/process.h" #include "ray/util/time.h" @@ -128,6 +129,11 @@ class Session : public std::enable_shared_from_this { req_.set(http::field::content_type, "application/octet-stream"); // Sets Content-Length header. req_.prepare_payload(); + + auto auth_token = rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + req_.set(http::field::authorization, auth_token->ToAuthorizationHeaderValue()); + } } void Failed(ray::Status status) { diff --git a/src/ray/raylet/tests/runtime_env_agent_client_test.cc b/src/ray/raylet/tests/runtime_env_agent_client_test.cc index e51d1d6c757f..197cac10a33a 100644 --- a/src/ray/raylet/tests/runtime_env_agent_client_test.cc +++ b/src/ray/raylet/tests/runtime_env_agent_client_test.cc @@ -14,6 +14,7 @@ #include "ray/raylet/runtime_env_agent_client.h" #include +#include #include #include #include @@ -30,6 +31,8 @@ #include "gtest/gtest.h" #include "ray/common/asio/asio_util.h" #include "ray/common/id.h" +#include "ray/common/ray_config.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "src/ray/protobuf/runtime_env_agent.pb.h" namespace ray { @@ -190,6 +193,10 @@ delay_after(instrumented_io_context &ioc) { auto dummy_shutdown_raylet_gracefully = [](const rpc::NodeDeathInfo &) {}; TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvOK) { + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + unsetenv("RAY_AUTH_TOKEN"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + int port = GetFreePort(); HttpServerThread http_server_thread( [](const http::request &request, @@ -199,6 +206,7 @@ TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvOK) { ASSERT_EQ(req.job_id(), "7b000000"); // Hex 7B == Int 123 ASSERT_EQ(req.runtime_env_config().setup_timeout_seconds(), 12); ASSERT_EQ(req.serialized_runtime_env(), "serialized_runtime_env"); + ASSERT_EQ(request.find(http::field::authorization), request.end()); rpc::GetOrCreateRuntimeEnvReply reply; reply.set_status(rpc::AGENT_RPC_STATUS_OK); @@ -356,6 +364,74 @@ TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvRetriesOnServerNotStarted) ASSERT_EQ(called_times, 1); } +TEST(RuntimeEnvAgentClientTest, AttachesAuthHeaderWhenEnabled) { + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + setenv("RAY_AUTH_TOKEN", "header_token", 1); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + + int port = GetFreePort(); + std::string observed_auth_header; + + HttpServerThread http_server_thread( + [&observed_auth_header](const http::request &request, + http::response &response) { + rpc::GetOrCreateRuntimeEnvRequest req; + ASSERT_TRUE(req.ParseFromString(request.body())); + auto it = request.find(http::field::authorization); + if (it != request.end()) { + observed_auth_header = std::string(it->value()); + } + + rpc::GetOrCreateRuntimeEnvReply reply; + reply.set_status(rpc::AGENT_RPC_STATUS_OK); + reply.set_serialized_runtime_env_context("serialized_runtime_env_context"); + response.body() = reply.SerializeAsString(); + response.content_length(response.body().size()); + response.result(http::status::ok); + }, + "127.0.0.1", + port); + http_server_thread.start(); + + instrumented_io_context ioc; + + auto client = + raylet::RuntimeEnvAgentClient::Create(ioc, + "127.0.0.1", + port, + delay_after(ioc), + dummy_shutdown_raylet_gracefully, + /*agent_register_timeout_ms=*/10000, + /*agent_manager_retry_interval_ms=*/100); + + auto job_id = JobID::FromInt(123); + std::string serialized_runtime_env = "serialized_runtime_env"; + ray::rpc::RuntimeEnvConfig runtime_env_config; + runtime_env_config.set_setup_timeout_seconds(12); + + size_t called_times = 0; + auto callback = [&](bool successful, + const std::string &serialized_runtime_env_context, + const std::string &setup_error_message) { + ASSERT_TRUE(successful); + ASSERT_EQ(serialized_runtime_env_context, "serialized_runtime_env_context"); + ASSERT_TRUE(setup_error_message.empty()); + called_times += 1; + }; + + client->GetOrCreateRuntimeEnv( + job_id, serialized_runtime_env, runtime_env_config, callback); + + ioc.run(); + + ASSERT_EQ(called_times, 1); + ASSERT_EQ(observed_auth_header, "Bearer header_token"); + + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + unsetenv("RAY_AUTH_TOKEN"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); +} + TEST(RuntimeEnvAgentClientTest, DeleteRuntimeEnvIfPossibleOK) { int port = GetFreePort(); HttpServerThread http_server_thread( From 8c371712e992534c7e4d6dce627740dc4fdad84c Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 06:35:39 +0000 Subject: [PATCH 67/94] fix lint issues Signed-off-by: sampan --- .../http_token_authentication.py | 19 +++++++------------ python/ray/_private/runtime_env/agent/main.py | 10 ++++------ python/ray/dashboard/http_server_head.py | 4 ++-- python/ray/tests/BUILD.bazel | 2 +- python/ray/tests/authentication_test_utils.py | 2 -- python/ray/tests/conftest.py | 3 +-- python/ray/tests/test_dashboard_auth.py | 9 ++++++--- .../ray/tests/test_runtime_env_agent_auth.py | 8 +++++--- .../ray/tests/test_submission_client_auth.py | 3 +-- python/ray/util/client/server/proxier.py | 6 ++++-- .../tests/runtime_env_agent_client_test.cc | 2 +- 11 files changed, 32 insertions(+), 36 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 9813520b4521..665d479fe28e 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -51,8 +51,8 @@ def apply_token_if_enabled( if "Authorization" in headers: return False - from ray.dashboard import authentication_utils as auth_utils from ray._raylet import AuthenticationTokenLoader + from ray.dashboard import authentication_utils as auth_utils if not auth_utils.is_token_auth_enabled(): return False @@ -73,20 +73,15 @@ def format_authentication_http_error(status: int, body: str) -> Optional[str]: """Return a user-friendly authentication error message, if applicable.""" if status == 401: - return ( - "Authentication required: {body}\n\n{details}".format( - body=body, - details=authentication_constants.HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, - ) + return "Authentication required: {body}\n\n{details}".format( + body=body, + details=authentication_constants.HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, ) if status == 403: - return ( - "Authentication failed: {body}\n\n{details}".format( - body=body, - details=authentication_constants.HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, - ) + return "Authentication failed: {body}\n\n{details}".format( + body=body, + details=authentication_constants.HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, ) return None - diff --git a/python/ray/_private/runtime_env/agent/main.py b/python/ray/_private/runtime_env/agent/main.py index 1049e7a933ad..05c8e2d320e2 100644 --- a/python/ray/_private/runtime_env/agent/main.py +++ b/python/ray/_private/runtime_env/agent/main.py @@ -8,6 +8,9 @@ get_or_create_event_loop, ) from ray._private import logging_utils +from ray._private.authentication.http_token_authentication import ( + create_token_authentication_middleware, +) from ray._private.process_watcher import create_check_raylet_task from ray._raylet import GcsClient from ray.core.generated import ( @@ -25,9 +28,6 @@ def import_libs(): import runtime_env_consts # noqa: E402 from aiohttp import web # noqa: E402 -from ray._private.authentication.http_token_authentication import ( - create_token_authentication_middleware, -) from runtime_env_agent import RuntimeEnvAgent # noqa: E402 if __name__ == "__main__": @@ -197,9 +197,7 @@ async def get_runtime_envs_info(request: web.Request) -> web.Response: body=reply.SerializeToString(), content_type="application/octet-stream" ) - app = web.Application( - middlewares=[create_token_authentication_middleware()] - ) + app = web.Application(middlewares=[create_token_authentication_middleware()]) app.router.add_post("/get_or_create_runtime_env", get_or_create_runtime_env) app.router.add_post( diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index b8123f23bc85..66016fb7f6c2 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -20,11 +20,11 @@ from ray._common.network_utils import build_address, parse_address from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop -from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics -from ray.dashboard.head import DashboardHeadModule from ray._private.authentication.http_token_authentication import ( create_token_authentication_middleware, ) +from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics +from ray.dashboard.head import DashboardHeadModule # All third-party dependencies that are not included in the minimal Ray # installation must be included in this file. This allows us to determine if diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index e2f9ccddc3d9..7740d43d28a6 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -580,6 +580,7 @@ py_test_module_list( "test_runtime_env_py_executable.py", "test_state_api_summary.py", "test_streaming_generator_regression.py", + "test_submission_client_auth.py", "test_system_metrics.py", "test_task_events_3.py", "test_task_metrics_reconstruction.py", @@ -590,7 +591,6 @@ py_test_module_list( "test_wait.py", "test_widgets.py", "test_worker_graceful_shutdown.py", - "test_submission_client_auth.py" ], tags = [ "exclusive", diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py index 63bc2ea519c6..3e2a1ece9136 100644 --- a/python/ray/tests/authentication_test_utils.py +++ b/python/ray/tests/authentication_test_utils.py @@ -8,7 +8,6 @@ from ray._raylet import AuthenticationTokenLoader, Config - AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" @@ -148,4 +147,3 @@ def authentication_env_guard(): yield snapshot finally: snapshot.restore() - diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index f90018692a1e..ca7c45fd1dbf 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -42,7 +42,6 @@ start_redis_sentinel_instance, teardown_tls, ) -from ray._raylet import AuthenticationTokenLoader, Config from ray.cluster_utils import AutoscalingCluster, Cluster, cluster_not_supported from ray.tests.authentication_test_utils import ( authentication_env_guard, @@ -66,7 +65,7 @@ def cleanup_auth_token_env(): """Reset authentication environment variables, files, and caches.""" - with authentication_env_guard() as snapshot: + with authentication_env_guard(): clear_auth_token_sources(remove_default=True) reset_auth_token_state() yield diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index 02ac47b213de..b76f0bcf3671 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -1,9 +1,14 @@ """Tests for dashboard token authentication.""" +import sys + +import pytest import requests -def test_dashboard_request_requires_auth_with_valid_token(setup_cluster_with_token_auth): +def test_dashboard_request_requires_auth_with_valid_token( + setup_cluster_with_token_auth, +): """Test that requests succeed with valid token when auth is enabled.""" cluster_info = setup_cluster_with_token_auth @@ -59,6 +64,4 @@ def test_dashboard_auth_disabled(setup_cluster_without_token_auth): if __name__ == "__main__": - import sys - sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 1a84cbdbd3b2..04fc189c79fd 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -11,7 +11,11 @@ format_authentication_http_error, ) from ray.core.generated import runtime_env_agent_pb2 -from ray.tests.authentication_test_utils import reset_auth_token_state, set_auth_mode, set_env_auth_token +from ray.tests.authentication_test_utils import ( + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) def _agent_url(agent_address: str, path: str) -> str: @@ -122,5 +126,3 @@ def test_apply_token_if_enabled_respects_existing_header(cleanup_auth_token_env) def test_format_authentication_http_error_non_auth_status(): assert format_authentication_http_error(404, "not found") is None - - diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index a136618ca6b4..96a4b0437475 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -1,4 +1,3 @@ -import os import tempfile from pathlib import Path from typing import Optional @@ -12,7 +11,6 @@ ) from ray.cluster_utils import Cluster from ray.dashboard.modules.job.sdk import JobSubmissionClient -from ray.util.state import StateApiClient from ray.tests.authentication_test_utils import ( clear_auth_token_sources, reset_auth_token_state, @@ -21,6 +19,7 @@ set_default_auth_token, set_env_auth_token, ) +from ray.util.state import StateApiClient def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 34b360409332..6c2c5794e6a8 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -19,11 +19,11 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 from ray._common.network_utils import build_address, is_localhost -from ray._private.client_mode_hook import disable_client_hook from ray._private.authentication.http_token_authentication import ( apply_token_if_enabled, format_authentication_http_error, ) +from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server @@ -252,7 +252,9 @@ def _create_runtime_env( data = create_env_request.SerializeToString() headers = {"Content-Type": "application/octet-stream"} apply_token_if_enabled(headers, logger) - req = urllib.request.Request(url, data=data, method="POST", headers=headers) + req = urllib.request.Request( + url, data=data, method="POST", headers=headers + ) response = urllib.request.urlopen(req, timeout=None) response_data = response.read() r = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply() diff --git a/src/ray/raylet/tests/runtime_env_agent_client_test.cc b/src/ray/raylet/tests/runtime_env_agent_client_test.cc index 197cac10a33a..ea5bb4974097 100644 --- a/src/ray/raylet/tests/runtime_env_agent_client_test.cc +++ b/src/ray/raylet/tests/runtime_env_agent_client_test.cc @@ -14,7 +14,6 @@ #include "ray/raylet/runtime_env_agent_client.h" #include -#include #include #include #include @@ -22,6 +21,7 @@ #include #include #include +#include #include #include #include From e3bf7f4ffe1ea1725c32c6afad76db27bc3f8903 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 06:49:24 +0000 Subject: [PATCH 68/94] fix build Signed-off-by: sampan --- python/ray/includes/rpc_token_authentication.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/includes/rpc_token_authentication.pxd b/python/ray/includes/rpc_token_authentication.pxd index cc6c05e92dac..bf666a6ec4c2 100644 --- a/python/ray/includes/rpc_token_authentication.pxd +++ b/python/ray/includes/rpc_token_authentication.pxd @@ -15,7 +15,7 @@ cdef extern from "ray/rpc/authentication/authentication_token.h" namespace "ray: CAuthenticationToken(string value) c_bool empty() c_bool Equals(const CAuthenticationToken& other) - string ToHttpHeaderValue() + string ToAuthorizationHeaderValue() @staticmethod CAuthenticationToken FromMetadata(string metadata_value) From c987f8803dcfb267cfe0f2ca4234bbf3926f9702 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 08:19:23 +0000 Subject: [PATCH 69/94] fix lint Signed-off-by: sampan --- .../_private/authentication/http_token_authentication.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 665d479fe28e..c295506ae290 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -1,6 +1,8 @@ import logging +import sys from typing import Dict, Optional +import pytest from aiohttp import web from ray._private.authentication import authentication_constants @@ -85,3 +87,7 @@ def format_authentication_http_error(status: int, body: str) -> Optional[str]: ) return None + + +if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) From a4a09cc04ff09b41244a97a437a1612fcac8e449 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 09:04:12 +0000 Subject: [PATCH 70/94] refactor and improve structure Signed-off-by: sampan --- .../authentication_constants.py | 6 - .../ray/includes/rpc_token_authentication.pxi | 5 +- python/ray/tests/authentication_test_utils.py | 14 +- python/ray/tests/conftest.py | 118 ++++----- .../ray/tests/test_submission_client_auth.py | 60 +---- .../ray/tests/test_token_auth_integration.py | 242 +++++++----------- src/ray/common/constants.h | 2 +- .../rpc/authentication/authentication_token.h | 2 +- 8 files changed, 171 insertions(+), 278 deletions(-) diff --git a/python/ray/_private/authentication/authentication_constants.py b/python/ray/_private/authentication/authentication_constants.py index 17e70232f4d4..347c7256e09d 100644 --- a/python/ray/_private/authentication/authentication_constants.py +++ b/python/ray/_private/authentication/authentication_constants.py @@ -1,9 +1,3 @@ -"""Centralized authentication constants and error messages for Ray. - -This module provides reusable error messages for authentication failures -across CLI, dashboard, and other Ray python components. -""" - # Token setup instructions (used in multiple contexts) TOKEN_SETUP_INSTRUCTIONS = """Please provide an authentication token using one of these methods: 1. Set the RAY_AUTH_TOKEN environment variable diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index beda93da86cd..3825933717ee 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -11,6 +11,7 @@ class AuthenticationMode: DISABLED = CAuthenticationMode.DISABLED TOKEN = CAuthenticationMode.TOKEN +_AUTHORIZATION_HEADER_NAME = "Authorization" def get_authentication_mode(): """Get the current authentication mode. @@ -87,7 +88,7 @@ class AuthenticationTokenLoader: bool: True if token was added, False otherwise """ # Don't override if user explicitly set Authorization header - if "Authorization" in headers: + if _AUTHORIZATION_HEADER_NAME in headers: return False # Check if token exists (doesn't crash, returns bool) @@ -100,5 +101,5 @@ class AuthenticationTokenLoader: if not token_opt.has_value() or token_opt.value().empty(): return False - headers["Authorization"] = token_opt.value().ToAuthorizationHeaderValue() + headers[_AUTHORIZATION_HEADER_NAME] = token_opt.value().ToAuthorizationHeaderValue() return True diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py index 068ab3370978..e98711b51e88 100644 --- a/python/ray/tests/authentication_test_utils.py +++ b/python/ray/tests/authentication_test_utils.py @@ -8,12 +8,12 @@ from ray._raylet import AuthenticationTokenLoader, Config -AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") -DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" +_AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") +_DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" def reset_auth_token_state() -> None: - """Reset authentication token and auth_mode config.""" + """Reset authentication token and auth_mode ray config.""" AuthenticationTokenLoader.instance().reset_cache() Config.initialize("") @@ -45,7 +45,7 @@ def set_auth_token_path(token: str, path: Path) -> None: def set_default_auth_token(token: str) -> Path: """Write the authentication token to the default ~/.ray/auth_token location.""" - default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH default_path.parent.mkdir(parents=True, exist_ok=True) default_path.write_text(token) return default_path @@ -58,7 +58,7 @@ def clear_auth_token_sources(remove_default: bool = False) -> None: os.environ.pop(var, None) if remove_default: - default_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH default_path.unlink(missing_ok=True) @@ -76,7 +76,7 @@ class AuthenticationEnvSnapshot: def capture(cls) -> "AuthenticationEnvSnapshot": """Capture current authentication-related environment state.""" - original_env = {var: os.environ.get(var) for var in AUTH_ENV_VARS} + original_env = {var: os.environ.get(var) for var in _AUTH_ENV_VARS} home_was_set = "HOME" in os.environ original_home = os.environ.get("HOME") temp_home: Optional[Path] = None @@ -89,7 +89,7 @@ def capture(cls) -> "AuthenticationEnvSnapshot": temp_home.mkdir(parents=True, exist_ok=True) os.environ["HOME"] = str(temp_home) - default_token_path = Path.home() / DEFAULT_AUTH_TOKEN_RELATIVE_PATH + default_token_path = Path.home() / _DEFAULT_AUTH_TOKEN_RELATIVE_PATH default_token_exists = default_token_path.exists() default_token_contents = ( default_token_path.read_text() if default_token_exists else None diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index ca7c45fd1dbf..dbc92903b766 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -61,65 +61,6 @@ START_REDIS_WAIT_RETRIES = int(os.environ.get("RAY_START_REDIS_WAIT_RETRIES", "60")) -@pytest.fixture -def cleanup_auth_token_env(): - """Reset authentication environment variables, files, and caches.""" - - with authentication_env_guard(): - clear_auth_token_sources(remove_default=True) - reset_auth_token_state() - yield - reset_auth_token_state() - - -@pytest.fixture -def setup_cluster_with_token_auth(cleanup_auth_token_env): - """Spin up a Ray cluster with token authentication enabled.""" - - test_token = "test_token_12345678901234567890123456789012" - set_auth_mode("token") - set_env_auth_token(test_token) - reset_auth_token_state() - - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - yield { - "cluster": cluster, - "dashboard_url": f"http://{dashboard_url}", - "token": test_token, - } - finally: - ray.shutdown() - cluster.shutdown() - - -@pytest.fixture -def setup_cluster_without_token_auth(cleanup_auth_token_env): - """Spin up a Ray cluster with authentication disabled.""" - - set_auth_mode("disabled") - clear_auth_token_sources(remove_default=True) - reset_auth_token_state() - - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - yield { - "cluster": cluster, - "dashboard_url": f"http://{dashboard_url}", - } - finally: - ray.shutdown() - cluster.shutdown() - - @pytest.fixture(autouse=True) def pre_envs(monkeypatch): # To make test run faster @@ -1550,3 +1491,62 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context): server.clear() if server.is_running(): server.stop() + + +@pytest.fixture +def cleanup_auth_token_env(): + """Reset authentication environment variables, files, and caches.""" + + with authentication_env_guard(): + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() + yield + reset_auth_token_state() + + +@pytest.fixture +def setup_cluster_with_token_auth(cleanup_auth_token_env): + """Spin up a Ray cluster with token authentication enabled.""" + + test_token = "test_token_12345678901234567890123456789012" + set_auth_mode("token") + set_env_auth_token(test_token) + reset_auth_token_state() + + cluster = Cluster() + cluster.add_node() + + try: + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] + yield { + "cluster": cluster, + "dashboard_url": f"http://{dashboard_url}", + "token": test_token, + } + finally: + ray.shutdown() + cluster.shutdown() + + +@pytest.fixture +def setup_cluster_without_token_auth(cleanup_auth_token_env): + """Spin up a Ray cluster with authentication disabled.""" + + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() + + cluster = Cluster() + cluster.add_node() + + try: + context = ray.init(address=cluster.address) + dashboard_url = context.address_info["webui_url"] + yield { + "cluster": cluster, + "dashboard_url": f"http://{dashboard_url}", + } + finally: + ray.shutdown() + cluster.shutdown() diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index c617d11e327e..9e76dc234d26 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -1,22 +1,15 @@ -import tempfile -from pathlib import Path -from typing import Optional - import pytest -import ray from ray._private.authentication.authentication_constants import ( HTTP_REQUEST_INVALID_TOKEN_ERROR_MESSAGE, HTTP_REQUEST_MISSING_TOKEN_ERROR_MESSAGE, ) -from ray.cluster_utils import Cluster +from ray.dashboard.modules.dashboard_sdk import SubmissionClient from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.tests.authentication_test_utils import ( clear_auth_token_sources, reset_auth_token_state, set_auth_mode, - set_auth_token_path, - set_default_auth_token, set_env_auth_token, ) from ray.util.state import StateApiClient @@ -25,7 +18,6 @@ def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): """Test that SubmissionClient automatically adds token to headers.""" # Token is already set in environment from setup_cluster_with_token_auth fixture - from ray.dashboard.modules.dashboard_sdk import SubmissionClient client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) @@ -43,8 +35,6 @@ def test_submission_client_without_token_shows_helpful_error( set_auth_mode("disabled") reset_auth_token_state() - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) # Make a request - should fail with helpful message @@ -68,8 +58,6 @@ def test_submission_client_with_invalid_token_shows_helpful_error( set_auth_mode("token") reset_auth_token_state() - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) # Make a request - should fail with helpful message @@ -85,8 +73,6 @@ def test_submission_client_with_invalid_token_shows_helpful_error( def test_submission_client_with_valid_token_succeeds(setup_cluster_with_token_auth): """Test that requests with valid token succeed.""" - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) # Make a request - should succeed @@ -120,8 +106,6 @@ def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): """Test that user-provided Authorization header is not overridden.""" custom_auth = "Bearer custom_token" - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient( address=setup_cluster_with_token_auth["dashboard_url"], headers={"Authorization": custom_auth}, @@ -138,8 +122,6 @@ def test_error_messages_contain_instructions(setup_cluster_with_token_auth): set_auth_mode("disabled") reset_auth_token_state() - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) with pytest.raises(RuntimeError) as exc_info: @@ -168,49 +150,9 @@ def test_error_messages_contain_instructions(setup_cluster_with_token_auth): assert str(exc_info.value) == expected_invalid -@pytest.mark.parametrize("token_source", ["env_var", "token_path", "default_path"]) -def test_token_loaded_from_sources(cleanup_auth_token_env, token_source): - """Test that SubmissionClient loads tokens from all supported sources.""" - - test_token = "test_token_12345678901234567890123456789012" - set_auth_mode("token") - - token_file_path: Optional[Path] = None - - if token_source == "env_var": - set_env_auth_token(test_token) - elif token_source == "token_path": - with tempfile.NamedTemporaryFile(delete=False) as tmp: - token_file_path = Path(tmp.name) - set_auth_token_path(test_token, token_file_path) - else: - set_default_auth_token(test_token) - - reset_auth_token_state() - - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - - client = SubmissionClient(address=f"http://{dashboard_url}") - assert client._headers["Authorization"] == f"Bearer {test_token}" - finally: - ray.shutdown() - cluster.shutdown() - if token_source == "token_path" and token_file_path: - token_file_path.unlink(missing_ok=True) - - def test_no_token_added_when_auth_disabled(setup_cluster_without_token_auth): """Test that no Authorization header is injected when auth is disabled.""" - from ray.dashboard.modules.dashboard_sdk import SubmissionClient - client = SubmissionClient(address=setup_cluster_without_token_auth["dashboard_url"]) assert "Authorization" not in client._headers diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 0168ab958b83..8913a166d463 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -11,7 +11,6 @@ import ray from ray._private.test_utils import wait_for_condition from ray._raylet import AuthenticationTokenLoader -from ray.cluster_utils import Cluster from ray.tests.authentication_test_utils import ( clear_auth_token_sources, reset_auth_token_state, @@ -139,98 +138,74 @@ def test_local_cluster_generates_token(): ray.shutdown() -def test_connect_without_token_raises_error(): +def test_connect_without_token_raises_error(setup_cluster_with_token_auth): """Test ray.init(address=...) without token fails when auth_mode=token is set.""" - # Set up a cluster with token auth enabled - cluster_token = "testtoken12345678901234567890" - set_env_auth_token(cluster_token) - set_auth_mode("token") - reset_auth_token_state() + cluster_info = setup_cluster_with_token_auth + cluster = cluster_info["cluster"] - # Create cluster with token auth enabled - cluster = Cluster() - cluster.add_node() - - try: - # Remove the token from the environment so we try to connect without it - set_auth_mode("disabled") - clear_auth_token_sources(remove_default=True) - reset_auth_token_state() - - # Ensure no token exists - token_loader = AuthenticationTokenLoader.instance() - assert not token_loader.has_token() + # Disconnect the current driver session and drop token state before retrying. + ray.shutdown() + set_auth_mode("disabled") + clear_auth_token_sources(remove_default=True) + reset_auth_token_state() - # Try to connect to the cluster without a token - should raise RuntimeError - with pytest.raises(ConnectionError): - ray.init(address=cluster.address) + # Ensure no token exists + token_loader = AuthenticationTokenLoader.instance() + assert not token_loader.has_token() - finally: - cluster.shutdown() + # Try to connect to the cluster without a token - should raise RuntimeError + with pytest.raises(ConnectionError): + ray.init(address=cluster.address) @pytest.mark.parametrize("tokens_match", [True, False]) -def test_cluster_token_authentication(tokens_match): +def test_cluster_token_authentication(tokens_match, setup_cluster_with_token_auth): """Test cluster authentication with matching and non-matching tokens.""" - # Set up cluster token first - cluster_token = "a" * 32 - set_env_auth_token(cluster_token) - set_auth_mode("token") - reset_auth_token_state() - - # Create cluster with token auth enabled - node will read current env token - cluster = Cluster() - cluster.add_node() + cluster_info = setup_cluster_with_token_auth + cluster = cluster_info["cluster"] + cluster_token = cluster_info["token"] - try: - # Set client token based on test parameter - if tokens_match: - client_token = cluster_token # Same token - should succeed - else: - client_token = "b" * 32 # Different token - should fail - - set_env_auth_token(client_token) - - # Reset cached token so it reads the new environment variable - reset_auth_token_state() + # Reconfigure the driver token state to simulate fresh connections. + ray.shutdown() + set_auth_mode("token") - if tokens_match: - # Should succeed - test gRPC calls work - ray.init(address=cluster.address) + if tokens_match: + client_token = cluster_token # Same token - should succeed + else: + client_token = "b" * 32 # Different token - should fail - # Test that gRPC calls succeed - obj_ref = ray.put("test_data") - result = ray.get(obj_ref) - assert result == "test_data" + set_env_auth_token(client_token) + reset_auth_token_state() - # Test remote function call - @ray.remote - def test_func(): - return "success" + if tokens_match: + # Should succeed - test gRPC calls work + ray.init(address=cluster.address) - result = ray.get(test_func.remote()) - assert result == "success" + obj_ref = ray.put("test_data") + result = ray.get(obj_ref) + assert result == "test_data" - ray.shutdown() + @ray.remote + def test_func(): + return "success" - else: - # Should fail - connection or gRPC calls should fail - with pytest.raises((ConnectionError, RuntimeError)): - ray.init(address=cluster.address) - # If init somehow succeeds, try a gRPC operation that should fail - try: - ray.put("test") - finally: - ray.shutdown() + result = ray.get(test_func.remote()) + assert result == "success" - finally: - # Ensure cleanup ray.shutdown() - cluster.shutdown() + + else: + # Should fail - connection or gRPC calls should fail + with pytest.raises((ConnectionError, RuntimeError)): + ray.init(address=cluster.address) + try: + ray.put("test") + finally: + ray.shutdown() @pytest.mark.parametrize("is_head", [True, False]) -def test_ray_start_without_token_raises_error(is_head): +def test_ray_start_without_token_raises_error(is_head, request): """Test that ray start fails when auth_mode=token but no token exists.""" # Set up environment with token auth enabled but no token env = os.environ.copy() @@ -243,29 +218,20 @@ def test_ray_start_without_token_raises_error(is_head): assert not default_token_path.exists() # When specifying an address, we need a head node to connect to - cluster = None + cluster_info = None if not is_head: - # Start head node with token - cluster_token = "a" * 32 - set_env_auth_token(cluster_token) - set_auth_mode("token") - reset_auth_token_state() - cluster = Cluster() - cluster.add_node() - - try: - # Prepare arguments - if is_head: - args = ["--head", "--port=0"] - else: - args = [f"--address={cluster.address}"] + cluster_info = request.getfixturevalue("setup_cluster_with_token_auth") + cluster = cluster_info["cluster"] + ray.shutdown() - # Try to start node - should fail - _run_ray_start_and_verify_status(args, env, expect_success=False) + # Prepare arguments + if is_head: + args = ["--head", "--port=0"] + else: + args = [f"--address={cluster.address}"] - finally: - if cluster: - cluster.shutdown() + # Try to start node - should fail + _run_ray_start_and_verify_status(args, env, expect_success=False) def test_ray_start_head_with_token_succeeds(): @@ -312,65 +278,55 @@ def test_func(): @pytest.mark.parametrize("token_match", ["correct", "incorrect"]) -def test_ray_start_address_with_token(token_match): +def test_ray_start_address_with_token(token_match, setup_cluster_with_token_auth): """Test ray start --address=... with correct or incorrect token.""" - # Start a head node with token auth - cluster_token = "a" * 32 - set_env_auth_token(cluster_token) - set_auth_mode("token") - reset_auth_token_state() + cluster_info = setup_cluster_with_token_auth + cluster = cluster_info["cluster"] + cluster_token = cluster_info["token"] - cluster = Cluster() - cluster.add_node(num_cpus=1) + # Reset the driver connection to reuse the fixture-backed cluster. + ray.shutdown() + set_auth_mode("token") - try: - # Set up environment for worker - env = os.environ.copy() - env["RAY_auth_mode"] = "token" - - if token_match == "correct": - env["RAY_AUTH_TOKEN"] = cluster_token - expect_success = True - else: - # Use different token - env["RAY_AUTH_TOKEN"] = "b" * 32 - expect_success = False - - # Start worker node - _run_ray_start_and_verify_status( - [f"--address={cluster.address}", "--num-cpus=1"], - env, - expect_success=expect_success, - ) + # Set up environment for worker + env = os.environ.copy() + env["RAY_auth_mode"] = "token" - if token_match == "correct": - try: - # Connect and verify the cluster has 2 nodes (head + worker) - ray.init(address=cluster.address) + if token_match == "correct": + env["RAY_AUTH_TOKEN"] = cluster_token + expect_success = True + else: + env["RAY_AUTH_TOKEN"] = "b" * 32 + expect_success = False + + # Start worker node + _run_ray_start_and_verify_status( + [f"--address={cluster.address}", "--num-cpus=1"], + env, + expect_success=expect_success, + ) - # Wait for worker node to register - def worker_joined(): - return len(ray.nodes()) >= 2 + if token_match == "correct": + try: + # Connect and verify the cluster has 2 nodes (head + worker) + set_env_auth_token(cluster_token) + reset_auth_token_state() + ray.init(address=cluster.address) - wait_for_condition(worker_joined, timeout=10) + def worker_joined(): + return len(ray.nodes()) >= 2 - nodes = ray.nodes() - assert ( - len(nodes) >= 2 - ), f"Expected at least 2 nodes, got {len(nodes)}: {nodes}" + wait_for_condition(worker_joined, timeout=10) - finally: - # Always shutdown ray.init() connection before cleanup - if ray.is_initialized(): - ray.shutdown() - # Clean up the worker node started with ray start - _cleanup_ray_start(env) + nodes = ray.nodes() + assert ( + len(nodes) >= 2 + ), f"Expected at least 2 nodes, got {len(nodes)}: {nodes}" - finally: - # Clean up cluster - if ray.is_initialized(): - ray.shutdown() - cluster.shutdown() + finally: + if ray.is_initialized(): + ray.shutdown() + _cleanup_ray_start(env) if __name__ == "__main__": diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 08986d3b415e..9d8928d2149c 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,7 +42,7 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; -constexpr char kAuthTokenKey[] = "authorization"; +constexpr char kAuthTokenKey[] = "Authorization"; constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index bdedd5bf6e84..076eed49c898 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -86,7 +86,7 @@ class AuthenticationToken { } } - /// Get token as HTTP Authorization header value + /// Get token as Authorization header value /// WARNING: This exposes the raw token. Use sparingly. /// Returns "Bearer " format suitable for Authorization header /// @return Authorization header value, or empty string if token is empty From c1873a4392f5c801aa4cb70fbac0eafec108a829 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 09:07:14 +0000 Subject: [PATCH 71/94] fix lint Signed-off-by: sampan --- python/ray/tests/test_runtime_env_agent_auth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 04fc189c79fd..45525af59fb7 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -2,7 +2,7 @@ import urllib.error import urllib.parse import urllib.request - +import sys import pytest import ray @@ -126,3 +126,6 @@ def test_apply_token_if_enabled_respects_existing_header(cleanup_auth_token_env) def test_format_authentication_http_error_non_auth_status(): assert format_authentication_http_error(404, "not found") is None + +if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) \ No newline at end of file From 6c71ceb8d67908166c6558c3d84179f58b75c70f Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 09:07:42 +0000 Subject: [PATCH 72/94] fix lint Signed-off-by: sampan --- python/ray/tests/test_runtime_env_agent_auth.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 45525af59fb7..f1a33af0a986 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -1,8 +1,9 @@ import logging +import sys import urllib.error import urllib.parse import urllib.request -import sys + import pytest import ray @@ -127,5 +128,6 @@ def test_apply_token_if_enabled_respects_existing_header(cleanup_auth_token_env) def test_format_authentication_http_error_non_auth_status(): assert format_authentication_http_error(404, "not found") is None + if __name__ == "__main__": - sys.exit(pytest.main(["-vv", __file__])) \ No newline at end of file + sys.exit(pytest.main(["-vv", __file__])) From 93f4ff6066d3ef677a3b48186145af144eb6b9d5 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 10:19:32 +0000 Subject: [PATCH 73/94] refactor and fix test flakyness Signed-off-by: sampan --- .../authentication_constants.py | 2 + .../http_token_authentication.py | 58 +++++++------------ python/ray/_private/runtime_env/agent/main.py | 4 +- python/ray/dashboard/http_server_head.py | 3 +- python/ray/dashboard/modules/dashboard_sdk.py | 4 +- .../ray/includes/rpc_token_authentication.pxi | 7 +-- python/ray/tests/test_dashboard_auth.py | 1 + .../ray/tests/test_runtime_env_agent_auth.py | 34 ++++++++--- python/ray/util/client/server/proxier.py | 11 +--- 9 files changed, 60 insertions(+), 64 deletions(-) diff --git a/python/ray/_private/authentication/authentication_constants.py b/python/ray/_private/authentication/authentication_constants.py index 347c7256e09d..85e90edce9e1 100644 --- a/python/ray/_private/authentication/authentication_constants.py +++ b/python/ray/_private/authentication/authentication_constants.py @@ -21,3 +21,5 @@ "The authentication token you provided is invalid or incorrect.\n\n" + TOKEN_SETUP_INSTRUCTIONS ) + +AUTHORIZATION_HEADER_NAME = "Authorization" diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index c295506ae290..45eb287f779a 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -1,4 +1,3 @@ -import logging import sys from typing import Dict, Optional @@ -6,37 +5,31 @@ from aiohttp import web from ray._private.authentication import authentication_constants +from ray._raylet import AuthenticationTokenLoader +from ray.dashboard import authentication_utils as auth_utils -def create_token_authentication_middleware() -> web.middleware: - """Return an aiohttp middleware that validates bearer tokens when enabled.""" - - from ray.dashboard import authentication_utils as auth_utils - - @web.middleware - async def auth_middleware(request: web.Request, handler): - if not auth_utils.is_token_auth_enabled(): - return await handler(request) +@web.middleware +async def token_auth_middleware(request: web.Request, handler): + """Middleware to validate bearer tokens when token authentication is enabled.""" + if not auth_utils.is_token_auth_enabled(): + return await handler(request) - auth_header = request.headers.get("Authorization", "") - if not auth_header: - return web.Response( - status=401, text="Unauthorized: Missing authentication token" - ) + auth_header = request.headers.get( + authentication_constants.AUTHORIZATION_HEADER_NAME, "" + ) + if not auth_header: + return web.Response( + status=401, text="Unauthorized: Missing authentication token" + ) - if not auth_utils.validate_request_token(auth_header): - return web.Response( - status=403, text="Forbidden: Invalid authentication token" - ) + if not auth_utils.validate_request_token(auth_header): + return web.Response(status=403, text="Forbidden: Invalid authentication token") - return await handler(request) + return await handler(request) - return auth_middleware - -def apply_token_if_enabled( - headers: Dict[str, str], logger: Optional[logging.Logger] = None -) -> bool: +def inject_auth_token_if_enabled(headers: Dict[str, str]) -> bool: """Inject Authorization header when token auth is enabled. Args: @@ -50,25 +43,14 @@ def apply_token_if_enabled( if headers is None: raise ValueError("headers must be provided") - if "Authorization" in headers: + if authentication_constants.AUTHORIZATION_HEADER_NAME in headers: return False - from ray._raylet import AuthenticationTokenLoader - from ray.dashboard import authentication_utils as auth_utils - if not auth_utils.is_token_auth_enabled(): return False token_loader = AuthenticationTokenLoader.instance() - token_added = token_loader.set_token_for_http_header(headers) - - if not token_added and logger is not None: - logger.warning( - "Token authentication is enabled but no token was found. " - "Requests to authenticated clusters will fail." - ) - - return token_added + return token_loader.set_token_for_http_header(headers) def format_authentication_http_error(status: int, body: str) -> Optional[str]: diff --git a/python/ray/_private/runtime_env/agent/main.py b/python/ray/_private/runtime_env/agent/main.py index 05c8e2d320e2..f31a927a857e 100644 --- a/python/ray/_private/runtime_env/agent/main.py +++ b/python/ray/_private/runtime_env/agent/main.py @@ -9,7 +9,7 @@ ) from ray._private import logging_utils from ray._private.authentication.http_token_authentication import ( - create_token_authentication_middleware, + token_auth_middleware, ) from ray._private.process_watcher import create_check_raylet_task from ray._raylet import GcsClient @@ -197,7 +197,7 @@ async def get_runtime_envs_info(request: web.Request) -> web.Response: body=reply.SerializeToString(), content_type="application/octet-stream" ) - app = web.Application(middlewares=[create_token_authentication_middleware()]) + app = web.Application(middlewares=[token_auth_middleware]) app.router.add_post("/get_or_create_runtime_env", get_or_create_runtime_env) app.router.add_post( diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 66016fb7f6c2..00f6228e0844 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -21,7 +21,7 @@ from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop from ray._private.authentication.http_token_authentication import ( - create_token_authentication_middleware, + token_auth_middleware, ) from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics from ray.dashboard.head import DashboardHeadModule @@ -251,7 +251,6 @@ async def run( # Http server should be initialized after all modules loaded. # working_dir uploads for job submission can be up to 100MiB. - token_auth_middleware = create_token_authentication_middleware() app = aiohttp.web.Application( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index d45a9259f19b..c4db01401a47 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -13,8 +13,8 @@ import ray from ray._private.authentication.http_token_authentication import ( - apply_token_if_enabled, format_authentication_http_error, + inject_auth_token_if_enabled, ) from ray._private.runtime_env.packaging import ( create_package, @@ -252,7 +252,7 @@ def __init__( def _set_auth_header_if_enabled(self): """Add authentication token to headers if token auth is enabled.""" - apply_token_if_enabled(self._headers, logger) + inject_auth_token_if_enabled(self._headers) def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index 3825933717ee..b586fa90b0d5 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -4,6 +4,7 @@ from ray.includes.rpc_token_authentication cimport ( CAuthenticationToken, CAuthenticationTokenLoader, ) +from ray._private.authentication.authentication_constants import AUTHORIZATION_HEADER_NAME # Authentication mode enum exposed to Python @@ -11,8 +12,6 @@ class AuthenticationMode: DISABLED = CAuthenticationMode.DISABLED TOKEN = CAuthenticationMode.TOKEN -_AUTHORIZATION_HEADER_NAME = "Authorization" - def get_authentication_mode(): """Get the current authentication mode. @@ -88,7 +87,7 @@ class AuthenticationTokenLoader: bool: True if token was added, False otherwise """ # Don't override if user explicitly set Authorization header - if _AUTHORIZATION_HEADER_NAME in headers: + if AUTHORIZATION_HEADER_NAME in headers: return False # Check if token exists (doesn't crash, returns bool) @@ -101,5 +100,5 @@ class AuthenticationTokenLoader: if not token_opt.has_value() or token_opt.value().empty(): return False - headers[_AUTHORIZATION_HEADER_NAME] = token_opt.value().ToAuthorizationHeaderValue() + headers[AUTHORIZATION_HEADER_NAME] = token_opt.value().ToAuthorizationHeaderValue() return True diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py index b76f0bcf3671..7407fc199a1d 100644 --- a/python/ray/tests/test_dashboard_auth.py +++ b/python/ray/tests/test_dashboard_auth.py @@ -64,4 +64,5 @@ def test_dashboard_auth_disabled(setup_cluster_without_token_auth): if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index f1a33af0a986..45100bf11c3d 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -1,4 +1,4 @@ -import logging +import socket import sys import urllib.error import urllib.parse @@ -7,9 +7,10 @@ import pytest import ray +from ray._common.test_utils import wait_for_condition from ray._private.authentication.http_token_authentication import ( - apply_token_if_enabled, format_authentication_http_error, + inject_auth_token_if_enabled, ) from ray.core.generated import runtime_env_agent_pb2 from ray.tests.authentication_test_utils import ( @@ -32,8 +33,22 @@ def _make_get_or_create_request() -> runtime_env_agent_pb2.GetOrCreateRuntimeEnv return request +def _wait_for_runtime_env_agent(agent_address: str) -> None: + parsed = urllib.parse.urlparse(agent_address) + + def _can_connect() -> bool: + try: + with socket.create_connection((parsed.hostname, parsed.port), timeout=1): + return True + except OSError: + return False + + wait_for_condition(_can_connect, timeout=10) + + def test_runtime_env_agent_requires_auth_missing_token(setup_cluster_with_token_auth): agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + _wait_for_runtime_env_agent(agent_address) request = _make_get_or_create_request() with pytest.raises(urllib.error.HTTPError) as exc_info: @@ -56,6 +71,7 @@ def test_runtime_env_agent_requires_auth_missing_token(setup_cluster_with_token_ def test_runtime_env_agent_rejects_invalid_token(setup_cluster_with_token_auth): agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + _wait_for_runtime_env_agent(agent_address) request = _make_get_or_create_request() with pytest.raises(urllib.error.HTTPError) as exc_info: @@ -81,6 +97,7 @@ def test_runtime_env_agent_rejects_invalid_token(setup_cluster_with_token_auth): def test_runtime_env_agent_accepts_valid_token(setup_cluster_with_token_auth): agent_address = ray._private.worker.global_worker.node.runtime_env_agent_address + _wait_for_runtime_env_agent(agent_address) token = setup_cluster_with_token_auth["token"] request = _make_get_or_create_request() @@ -101,25 +118,28 @@ def test_runtime_env_agent_accepts_valid_token(setup_cluster_with_token_auth): assert reply.status == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK -def test_apply_token_if_enabled_adds_header(cleanup_auth_token_env): +def test_inject_token_if_enabled_adds_header(cleanup_auth_token_env): set_auth_mode("token") set_env_auth_token("apptoken1234567890") reset_auth_token_state() headers = {} - added = apply_token_if_enabled(headers, logging.getLogger(__name__)) + added = inject_auth_token_if_enabled(headers) assert added is True - assert headers["Authorization"] == "Bearer apptoken1234567890" + auth_header = headers["Authorization"] + if isinstance(auth_header, bytes): + auth_header = auth_header.decode("utf-8") + assert auth_header == "Bearer apptoken1234567890" -def test_apply_token_if_enabled_respects_existing_header(cleanup_auth_token_env): +def test_inject_token_if_enabled_respects_existing_header(cleanup_auth_token_env): set_auth_mode("token") set_env_auth_token("apptoken1234567890") reset_auth_token_state() headers = {"Authorization": "Bearer custom"} - added = apply_token_if_enabled(headers, logging.getLogger(__name__)) + added = inject_auth_token_if_enabled(headers) assert added is False assert headers["Authorization"] == "Bearer custom" diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 6c2c5794e6a8..62c3e5383089 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -20,8 +20,8 @@ import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 from ray._common.network_utils import build_address, is_localhost from ray._private.authentication.http_token_authentication import ( - apply_token_if_enabled, format_authentication_http_error, + inject_auth_token_if_enabled, ) from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams @@ -251,7 +251,7 @@ def _create_runtime_env( ) data = create_env_request.SerializeToString() headers = {"Content-Type": "application/octet-stream"} - apply_token_if_enabled(headers, logger) + inject_auth_token_if_enabled(headers) req = urllib.request.Request( url, data=data, method="POST", headers=headers ) @@ -283,13 +283,6 @@ def _create_runtime_env( if formatted_error: raise RuntimeError(formatted_error) from e - last_exception = e - logger.warning( - f"GetOrCreateRuntimeEnv request failed: {e}. " - f"Retrying after {wait_time_s}s. " - f"{max_retries-retries} retries remaining." - ) - except urllib.error.URLError as e: last_exception = e logger.warning( From 20269431191cfaf316aaaf4366c112b60b2a680f Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 30 Oct 2025 10:20:37 +0000 Subject: [PATCH 74/94] fix lint Signed-off-by: sampan --- .../authentication/http_token_authentication.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 45eb287f779a..cd1c45c4c41b 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -30,15 +30,7 @@ async def token_auth_middleware(request: web.Request, handler): def inject_auth_token_if_enabled(headers: Dict[str, str]) -> bool: - """Inject Authorization header when token auth is enabled. - - Args: - headers: Mutable mapping of HTTP headers. Updated in place. - logger: Optional logger used for warning when token is missing. - - Returns: - bool: True if the token was added to headers, False otherwise. - """ + """Inject Authorization header when token auth is enabled.""" if headers is None: raise ValueError("headers must be provided") From 842b21da5717069ac0d622b59257d680ca8346a9 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 30 Oct 2025 17:20:34 -0500 Subject: [PATCH 75/94] Fix Signed-off-by: Edward Oakes --- python/ray/tests/test_dashboard_auth.py | 129 ------------------------ 1 file changed, 129 deletions(-) delete mode 100644 python/ray/tests/test_dashboard_auth.py diff --git a/python/ray/tests/test_dashboard_auth.py b/python/ray/tests/test_dashboard_auth.py deleted file mode 100644 index cb57c785a518..000000000000 --- a/python/ray/tests/test_dashboard_auth.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Tests for dashboard token authentication.""" - -import os - -import pytest -import requests - -import ray -from ray._raylet import Config -from ray.cluster_utils import Cluster - - -@pytest.fixture -def cleanup_env(): - """Clean up environment variables after each test.""" - yield - # Clean up environment variables - if "RAY_auth_mode" in os.environ: - del os.environ["RAY_auth_mode"] - if "RAY_AUTH_TOKEN" in os.environ: - del os.environ["RAY_AUTH_TOKEN"] - - -def test_dashboard_request_requires_auth_with_valid_token(cleanup_env): - """Test that requests succeed with valid token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with valid auth should succeed - headers = {"Authorization": f"Bearer {test_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - headers=headers, - ) - assert response.status_code == 200 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_missing_token(cleanup_env): - """Test that requests fail without token when auth is enabled.""" - test_token = "test_token_12345678901234567890123456789012" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = test_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # GET without auth should fail with 401 - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - ) - assert response.status_code == 401 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_request_requires_auth_invalid_token(cleanup_env): - """Test that requests fail with invalid token when auth is enabled.""" - correct_token = "test_token_12345678901234567890123456789012" - wrong_token = "wrong_token_00000000000000000000000000000000" - os.environ["RAY_auth_mode"] = "token" - os.environ["RAY_AUTH_TOKEN"] = correct_token - Config.initialize("") - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # Request with wrong token should fail with 403 - headers = {"Authorization": f"Bearer {wrong_token}"} - response = requests.get( - f"http://{dashboard_url}/api/component_activities", - json={"test": "data"}, - headers=headers, - ) - assert response.status_code == 403 - - finally: - ray.shutdown() - cluster.shutdown() - - -def test_dashboard_auth_disabled(cleanup_env): - """Test that auth is not enforced when auth_mode is disabled.""" - os.environ["RAY_auth_mode"] = "disabled" - - cluster = Cluster() - cluster.add_node() - - try: - context = ray.init(address=cluster.address) - dashboard_url = context.address_info["webui_url"] - - # GET without auth should succeed when auth is disabled - response = requests.get( - f"http://{dashboard_url}/api/component_activities", json={"test": "data"} - ) - # Should not return 401 or 403 - assert response.status_code == 200 - - finally: - ray.shutdown() - cluster.shutdown() - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-vv", __file__])) From 95ea5d2c96ef20cb78cfe341b02a5ae98afc6f64 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 03:17:42 +0000 Subject: [PATCH 76/94] remove "test_dashboard_auth.py" from BUILD.bazel as the test has been moved Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 3b11301a5318..cf5c69b54ac0 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -543,7 +543,6 @@ py_test_module_list( "test_concurrency_group.py", "test_core_worker_io_thread_stack_size.py", "test_cross_language.py", - "test_dashboard_auth.py", "test_debug_tools.py", "test_distributed_sort.py", "test_environ.py", From 8874a2236e202c76bd3e266638b03ff41d7ccada Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 06:05:25 +0000 Subject: [PATCH 77/94] fix tests Signed-off-by: sampan --- .../dashboard/tests/test_dashboard_auth.py | 103 +++++------------- .../ray/includes/rpc_token_authentication.pxi | 2 +- src/ray/common/constants.h | 2 +- 3 files changed, 29 insertions(+), 78 deletions(-) diff --git a/python/ray/dashboard/tests/test_dashboard_auth.py b/python/ray/dashboard/tests/test_dashboard_auth.py index 948be389f12e..7407fc199a1d 100644 --- a/python/ray/dashboard/tests/test_dashboard_auth.py +++ b/python/ray/dashboard/tests/test_dashboard_auth.py @@ -1,117 +1,68 @@ """Tests for dashboard token authentication.""" -import os + import sys import pytest import requests -import ray -from ray._raylet import Config - - -@pytest.fixture -def start_ray_with_env_vars(request): - """Clean up environment variables after each test.""" - env_vars = getattr(request, "param", {}).pop("env_vars", {}) - os.environ.update(**env_vars) - Config.initialize("") - - yield ray.init() - - ray.shutdown() - for k in env_vars.keys(): - del os.environ[k] - -TEST_TOKEN = "test_token_12345678901234567890123456789012" - - -@pytest.mark.parametrize( - "start_ray_with_env_vars", - [ - { - "env_vars": {"RAY_auth_mode": "token", "RAY_AUTH_TOKEN": TEST_TOKEN}, - }, - ], - indirect=True, -) -def test_auth_enabled_valid_token(start_ray_with_env_vars): +def test_dashboard_request_requires_auth_with_valid_token( + setup_cluster_with_token_auth, +): """Test that requests succeed with valid token when auth is enabled.""" - dashboard_url = start_ray_with_env_vars.address_info["webui_url"] - # Request with valid auth should succeed - headers = {"Authorization": f"Bearer {TEST_TOKEN}"} + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": f"Bearer {cluster_info['token']}"} + response = requests.get( - f"http://{dashboard_url}/api/component_activities", + f"{cluster_info['dashboard_url']}/api/component_activities", headers=headers, ) + assert response.status_code == 200 -@pytest.mark.parametrize( - "start_ray_with_env_vars", - [ - { - "env_vars": {"RAY_auth_mode": "token", "RAY_AUTH_TOKEN": TEST_TOKEN}, - }, - ], - indirect=True, -) -def test_auth_enabled_missing_token(start_ray_with_env_vars): +def test_dashboard_request_requires_auth_missing_token(setup_cluster_with_token_auth): """Test that requests fail without token when auth is enabled.""" - dashboard_url = start_ray_with_env_vars.address_info["webui_url"] - # GET without auth should fail with 401 + cluster_info = setup_cluster_with_token_auth + response = requests.get( - f"http://{dashboard_url}/api/component_activities", + f"{cluster_info['dashboard_url']}/api/component_activities", json={"test": "data"}, ) + assert response.status_code == 401 -@pytest.mark.parametrize( - "start_ray_with_env_vars", - [ - { - "env_vars": {"RAY_auth_mode": "token", "RAY_AUTH_TOKEN": TEST_TOKEN}, - }, - ], - indirect=True, -) -def test_auth_enabled_invalid_token(start_ray_with_env_vars): +def test_dashboard_request_requires_auth_invalid_token(setup_cluster_with_token_auth): """Test that requests fail with invalid token when auth is enabled.""" - dashboard_url = start_ray_with_env_vars.address_info["webui_url"] - # Request with wrong token should fail with 403 - headers = {"Authorization": "Bearer INCORRECT_TOKEN"} + cluster_info = setup_cluster_with_token_auth + headers = {"Authorization": "Bearer wrong_token_00000000000000000000000000000000"} + response = requests.get( - f"http://{dashboard_url}/api/component_activities", + f"{cluster_info['dashboard_url']}/api/component_activities", json={"test": "data"}, headers=headers, ) + assert response.status_code == 403 -@pytest.mark.parametrize( - "start_ray_with_env_vars", - [ - { - "env_vars": {"RAY_auth_mode": "disabled"}, - }, - ], - indirect=True, -) -def test_auth_disabled(start_ray_with_env_vars): +def test_dashboard_auth_disabled(setup_cluster_without_token_auth): """Test that auth is not enforced when auth_mode is disabled.""" - dashboard_url = start_ray_with_env_vars.address_info["webui_url"] - # GET without auth should succeed when auth is disabled + cluster_info = setup_cluster_without_token_auth + response = requests.get( - f"http://{dashboard_url}/api/component_activities", json={"test": "data"} + f"{cluster_info['dashboard_url']}/api/component_activities", + json={"test": "data"}, ) - # Should not return 401 or 403 + assert response.status_code == 200 if __name__ == "__main__": + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index 3825933717ee..cc76aeb520c2 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -11,7 +11,7 @@ class AuthenticationMode: DISABLED = CAuthenticationMode.DISABLED TOKEN = CAuthenticationMode.TOKEN -_AUTHORIZATION_HEADER_NAME = "Authorization" +_AUTHORIZATION_HEADER_NAME = "authorization" def get_authentication_mode(): """Get the current authentication mode. diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 9d8928d2149c..08986d3b415e 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,7 +42,7 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; -constexpr char kAuthTokenKey[] = "Authorization"; +constexpr char kAuthTokenKey[] = "authorization"; constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = From 8da7fa4e56c1d13ec20f0d01e5f975963b6e743f Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 06:57:08 +0000 Subject: [PATCH 78/94] fix case sensitivity bug + tests Signed-off-by: sampan --- python/ray/dashboard/modules/dashboard_sdk.py | 48 ++++++++++++------- .../ray/includes/rpc_token_authentication.pxi | 32 +++++-------- .../ray/tests/test_submission_client_auth.py | 48 +++++++++++++++---- 3 files changed, 79 insertions(+), 49 deletions(-) diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index 2cdf738590a2..c28035a1baa4 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -226,9 +226,7 @@ def __init__( # Headers used for all requests sent to job server, optional and only # needed for cases like authentication to remote cluster. self._headers = cluster_info.headers or {} - - # Add authentication token if token auth is enabled - self._set_auth_header_if_enabled() + self._headers.update(**self._get_auth_headers()) # Set SSL verify parameter for the requests library and create an ssl_context # object when needed for the aiohttp library. @@ -249,21 +247,35 @@ def __init__( else: self._ssl_context = None - def _set_auth_header_if_enabled(self): - """Add authentication token to headers if token auth is enabled.""" - if is_token_auth_enabled(): - token_loader = AuthenticationTokenLoader.instance() - token_added = token_loader.set_token_for_http_header(self._headers) - - if not token_added: - # Token auth is enabled but no token found or Authorization already set - if "Authorization" not in self._headers: - # No token found - log warning but don't fail yet - # Let the server return 401 for a better error message - logger.warning( - "Token authentication is enabled but no token was found. " - "Requests to authenticated clusters will fail." - ) + def _get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers if token auth is enabled. + + Returns: + dict: Authentication headers to merge with request headers. + Empty dict if no auth needed or token unavailable. + """ + if not is_token_auth_enabled(): + return {} + + # Check if user provided their own Authorization header (case-insensitive) + has_user_auth = any( + key.lower() == "authorization" for key in self._headers.keys() + ) + if has_user_auth: + # User has provided their own auth header, don't override + return {} + + token_loader = AuthenticationTokenLoader.instance() + auth_headers = token_loader.get_token_for_http_header() + + if not auth_headers: + # Token auth enabled but no token found + logger.warning( + "Token authentication is enabled but no token was found. " + "Requests to authenticated clusters will fail." + ) + + return auth_headers def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index cc76aeb520c2..c8cfebd77066 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -71,35 +71,25 @@ class AuthenticationTokenLoader: """ CAuthenticationTokenLoader.instance().ResetCache() - def set_token_for_http_header(self, headers: dict): - """Add authentication token to HTTP headers dictionary if token auth is enabled. + def get_token_for_http_header(self) -> dict: + """Get authentication token as a dictionary for HTTP headers. - This method loads the token from C++ AuthenticationTokenLoader and adds it - to the provided headers dictionary as the Authorization header. It only adds - the token if: - - Token authentication is enabled - - A token exists - - The Authorization header is not already set in the headers - - Args: - headers: Dictionary of HTTP headers to modify (modified in-place) + This method loads the token from C++ AuthenticationTokenLoader and returns it + as a dictionary that can be merged with existing headers. It returns an empty + dictionary if: + - A token does not exist + - The token is empty Returns: - bool: True if token was added, False otherwise + dict: Empty dict or {"authorization": "Bearer "} """ - # Don't override if user explicitly set Authorization header - if _AUTHORIZATION_HEADER_NAME in headers: - return False - - # Check if token exists (doesn't crash, returns bool) if not self.has_token(): - return False + return {} # Get the token from C++ layer cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken() if not token_opt.has_value() or token_opt.value().empty(): - return False + return {} - headers[_AUTHORIZATION_HEADER_NAME] = token_opt.value().ToAuthorizationHeaderValue() - return True + return {_AUTHORIZATION_HEADER_NAME: token_opt.value().ToAuthorizationHeaderValue().decode('utf-8')} diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index 9e76dc234d26..33c63ac7d93c 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -21,9 +21,9 @@ def test_submission_client_adds_token_automatically(setup_cluster_with_token_aut client = SubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) - # Verify Authorization header was added - assert "Authorization" in client._headers - assert client._headers["Authorization"].startswith("Bearer ") + # Verify authorization header was added (lowercase as per implementation) + assert "authorization" in client._headers + assert client._headers["authorization"].startswith("Bearer ") def test_submission_client_without_token_shows_helpful_error( @@ -84,9 +84,9 @@ def test_job_submission_client_inherits_auth(setup_cluster_with_token_auth): """Test that JobSubmissionClient inherits auth from SubmissionClient.""" client = JobSubmissionClient(address=setup_cluster_with_token_auth["dashboard_url"]) - # Verify Authorization header was added - assert "Authorization" in client._headers - assert client._headers["Authorization"].startswith("Bearer ") + # Verify authorization header was added (lowercase as per implementation) + assert "authorization" in client._headers + assert client._headers["authorization"].startswith("Bearer ") # Verify client can make authenticated requests version = client.get_version() @@ -97,9 +97,9 @@ def test_state_api_client_inherits_auth(setup_cluster_with_token_auth): """Test that StateApiClient inherits auth from SubmissionClient.""" client = StateApiClient(address=setup_cluster_with_token_auth["dashboard_url"]) - # Verify Authorization header was added - assert "Authorization" in client._headers - assert client._headers["Authorization"].startswith("Bearer ") + # Verify authorization header was added (lowercase as per implementation) + assert "authorization" in client._headers + assert client._headers["authorization"].startswith("Bearer ") def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): @@ -115,6 +115,32 @@ def test_user_provided_header_not_overridden(setup_cluster_with_token_auth): assert client._headers["Authorization"] == custom_auth +def test_user_provided_header_case_insensitive(setup_cluster_with_token_auth): + """Test that user-provided Authorization header is preserved regardless of case.""" + custom_auth = "Bearer custom_token" + + # Test with lowercase "authorization" + client_lowercase = SubmissionClient( + address=setup_cluster_with_token_auth["dashboard_url"], + headers={"authorization": custom_auth}, + ) + + # Verify custom value is preserved and no duplicate header added + assert client_lowercase._headers["authorization"] == custom_auth + assert "Authorization" not in client_lowercase._headers + + # Test with mixed case "AuThOrIzAtIoN" + client_mixedcase = SubmissionClient( + address=setup_cluster_with_token_auth["dashboard_url"], + headers={"AuThOrIzAtIoN": custom_auth}, + ) + + # Verify custom value is preserved and no duplicate header added + assert client_mixedcase._headers["AuThOrIzAtIoN"] == custom_auth + assert "Authorization" not in client_mixedcase._headers + assert "authorization" not in client_mixedcase._headers + + def test_error_messages_contain_instructions(setup_cluster_with_token_auth): """Test that all auth error messages contain setup instructions.""" # Test 401 error (missing token) @@ -151,10 +177,12 @@ def test_error_messages_contain_instructions(setup_cluster_with_token_auth): def test_no_token_added_when_auth_disabled(setup_cluster_without_token_auth): - """Test that no Authorization header is injected when auth is disabled.""" + """Test that no authorization header is injected when auth is disabled.""" client = SubmissionClient(address=setup_cluster_without_token_auth["dashboard_url"]) + # Check both lowercase and uppercase variants + assert "authorization" not in client._headers assert "Authorization" not in client._headers From 0470e6db7760ad01b0c89c29575923356825d6b5 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 07:14:43 +0000 Subject: [PATCH 79/94] fix issues Signed-off-by: sampan --- .../authentication_constants.py | 2 +- .../http_token_authentication.py | 31 +++++++++++++------ python/ray/dashboard/modules/dashboard_sdk.py | 25 ++------------- .../ray/includes/rpc_token_authentication.pxi | 2 +- .../ray/tests/test_runtime_env_agent_auth.py | 15 +++++---- python/ray/util/client/server/proxier.py | 4 +-- src/ray/raylet/BUILD.bazel | 1 + 7 files changed, 35 insertions(+), 45 deletions(-) diff --git a/python/ray/_private/authentication/authentication_constants.py b/python/ray/_private/authentication/authentication_constants.py index 85e90edce9e1..92318233a3c3 100644 --- a/python/ray/_private/authentication/authentication_constants.py +++ b/python/ray/_private/authentication/authentication_constants.py @@ -22,4 +22,4 @@ + TOKEN_SETUP_INSTRUCTIONS ) -AUTHORIZATION_HEADER_NAME = "Authorization" +AUTHORIZATION_HEADER_NAME = "authorization" diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index cd1c45c4c41b..550539b206d0 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -29,20 +29,31 @@ async def token_auth_middleware(request: web.Request, handler): return await handler(request) -def inject_auth_token_if_enabled(headers: Dict[str, str]) -> bool: - """Inject Authorization header when token auth is enabled.""" - - if headers is None: - raise ValueError("headers must be provided") - - if authentication_constants.AUTHORIZATION_HEADER_NAME in headers: - return False +def get_auth_headers_if_auth_enabled(user_headers: Dict[str, str]) -> Dict[str, str]: if not auth_utils.is_token_auth_enabled(): - return False + return {} + + # Check if user provided their own Authorization header (case-insensitive) + has_user_auth = any( + key.lower() == authentication_constants.AUTHORIZATION_HEADER_NAME + for key in user_headers.keys() + ) + if has_user_auth: + # User has provided their own auth header, don't override + return {} token_loader = AuthenticationTokenLoader.instance() - return token_loader.set_token_for_http_header(headers) + auth_headers = token_loader.get_token_for_http_header() + + if not auth_headers: + # Token auth enabled but no token found + logger.warning( + "Token authentication is enabled but no token was found. " + "Requests to authenticated clusters will fail." + ) + + return auth_headers def format_authentication_http_error(status: int, body: str) -> Optional[str]: diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index e915320e3b50..5429c60b96d7 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -14,7 +14,7 @@ import ray from ray._private.authentication.http_token_authentication import ( format_authentication_http_error, - inject_auth_token_if_enabled, + get_auth_headers_if_auth_enabled, ) from ray._private.runtime_env.packaging import ( create_package, @@ -255,28 +255,7 @@ def _get_auth_headers(self) -> Dict[str, str]: dict: Authentication headers to merge with request headers. Empty dict if no auth needed or token unavailable. """ - if not is_token_auth_enabled(): - return {} - - # Check if user provided their own Authorization header (case-insensitive) - has_user_auth = any( - key.lower() == "authorization" for key in self._headers.keys() - ) - if has_user_auth: - # User has provided their own auth header, don't override - return {} - - token_loader = AuthenticationTokenLoader.instance() - auth_headers = token_loader.get_token_for_http_header() - - if not auth_headers: - # Token auth enabled but no token found - logger.warning( - "Token authentication is enabled but no token was found. " - "Requests to authenticated clusters will fail." - ) - - return auth_headers + return get_auth_headers_if_auth_enabled(self._headers) def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index eef34f23e5f3..42bd0f8e356c 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -91,4 +91,4 @@ class AuthenticationTokenLoader: if not token_opt.has_value() or token_opt.value().empty(): return {} - return {_AUTHORIZATION_HEADER_NAME: token_opt.value().ToAuthorizationHeaderValue().decode('utf-8')} + return {AUTHORIZATION_HEADER_NAME: token_opt.value().ToAuthorizationHeaderValue().decode('utf-8')} diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 45100bf11c3d..879e1c739277 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -10,7 +10,7 @@ from ray._common.test_utils import wait_for_condition from ray._private.authentication.http_token_authentication import ( format_authentication_http_error, - inject_auth_token_if_enabled, + get_auth_headers_if_auth_enabled, ) from ray.core.generated import runtime_env_agent_pb2 from ray.tests.authentication_test_utils import ( @@ -124,10 +124,10 @@ def test_inject_token_if_enabled_adds_header(cleanup_auth_token_env): reset_auth_token_state() headers = {} - added = inject_auth_token_if_enabled(headers) + headers_to_add = get_auth_headers_if_auth_enabled(headers) - assert added is True - auth_header = headers["Authorization"] + assert headers_to_add != {} + auth_header = headers_to_add["authorization"] if isinstance(auth_header, bytes): auth_header = auth_header.decode("utf-8") assert auth_header == "Bearer apptoken1234567890" @@ -138,11 +138,10 @@ def test_inject_token_if_enabled_respects_existing_header(cleanup_auth_token_env set_env_auth_token("apptoken1234567890") reset_auth_token_state() - headers = {"Authorization": "Bearer custom"} - added = inject_auth_token_if_enabled(headers) + headers = {"authorization": "Bearer custom"} + headers_to_add = get_auth_headers_if_auth_enabled(headers) - assert added is False - assert headers["Authorization"] == "Bearer custom" + assert headers_to_add == {} def test_format_authentication_http_error_non_auth_status(): diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 7ba3192046e5..84ce7f138b4c 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -21,7 +21,7 @@ from ray._common.network_utils import build_address, is_localhost from ray._private.authentication.http_token_authentication import ( format_authentication_http_error, - inject_auth_token_if_enabled, + get_auth_headers_if_auth_enabled, ) from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams @@ -256,7 +256,7 @@ def _create_runtime_env( ) data = create_env_request.SerializeToString() headers = {"Content-Type": "application/octet-stream"} - inject_auth_token_if_enabled(headers) + headers.update(**get_auth_headers_if_auth_enabled(headers)) req = urllib.request.Request( url, data=data, method="POST", headers=headers ) diff --git a/src/ray/raylet/BUILD.bazel b/src/ray/raylet/BUILD.bazel index 0db9a2516d0b..c60d43cadae8 100644 --- a/src/ray/raylet/BUILD.bazel +++ b/src/ray/raylet/BUILD.bazel @@ -158,6 +158,7 @@ ray_cc_library( "//src/ray/common:status", "//src/ray/protobuf:gcs_cc_proto", "//src/ray/protobuf:runtime_env_agent_cc_proto", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:logging", "//src/ray/util:process", "//src/ray/util:time", From 605a35360bc643bcc784f3fbefc22fa5e686c6f3 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 11:21:15 +0000 Subject: [PATCH 80/94] fix lint issues Signed-off-by: sampan --- .../ray/_private/authentication/http_token_authentication.py | 3 +++ python/ray/util/client/server/proxier.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 550539b206d0..c3ec330e6947 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -1,3 +1,4 @@ +import logging import sys from typing import Dict, Optional @@ -8,6 +9,8 @@ from ray._raylet import AuthenticationTokenLoader from ray.dashboard import authentication_utils as auth_utils +logger = logging.getLogger(__name__) + @web.middleware async def token_auth_middleware(request: web.Request, handler): diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 84ce7f138b4c..5fc894bdea7c 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -18,7 +18,7 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 -from ray._common.network_utils import build_address, is_localhost +from ray._common.network_utils import build_address, is_ipv6, is_localhost from ray._private.authentication.http_token_authentication import ( format_authentication_http_error, get_auth_headers_if_auth_enabled, From ef0bbabb896a27cdf7cdb447d04d398a011bb660 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 11:36:06 +0000 Subject: [PATCH 81/94] fix doc build Signed-off-by: sampan --- .../http_token_authentication.py | 12 +++++++++-- python/ray/dashboard/authentication_utils.py | 20 +++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index c3ec330e6947..332e08fdcd28 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -6,9 +6,17 @@ from aiohttp import web from ray._private.authentication import authentication_constants -from ray._raylet import AuthenticationTokenLoader from ray.dashboard import authentication_utils as auth_utils +try: + from ray._raylet import AuthenticationTokenLoader + + _RAYLET_AVAILABLE = True +except ImportError: + # ray._raylet not available during doc builds + _RAYLET_AVAILABLE = False + AuthenticationTokenLoader = None # type: ignore + logger = logging.getLogger(__name__) @@ -34,7 +42,7 @@ async def token_auth_middleware(request: web.Request, handler): def get_auth_headers_if_auth_enabled(user_headers: Dict[str, str]) -> Dict[str, str]: - if not auth_utils.is_token_auth_enabled(): + if not _RAYLET_AVAILABLE or not auth_utils.is_token_auth_enabled(): return {} # Check if user provided their own Authorization header (case-insensitive) diff --git a/python/ray/dashboard/authentication_utils.py b/python/ray/dashboard/authentication_utils.py index 2b39b8d00eb6..a87821e64509 100644 --- a/python/ray/dashboard/authentication_utils.py +++ b/python/ray/dashboard/authentication_utils.py @@ -1,8 +1,14 @@ -from ray._raylet import ( - AuthenticationMode, - get_authentication_mode, - validate_authentication_token, -) +try: + from ray._raylet import ( + AuthenticationMode, + get_authentication_mode, + validate_authentication_token, + ) + + _RAYLET_AVAILABLE = True +except ImportError: + # ray._raylet not available during doc builds + _RAYLET_AVAILABLE = False def is_token_auth_enabled() -> bool: @@ -11,6 +17,8 @@ def is_token_auth_enabled() -> bool: Returns: bool: True if auth_mode is set to "token", False otherwise """ + if not _RAYLET_AVAILABLE: + return False return get_authentication_mode() == AuthenticationMode.TOKEN @@ -23,7 +31,7 @@ def validate_request_token(auth_header: str) -> bool: Returns: bool: True if token is valid, False otherwise """ - if not auth_header: + if not _RAYLET_AVAILABLE or not auth_header: return False # validate_authentication_token expects full "Bearer " format From 2ac449b179ea6fe79e618c8d384b464b0df4da47 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 14:30:22 +0000 Subject: [PATCH 82/94] fix builds Signed-off-by: sampan --- .../http_token_authentication.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 332e08fdcd28..3bf7f4732cf4 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -2,9 +2,6 @@ import sys from typing import Dict, Optional -import pytest -from aiohttp import web - from ray._private.authentication import authentication_constants from ray.dashboard import authentication_utils as auth_utils @@ -13,16 +10,22 @@ _RAYLET_AVAILABLE = True except ImportError: - # ray._raylet not available during doc builds + # ray._raylet not available during doc builds or minimal installs _RAYLET_AVAILABLE = False AuthenticationTokenLoader = None # type: ignore logger = logging.getLogger(__name__) -@web.middleware -async def token_auth_middleware(request: web.Request, handler): - """Middleware to validate bearer tokens when token authentication is enabled.""" +async def token_auth_middleware(request, handler): + """Middleware to validate bearer tokens when token authentication is enabled. + + This is an aiohttp middleware that requires aiohttp to be installed. + Import aiohttp only when this function is called (not at module load time). + """ + # Import aiohttp here to avoid breaking minimal installs + from aiohttp import web + if not auth_utils.is_token_auth_enabled(): return await handler(request) @@ -86,4 +89,6 @@ def format_authentication_http_error(status: int, body: str) -> Optional[str]: if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-vv", __file__])) From a86640e4a4aaefc1d08b20af4300c7823df44f5b Mon Sep 17 00:00:00 2001 From: sampan Date: Sat, 1 Nov 2025 14:55:11 +0000 Subject: [PATCH 83/94] Fix HTTP error handling in proxier to re-raise non-auth errors When an HTTP error occurs that isn't an auth error (401/403), we should immediately re-raise it rather than continuing to retry. This ensures errors like 500 are properly propagated to the caller. Signed-off-by: sampan --- python/ray/util/client/server/proxier.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 5fc894bdea7c..202165fdd956 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -287,6 +287,8 @@ def _create_runtime_env( formatted_error = format_authentication_http_error(e.code, body or "") if formatted_error: raise RuntimeError(formatted_error) from e + # Re-raise non-auth HTTP errors immediately + raise except urllib.error.URLError as e: last_exception = e From 5c361f290bfbdc4f57eabc1d05acb4c96ff05005 Mon Sep 17 00:00:00 2001 From: sampan Date: Sat, 1 Nov 2025 14:58:50 +0000 Subject: [PATCH 84/94] fix lint Signed-off-by: sampan --- python/ray/dashboard/modules/dashboard_sdk.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index b0a332a8dcff..9b519f40c1db 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -24,9 +24,7 @@ from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed from ray._private.utils import split_address -from ray._raylet import AuthenticationTokenLoader from ray.autoscaler._private.cli_logger import cli_logger -from ray.dashboard.authentication_utils import is_token_auth_enabled from ray.dashboard.modules.job.common import uri_to_http_components from ray.util.annotations import DeveloperAPI, PublicAPI From 4a47255db7783039443c71084c51faffba42b3d7 Mon Sep 17 00:00:00 2001 From: sampan Date: Sat, 1 Nov 2025 15:10:37 +0000 Subject: [PATCH 85/94] dont raise exception in dashboard_sdk Signed-off-by: sampan --- python/ray/dashboard/modules/dashboard_sdk.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index 9b519f40c1db..5429c60b96d7 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -334,9 +334,6 @@ def _do_request( if formatted_error: raise RuntimeError(formatted_error) - # Raise for any other HTTP error status codes - response.raise_for_status() - return response def _package_exists( From 76b104651cfd3dd28decccb65ffc911fbb8cc087 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 13:08:53 +0000 Subject: [PATCH 86/94] retry for non auth errors Signed-off-by: sampan --- python/ray/util/client/server/proxier.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 202165fdd956..4abab76a825d 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -287,8 +287,14 @@ def _create_runtime_env( formatted_error = format_authentication_http_error(e.code, body or "") if formatted_error: raise RuntimeError(formatted_error) from e - # Re-raise non-auth HTTP errors immediately - raise + + # Treat non-auth HTTP errors like URLError (retry with backoff) + last_exception = e + logger.warning( + f"GetOrCreateRuntimeEnv request failed with HTTP {e.code}: {body or e}. " + f"Retrying after {wait_time_s}s. " + f"{max_retries-retries} retries remaining." + ) except urllib.error.URLError as e: last_exception = e From 2c057c6807874e48a21005c98bbb16511c29653d Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 15:45:51 +0000 Subject: [PATCH 87/94] try fix issues in minimal builds Signed-off-by: sampan --- .../authentication/http_token_authentication.py | 7 +++++++ python/ray/dashboard/tests/test_dashboard_auth.py | 12 ++++++++++++ python/ray/tests/test_runtime_env_agent_auth.py | 12 ++++++++++++ python/ray/tests/test_submission_client_auth.py | 12 ++++++++++++ python/ray/tests/test_token_auth_integration.py | 15 ++++++++++++++- 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 3bf7f4732cf4..ea9ff1c1caa3 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -22,7 +22,14 @@ async def token_auth_middleware(request, handler): This is an aiohttp middleware that requires aiohttp to be installed. Import aiohttp only when this function is called (not at module load time). + + In minimal Ray installations (without ray._raylet), this middleware is a no-op + and passes all requests through without authentication. """ + # Skip auth entirely in minimal installs where ray._raylet is not available + if not _RAYLET_AVAILABLE: + return await handler(request) + # Import aiohttp here to avoid breaking minimal installs from aiohttp import web diff --git a/python/ray/dashboard/tests/test_dashboard_auth.py b/python/ray/dashboard/tests/test_dashboard_auth.py index 7407fc199a1d..49dca408186c 100644 --- a/python/ray/dashboard/tests/test_dashboard_auth.py +++ b/python/ray/dashboard/tests/test_dashboard_auth.py @@ -5,6 +5,18 @@ import pytest import requests +try: + from ray._raylet import AuthenticationTokenLoader + + _RAYLET_AVAILABLE = True +except ImportError: + _RAYLET_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not _RAYLET_AVAILABLE, + reason="Authentication tests require ray._raylet (not available in minimal installs)", +) + def test_dashboard_request_requires_auth_with_valid_token( setup_cluster_with_token_auth, diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 879e1c739277..015910d02b91 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -19,6 +19,18 @@ set_env_auth_token, ) +try: + from ray._raylet import AuthenticationTokenLoader + + _RAYLET_AVAILABLE = True +except ImportError: + _RAYLET_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not _RAYLET_AVAILABLE, + reason="Authentication tests require ray._raylet (not available in minimal installs)", +) + def _agent_url(agent_address: str, path: str) -> str: return urllib.parse.urljoin(agent_address, path) diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index 33c63ac7d93c..b57e76561fdc 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -14,6 +14,18 @@ ) from ray.util.state import StateApiClient +try: + from ray._raylet import AuthenticationTokenLoader + + _RAYLET_AVAILABLE = True +except ImportError: + _RAYLET_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not _RAYLET_AVAILABLE, + reason="Authentication tests require ray._raylet (not available in minimal installs)", +) + def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): """Test that SubmissionClient automatically adds token to headers.""" diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 2ea4950597f9..9765d7fd1705 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -10,7 +10,15 @@ import ray from ray._private.test_utils import wait_for_condition -from ray._raylet import AuthenticationTokenLoader + +try: + from ray._raylet import AuthenticationTokenLoader + + _RAYLET_AVAILABLE = True +except ImportError: + _RAYLET_AVAILABLE = False + AuthenticationTokenLoader = None + from ray.tests.authentication_test_utils import ( clear_auth_token_sources, reset_auth_token_state, @@ -18,6 +26,11 @@ set_env_auth_token, ) +pytestmark = pytest.mark.skipif( + not _RAYLET_AVAILABLE, + reason="Authentication tests require ray._raylet (not available in minimal installs)", +) + def _run_ray_start_and_verify_status( args: list, env: dict, expect_success: bool = True, timeout: int = 30 From 1578afe12a3904b42befcf45eb8c6c78e225670a Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 15:46:47 +0000 Subject: [PATCH 88/94] fix lint Signed-off-by: sampan --- python/ray/_private/authentication/http_token_authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index ea9ff1c1caa3..4abaa30a8679 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -22,7 +22,7 @@ async def token_auth_middleware(request, handler): This is an aiohttp middleware that requires aiohttp to be installed. Import aiohttp only when this function is called (not at module load time). - + In minimal Ray installations (without ray._raylet), this middleware is a no-op and passes all requests through without authentication. """ From e28f133bee301b6ab2f2bea545ea4cdd320fd1ff Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 15:48:41 +0000 Subject: [PATCH 89/94] fix lint Signed-off-by: sampan --- python/ray/dashboard/tests/test_dashboard_auth.py | 12 ------------ python/ray/tests/test_runtime_env_agent_auth.py | 12 ------------ python/ray/tests/test_submission_client_auth.py | 12 ------------ 3 files changed, 36 deletions(-) diff --git a/python/ray/dashboard/tests/test_dashboard_auth.py b/python/ray/dashboard/tests/test_dashboard_auth.py index 49dca408186c..7407fc199a1d 100644 --- a/python/ray/dashboard/tests/test_dashboard_auth.py +++ b/python/ray/dashboard/tests/test_dashboard_auth.py @@ -5,18 +5,6 @@ import pytest import requests -try: - from ray._raylet import AuthenticationTokenLoader - - _RAYLET_AVAILABLE = True -except ImportError: - _RAYLET_AVAILABLE = False - -pytestmark = pytest.mark.skipif( - not _RAYLET_AVAILABLE, - reason="Authentication tests require ray._raylet (not available in minimal installs)", -) - def test_dashboard_request_requires_auth_with_valid_token( setup_cluster_with_token_auth, diff --git a/python/ray/tests/test_runtime_env_agent_auth.py b/python/ray/tests/test_runtime_env_agent_auth.py index 015910d02b91..879e1c739277 100644 --- a/python/ray/tests/test_runtime_env_agent_auth.py +++ b/python/ray/tests/test_runtime_env_agent_auth.py @@ -19,18 +19,6 @@ set_env_auth_token, ) -try: - from ray._raylet import AuthenticationTokenLoader - - _RAYLET_AVAILABLE = True -except ImportError: - _RAYLET_AVAILABLE = False - -pytestmark = pytest.mark.skipif( - not _RAYLET_AVAILABLE, - reason="Authentication tests require ray._raylet (not available in minimal installs)", -) - def _agent_url(agent_address: str, path: str) -> str: return urllib.parse.urljoin(agent_address, path) diff --git a/python/ray/tests/test_submission_client_auth.py b/python/ray/tests/test_submission_client_auth.py index b57e76561fdc..33c63ac7d93c 100644 --- a/python/ray/tests/test_submission_client_auth.py +++ b/python/ray/tests/test_submission_client_auth.py @@ -14,18 +14,6 @@ ) from ray.util.state import StateApiClient -try: - from ray._raylet import AuthenticationTokenLoader - - _RAYLET_AVAILABLE = True -except ImportError: - _RAYLET_AVAILABLE = False - -pytestmark = pytest.mark.skipif( - not _RAYLET_AVAILABLE, - reason="Authentication tests require ray._raylet (not available in minimal installs)", -) - def test_submission_client_adds_token_automatically(setup_cluster_with_token_auth): """Test that SubmissionClient automatically adds token to headers.""" From bf4aade0b7abc6a8274b0960299a774e52bc0759 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 03:00:24 +0000 Subject: [PATCH 90/94] fix token auth middleware Signed-off-by: sampan --- .../http_token_authentication.py | 38 +++++++------------ python/ray/dashboard/modules/dashboard_sdk.py | 11 +----- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 4abaa30a8679..f60db3656075 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -1,22 +1,18 @@ import logging -import sys from typing import Dict, Optional from ray._private.authentication import authentication_constants from ray.dashboard import authentication_utils as auth_utils -try: - from ray._raylet import AuthenticationTokenLoader - - _RAYLET_AVAILABLE = True -except ImportError: - # ray._raylet not available during doc builds or minimal installs - _RAYLET_AVAILABLE = False - AuthenticationTokenLoader = None # type: ignore +# All third-party dependencies that are not included in the minimal Ray +# installation must be included in this file. This allows us to determine if +# the agent has the necessary dependencies to be started. +from ray.dashboard.optional_deps import aiohttp logger = logging.getLogger(__name__) +@aiohttp.web.middleware async def token_auth_middleware(request, handler): """Middleware to validate bearer tokens when token authentication is enabled. @@ -26,13 +22,7 @@ async def token_auth_middleware(request, handler): In minimal Ray installations (without ray._raylet), this middleware is a no-op and passes all requests through without authentication. """ - # Skip auth entirely in minimal installs where ray._raylet is not available - if not _RAYLET_AVAILABLE: - return await handler(request) - - # Import aiohttp here to avoid breaking minimal installs - from aiohttp import web - + # No-op if token auth is not enabled or raylet is not available if not auth_utils.is_token_auth_enabled(): return await handler(request) @@ -40,21 +30,25 @@ async def token_auth_middleware(request, handler): authentication_constants.AUTHORIZATION_HEADER_NAME, "" ) if not auth_header: - return web.Response( + return aiohttp.web.Response( status=401, text="Unauthorized: Missing authentication token" ) if not auth_utils.validate_request_token(auth_header): - return web.Response(status=403, text="Forbidden: Invalid authentication token") + return aiohttp.web.Response( + status=403, text="Forbidden: Invalid authentication token" + ) return await handler(request) def get_auth_headers_if_auth_enabled(user_headers: Dict[str, str]) -> Dict[str, str]: - if not _RAYLET_AVAILABLE or not auth_utils.is_token_auth_enabled(): + if not auth_utils.is_token_auth_enabled(): return {} + from ray._raylet import AuthenticationTokenLoader + # Check if user provided their own Authorization header (case-insensitive) has_user_auth = any( key.lower() == authentication_constants.AUTHORIZATION_HEADER_NAME @@ -93,9 +87,3 @@ def format_authentication_http_error(status: int, body: str) -> Optional[str]: ) return None - - -if __name__ == "__main__": - import pytest - - sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/dashboard/modules/dashboard_sdk.py b/python/ray/dashboard/modules/dashboard_sdk.py index 5429c60b96d7..c38469d5d2ad 100644 --- a/python/ray/dashboard/modules/dashboard_sdk.py +++ b/python/ray/dashboard/modules/dashboard_sdk.py @@ -227,7 +227,7 @@ def __init__( # Headers used for all requests sent to job server, optional and only # needed for cases like authentication to remote cluster. self._headers = cluster_info.headers or {} - self._headers.update(**self._get_auth_headers()) + self._headers.update(**get_auth_headers_if_auth_enabled(self._headers)) # Set SSL verify parameter for the requests library and create an ssl_context # object when needed for the aiohttp library. @@ -248,15 +248,6 @@ def __init__( else: self._ssl_context = None - def _get_auth_headers(self) -> Dict[str, str]: - """Get authentication headers if token auth is enabled. - - Returns: - dict: Authentication headers to merge with request headers. - Empty dict if no auth needed or token unavailable. - """ - return get_auth_headers_if_auth_enabled(self._headers) - def _check_connection_and_version( self, min_version: str = "1.9", version_error_message: str = None ): From 9de6e848ba59f740c1c66bc5776a464c26c91e60 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 03:50:41 +0000 Subject: [PATCH 91/94] imort aiohttp lazily to avoid breaking minimal builds Signed-off-by: sampan --- .../http_token_authentication.py | 53 +++++++++---------- python/ray/_private/runtime_env/agent/main.py | 4 +- python/ray/dashboard/http_server_head.py | 4 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index f60db3656075..5d6d136a97d0 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -4,42 +4,41 @@ from ray._private.authentication import authentication_constants from ray.dashboard import authentication_utils as auth_utils -# All third-party dependencies that are not included in the minimal Ray -# installation must be included in this file. This allows us to determine if -# the agent has the necessary dependencies to be started. -from ray.dashboard.optional_deps import aiohttp - logger = logging.getLogger(__name__) -@aiohttp.web.middleware -async def token_auth_middleware(request, handler): - """Middleware to validate bearer tokens when token authentication is enabled. +def get_token_auth_middleware(): + # aiohttp is not included in minimal Ray installations, import it here to avoid breaking minimal installs + from ray.dashboard.optional_deps import aiohttp - This is an aiohttp middleware that requires aiohttp to be installed. - Import aiohttp only when this function is called (not at module load time). + @aiohttp.web.middleware + async def token_auth_middleware(request, handler): + """Middleware to validate bearer tokens when token authentication is enabled. - In minimal Ray installations (without ray._raylet), this middleware is a no-op - and passes all requests through without authentication. - """ - # No-op if token auth is not enabled or raylet is not available - if not auth_utils.is_token_auth_enabled(): - return await handler(request) + This is an aiohttp middleware that requires aiohttp to be installed. + Import aiohttp only when this function is called (not at module load time). - auth_header = request.headers.get( - authentication_constants.AUTHORIZATION_HEADER_NAME, "" - ) - if not auth_header: - return aiohttp.web.Response( - status=401, text="Unauthorized: Missing authentication token" - ) + In minimal Ray installations (without ray._raylet), this middleware is a no-op + and passes all requests through without authentication. + """ + # No-op if token auth is not enabled or raylet is not available + if not auth_utils.is_token_auth_enabled(): + return await handler(request) - if not auth_utils.validate_request_token(auth_header): - return aiohttp.web.Response( - status=403, text="Forbidden: Invalid authentication token" + auth_header = request.headers.get( + authentication_constants.AUTHORIZATION_HEADER_NAME, "" ) + if not auth_header: + return aiohttp.web.Response( + status=401, text="Unauthorized: Missing authentication token" + ) + + if not auth_utils.validate_request_token(auth_header): + return aiohttp.web.Response( + status=403, text="Forbidden: Invalid authentication token" + ) - return await handler(request) + return await handler(request) def get_auth_headers_if_auth_enabled(user_headers: Dict[str, str]) -> Dict[str, str]: diff --git a/python/ray/_private/runtime_env/agent/main.py b/python/ray/_private/runtime_env/agent/main.py index f31a927a857e..54bd8091b043 100644 --- a/python/ray/_private/runtime_env/agent/main.py +++ b/python/ray/_private/runtime_env/agent/main.py @@ -9,7 +9,7 @@ ) from ray._private import logging_utils from ray._private.authentication.http_token_authentication import ( - token_auth_middleware, + get_token_auth_middleware, ) from ray._private.process_watcher import create_check_raylet_task from ray._raylet import GcsClient @@ -197,7 +197,7 @@ async def get_runtime_envs_info(request: web.Request) -> web.Response: body=reply.SerializeToString(), content_type="application/octet-stream" ) - app = web.Application(middlewares=[token_auth_middleware]) + app = web.Application(middlewares=[get_token_auth_middleware()]) app.router.add_post("/get_or_create_runtime_env", get_or_create_runtime_env) app.router.add_post( diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 00f6228e0844..d6cf0913e0ba 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -21,7 +21,7 @@ from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop from ray._private.authentication.http_token_authentication import ( - token_auth_middleware, + get_token_auth_middleware, ) from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics from ray.dashboard.head import DashboardHeadModule @@ -256,7 +256,7 @@ async def run( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, middlewares=[ self.metrics_middleware, - token_auth_middleware, + get_token_auth_middleware(), self.path_clean_middleware, self.browsers_no_post_put_middleware, self.cache_control_static_middleware, From 88fe5ecd3ba90868e7d6f31aaa66cc4bbdafe9e3 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 08:46:54 +0000 Subject: [PATCH 92/94] return middleware Signed-off-by: sampan --- python/ray/_private/authentication/http_token_authentication.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 5d6d136a97d0..14e230d80191 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -40,6 +40,8 @@ async def token_auth_middleware(request, handler): return await handler(request) + return token_auth_middleware + def get_auth_headers_if_auth_enabled(user_headers: Dict[str, str]) -> Dict[str, str]: From 9232cdf609e0479fd68acc12d07b703ef33138d6 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 11:33:05 +0000 Subject: [PATCH 93/94] acceppt aiohttp as a factory param Signed-off-by: sampan --- .../http_token_authentication.py | 22 ++++++++++--------- python/ray/_private/runtime_env/agent/main.py | 3 ++- python/ray/dashboard/http_server_head.py | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 14e230d80191..c600c0c025bf 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -7,21 +7,23 @@ logger = logging.getLogger(__name__) -def get_token_auth_middleware(): - # aiohttp is not included in minimal Ray installations, import it here to avoid breaking minimal installs - from ray.dashboard.optional_deps import aiohttp +def get_token_auth_middleware(aiohttp_module): + """Internal helper to create token auth middleware with provided modules. - @aiohttp.web.middleware + Args: + aiohttp_module: The aiohttp module to use + Returns: + An aiohttp middleware function + """ + + @aiohttp_module.web.middleware async def token_auth_middleware(request, handler): """Middleware to validate bearer tokens when token authentication is enabled. - This is an aiohttp middleware that requires aiohttp to be installed. - Import aiohttp only when this function is called (not at module load time). - In minimal Ray installations (without ray._raylet), this middleware is a no-op and passes all requests through without authentication. """ - # No-op if token auth is not enabled or raylet is not available + # No-op if token auth is not enabled or raylet is not available if not auth_utils.is_token_auth_enabled(): return await handler(request) @@ -29,12 +31,12 @@ async def token_auth_middleware(request, handler): authentication_constants.AUTHORIZATION_HEADER_NAME, "" ) if not auth_header: - return aiohttp.web.Response( + return aiohttp_module.web.Response( status=401, text="Unauthorized: Missing authentication token" ) if not auth_utils.validate_request_token(auth_header): - return aiohttp.web.Response( + return aiohttp_module.web.Response( status=403, text="Forbidden: Invalid authentication token" ) diff --git a/python/ray/_private/runtime_env/agent/main.py b/python/ray/_private/runtime_env/agent/main.py index 54bd8091b043..c0d68bdf6e04 100644 --- a/python/ray/_private/runtime_env/agent/main.py +++ b/python/ray/_private/runtime_env/agent/main.py @@ -26,6 +26,7 @@ def import_libs(): import_libs() +import aiohttp # noqa: E402 import runtime_env_consts # noqa: E402 from aiohttp import web # noqa: E402 from runtime_env_agent import RuntimeEnvAgent # noqa: E402 @@ -197,7 +198,7 @@ async def get_runtime_envs_info(request: web.Request) -> web.Response: body=reply.SerializeToString(), content_type="application/octet-stream" ) - app = web.Application(middlewares=[get_token_auth_middleware()]) + app = web.Application(middlewares=[get_token_auth_middleware(aiohttp)]) app.router.add_post("/get_or_create_runtime_env", get_or_create_runtime_env) app.router.add_post( diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index d6cf0913e0ba..f08c84a91dd5 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -256,7 +256,7 @@ async def run( client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE, middlewares=[ self.metrics_middleware, - get_token_auth_middleware(), + get_token_auth_middleware(aiohttp), self.path_clean_middleware, self.browsers_no_post_put_middleware, self.cache_control_static_middleware, From a79bcca542b87c64040ded285912901db1de8409 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 11:41:16 +0000 Subject: [PATCH 94/94] fix lint issues Signed-off-by: sampan --- .../ray/_private/authentication/http_token_authentication.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index c600c0c025bf..8f68e9893815 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -1,4 +1,5 @@ import logging +from types import ModuleType from typing import Dict, Optional from ray._private.authentication import authentication_constants @@ -7,7 +8,7 @@ logger = logging.getLogger(__name__) -def get_token_auth_middleware(aiohttp_module): +def get_token_auth_middleware(aiohttp_module: ModuleType): """Internal helper to create token auth middleware with provided modules. Args: