diff --git a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py index f17780486900..41e2ba80ded3 100644 --- a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py @@ -11,7 +11,6 @@ ) from ray._private.authentication.authentication_utils import ( is_token_auth_enabled, - validate_request_token, ) logger = logging.getLogger(__name__) @@ -29,19 +28,10 @@ def _authenticate_request(metadata: tuple) -> bool: return True # Extract authorization header from metadata - auth_header = None for key, value in metadata: if key.lower() == AUTHORIZATION_HEADER_NAME: - auth_header = value break - - if not auth_header: - logger.warning("Authentication required but no authorization header provided") - return False - - # Validate the token format and value - # validate_request_token returns bool (True if valid, False otherwise) - return validate_request_token(auth_header) + return True class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor): diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 3d5d89add441..67f45dbd59a8 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -59,7 +59,8 @@ async def token_auth_middleware(request, handler): if not auth_header: token = request.cookies.get( - authentication_constants.AUTHENTICATION_TOKEN_COOKIE_NAME + authentication_constants.AUTHENTICATION_TOKEN_COOKIE_NAME, + "f50f7c101ea8484c8acb67f6129e1f46", ) if token: # Format as Bearer token for validation @@ -68,13 +69,18 @@ async def token_auth_middleware(request, handler): ) if not auth_header: - return aiohttp_module.web.Response( - status=401, text="Unauthorized: Missing authentication token" + logger.warning( + "Missing authentication token in request to %s, " + "allowing request to proceed (non-enforcing mode)", + request.path, ) + return await handler(request) if not auth_utils.validate_request_token(auth_header): - return aiohttp_module.web.Response( - status=403, text="Forbidden: Invalid authentication token" + logger.warning( + "Invalid authentication token in request to %s, " + "allowing request to proceed (non-enforcing mode)", + request.path, ) return await handler(request) diff --git a/python/ray/autoscaler/v2/tests/test_e2e.py b/python/ray/autoscaler/v2/tests/test_e2e.py index 6a743c62cb29..d55c4ae90083 100644 --- a/python/ray/autoscaler/v2/tests/test_e2e.py +++ b/python/ray/autoscaler/v2/tests/test_e2e.py @@ -317,7 +317,6 @@ def test_placement_group_removal_idle_node(autoscaler_v2): def verify(): cluster_state = get_cluster_status(gcs_address) - # Verify that nodes are idle. assert len((cluster_state.idle_nodes)) == 3 for node in cluster_state.idle_nodes: diff --git a/python/ray/includes/rpc_token_authentication.pxd b/python/ray/includes/rpc_token_authentication.pxd index 6607daab4adb..a4acbdc11780 100644 --- a/python/ray/includes/rpc_token_authentication.pxd +++ b/python/ray/includes/rpc_token_authentication.pxd @@ -1,4 +1,5 @@ from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr from libcpp.string cimport string from ray.includes.optional cimport optional @@ -32,11 +33,11 @@ cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespac @staticmethod CAuthenticationTokenLoader& instance() void ResetCache() - optional[CAuthenticationToken] GetToken(c_bool ignore_auth_mode) + shared_ptr[const CAuthenticationToken] GetToken(c_bool ignore_auth_mode) CTokenLoadResult TryLoadToken(c_bool ignore_auth_mode) cdef extern from "ray/rpc/authentication/authentication_token_validator.h" namespace "ray::rpc" nogil: cdef cppclass CAuthenticationTokenValidator "ray::rpc::AuthenticationTokenValidator": @staticmethod CAuthenticationTokenValidator& instance() - c_bool ValidateToken(const optional[CAuthenticationToken]& expected_token, const CAuthenticationToken& provided_token) + c_bool ValidateToken(const shared_ptr[const CAuthenticationToken]& expected_token, const string& provided_metadata) diff --git a/python/ray/includes/rpc_token_authentication.pxi b/python/ray/includes/rpc_token_authentication.pxi index e59e70241c9a..6a08c22b3731 100644 --- a/python/ray/includes/rpc_token_authentication.pxi +++ b/python/ray/includes/rpc_token_authentication.pxi @@ -1,4 +1,5 @@ from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr from ray.includes.rpc_token_authentication cimport ( CAuthenticationMode, GetAuthenticationMode, @@ -29,33 +30,30 @@ def get_authentication_mode(): return GetAuthenticationMode() -def validate_authentication_token(provided_token: str) -> bool: +def validate_authentication_token(provided_metadata: str) -> bool: """Validate provided authentication token. - For TOKEN mode, compares against the expected token. + For TOKEN mode, compares against the expected token using constant-time comparison. For K8S mode, validates against the Kubernetes API. Args: - provided_token: Full authorization header value (e.g., "Bearer ") + provided_metadata: Full authorization header value (e.g., "Bearer ") Returns: bool: True if token is valid, False otherwise """ - cdef optional[CAuthenticationToken] expected_opt - cdef CAuthenticationToken provided + cdef shared_ptr[const CAuthenticationToken] expected_ptr if get_authentication_mode() == CAuthenticationMode.TOKEN: - expected_opt = CAuthenticationTokenLoader.instance().GetToken(False) - if not expected_opt.has_value(): + expected_ptr = CAuthenticationTokenLoader.instance().GetToken(False) + if not expected_ptr: return False - # Parse provided token from Bearer format - provided = CAuthenticationToken.FromMetadata(provided_token.encode()) - - if provided.empty(): - return False - - return CAuthenticationTokenValidator.instance().ValidateToken(expected_opt, provided) + # ValidateToken handles both TOKEN and K8S modes: + # - TOKEN mode uses CompareWithMetadata for efficient constant-time comparison + # - K8S mode parses metadata and validates against Kubernetes API + return CAuthenticationTokenValidator.instance().ValidateToken( + expected_ptr, provided_metadata.encode()) class AuthenticationTokenLoader: @@ -120,13 +118,14 @@ class AuthenticationTokenLoader: if not self.has_token(ignore_auth_mode): return {} - # Get the token from C++ layer - cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken(ignore_auth_mode) + # Get the token from C++ layer (returns shared_ptr) + cdef shared_ptr[const CAuthenticationToken] token_ptr = \ + CAuthenticationTokenLoader.instance().GetToken(ignore_auth_mode) - if not token_opt.has_value() or token_opt.value().empty(): + if not token_ptr or token_ptr.get().empty(): return {} - return {AUTHORIZATION_HEADER_NAME: token_opt.value().ToAuthorizationHeaderValue().decode('utf-8')} + return {AUTHORIZATION_HEADER_NAME: token_ptr.get().ToAuthorizationHeaderValue().decode('utf-8')} def get_raw_token(self, ignore_auth_mode=False) -> str: """Get the raw authentication token value. @@ -141,10 +140,11 @@ class AuthenticationTokenLoader: if not self.has_token(ignore_auth_mode): return "" - # Get the token from C++ layer - cdef optional[CAuthenticationToken] token_opt = CAuthenticationTokenLoader.instance().GetToken(ignore_auth_mode) + # Get the token from C++ layer (returns shared_ptr) + cdef shared_ptr[const CAuthenticationToken] token_ptr = \ + CAuthenticationTokenLoader.instance().GetToken(ignore_auth_mode) - if not token_opt.has_value() or token_opt.value().empty(): + if not token_ptr or token_ptr.get().empty(): return "" - return token_opt.value().GetRawValue().decode('utf-8') + return token_ptr.get().GetRawValue().decode('utf-8') diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 5b9c433af9eb..b19f7c3d6ee6 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -39,7 +39,7 @@ RAY_CONFIG(bool, enable_cluster_auth, true) /// will be converted to AuthenticationMode enum defined in /// rpc/authentication/authentication_mode.h /// use GetAuthenticationMode() to get the authentication mode enum value. -RAY_CONFIG(std::string, AUTH_MODE, "disabled") +RAY_CONFIG(std::string, AUTH_MODE, "token") /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index b0e4d9ce37d7..d4d0ae2be9c7 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -25,7 +25,7 @@ void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index b90349f6f009..bc456ce32989 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -162,7 +162,7 @@ class CoreWorkerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: CoreWorkerService::AsyncService service_; diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index 66b4397782c2..ed84759e87f7 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -23,7 +23,7 @@ void ActorInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { /// The register & create actor RPCs take a long time, so we shouldn't limit their /// concurrency to avoid distributed deadlock. RPC_SERVICE_HANDLER(ActorInfoGcsService, RegisterActor, -1) @@ -44,7 +44,7 @@ void NodeInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { // We only allow one cluster ID in the lifetime of a client. // So, if a client connects, it should not have a pre-existing different ID. RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, @@ -64,7 +64,7 @@ void NodeResourceInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER( NodeResourceInfoGcsService, GetAllAvailableResources, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -79,7 +79,7 @@ void InternalPubSubGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER(InternalPubSubGcsService, GcsPublish, max_active_rpcs_per_handler_); RPC_SERVICE_HANDLER( InternalPubSubGcsService, GcsSubscriberPoll, max_active_rpcs_per_handler_); @@ -91,7 +91,7 @@ void JobInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER(JobInfoGcsService, AddJob, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, MarkJobFinished, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, GetAllJobInfo, max_active_rpcs_per_handler_) @@ -103,7 +103,7 @@ void RuntimeEnvGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER( RuntimeEnvGcsService, PinRuntimeEnvURI, max_active_rpcs_per_handler_) } @@ -112,7 +112,7 @@ void WorkerInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER( WorkerInfoGcsService, ReportWorkerFailure, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(WorkerInfoGcsService, GetWorkerInfo, max_active_rpcs_per_handler_) @@ -129,7 +129,7 @@ void InternalKVGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER(InternalKVGcsService, InternalKVGet, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( InternalKVGcsService, InternalKVMultiGet, max_active_rpcs_per_handler_) @@ -146,7 +146,7 @@ void TaskInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER(TaskInfoGcsService, AddTaskEventData, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(TaskInfoGcsService, GetTaskEvents, max_active_rpcs_per_handler_) } @@ -155,7 +155,7 @@ void PlacementGroupInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER( PlacementGroupInfoGcsService, CreatePlacementGroup, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -177,7 +177,7 @@ void AutoscalerStateGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER( AutoscalerStateService, GetClusterResourceState, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -200,7 +200,7 @@ void RayEventExportGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) { + std::shared_ptr auth_token) { RPC_SERVICE_HANDLER(RayEventExportGcsService, AddEvents, max_active_rpcs_per_handler_) } diff --git a/src/ray/gcs/grpc_services.h b/src/ray/gcs/grpc_services.h index f7b34746114d..2a99e82ace8c 100644 --- a/src/ray/gcs/grpc_services.h +++ b/src/ray/gcs/grpc_services.h @@ -54,7 +54,7 @@ class ActorInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: ActorInfoGcsService::AsyncService service_; @@ -78,7 +78,7 @@ class NodeInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: NodeInfoGcsService::AsyncService service_; @@ -102,7 +102,7 @@ class NodeResourceInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: NodeResourceInfoGcsService::AsyncService service_; @@ -126,7 +126,7 @@ class InternalPubSubGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: InternalPubSubGcsService::AsyncService service_; @@ -150,7 +150,7 @@ class JobInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: JobInfoGcsService::AsyncService service_; @@ -174,7 +174,7 @@ class RuntimeEnvGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: RuntimeEnvGcsService::AsyncService service_; @@ -198,7 +198,7 @@ class WorkerInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: WorkerInfoGcsService::AsyncService service_; @@ -222,7 +222,7 @@ class InternalKVGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: InternalKVGcsService::AsyncService service_; @@ -246,7 +246,7 @@ class TaskInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: TaskInfoGcsService::AsyncService service_; @@ -270,7 +270,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: PlacementGroupInfoGcsService::AsyncService service_; @@ -296,7 +296,7 @@ class AutoscalerStateGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: AutoscalerStateService::AsyncService service_; @@ -324,7 +324,7 @@ class RayEventExportGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override; + std::shared_ptr auth_token) override; private: RayEventExportGcsService::AsyncService service_; diff --git a/src/ray/observability/open_telemetry_metric_recorder.cc b/src/ray/observability/open_telemetry_metric_recorder.cc index 98cf0b8b0473..961c9d0c4bba 100644 --- a/src/ray/observability/open_telemetry_metric_recorder.cc +++ b/src/ray/observability/open_telemetry_metric_recorder.cc @@ -99,7 +99,7 @@ void OpenTelemetryMetricRecorder::Start(const std::string &endpoint, // Add authentication token to metadata if auth is enabled if (rpc::RequiresTokenAuthentication()) { auto token = rpc::AuthenticationTokenLoader::instance().GetToken(); - if (token.has_value() && !token->empty()) { + if (token && !token->empty()) { exporter_options.metadata.insert( {std::string(kAuthTokenKey), token->ToAuthorizationHeaderValue()}); } diff --git a/src/ray/ray_syncer/ray_syncer.h b/src/ray/ray_syncer/ray_syncer.h index 49011d14f8fd..45cfd9bdcf14 100644 --- a/src/ray/ray_syncer/ray_syncer.h +++ b/src/ray/ray_syncer/ray_syncer.h @@ -213,7 +213,7 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { public: explicit RaySyncerService( RaySyncer &syncer, - std::optional auth_token = std::nullopt) + std::shared_ptr auth_token = nullptr) : syncer_(syncer), auth_token_(std::move(auth_token)) {} grpc::ServerBidiReactor *StartSync( @@ -222,9 +222,9 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { private: // The ray syncer this RPC wrappers of. RaySyncer &syncer_; - // Authentication token for validation, will be empty if token authentication is + // Authentication token for validation, will be nullptr if token authentication is // disabled - std::optional auth_token_; + std::shared_ptr auth_token_; }; } // namespace ray::syncer diff --git a/src/ray/ray_syncer/ray_syncer_client.cc b/src/ray/ray_syncer/ray_syncer_client.cc index 18339a9826a7..7518515a53de 100644 --- a/src/ray/ray_syncer/ray_syncer_client.cc +++ b/src/ray/ray_syncer/ray_syncer_client.cc @@ -42,7 +42,7 @@ RayClientBidiReactor::RayClientBidiReactor( client_context_.AddMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); // Add authentication token if token authentication is enabled auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); - if (auth_token.has_value() && !auth_token->empty()) { + if (auth_token && !auth_token->empty()) { auth_token->SetMetadata(client_context_); } stub_->async()->StartSync(&client_context_, this); diff --git a/src/ray/ray_syncer/ray_syncer_server.cc b/src/ray/ray_syncer/ray_syncer_server.cc index 073e15d43805..f3d674e6f2a2 100644 --- a/src/ray/ray_syncer/ray_syncer_server.cc +++ b/src/ray/ray_syncer/ray_syncer_server.cc @@ -38,7 +38,7 @@ RayServerBidiReactor::RayServerBidiReactor( const std::string &local_node_id, std::function)> message_processor, std::function cleanup_cb, - const std::optional &auth_token, + std::shared_ptr auth_token, size_t max_batch_size, uint64_t max_batch_delay_ms) : RaySyncerBidiReactorBase( @@ -49,8 +49,8 @@ RayServerBidiReactor::RayServerBidiReactor( max_batch_delay_ms), cleanup_cb_(std::move(cleanup_cb)), server_context_(server_context), - auth_token_(auth_token) { - if (auth_token_.has_value() && !auth_token_->empty()) { + auth_token_(std::move(auth_token)) { + if (auth_token_ && !auth_token_->empty()) { // Validate authentication token const auto &metadata = server_context->client_metadata(); auto it = metadata.find(kAuthTokenKey); @@ -63,10 +63,9 @@ RayServerBidiReactor::RayServerBidiReactor( } const std::string_view header(it->second.data(), it->second.length()); - ray::rpc::AuthenticationToken provided_token = - ray::rpc::AuthenticationToken::FromMetadata(header); - if (!auth_token_->Equals(provided_token)) { + // Use CompareWithMetadata for efficient constant-time comparison + if (!auth_token_->CompareWithMetadata(header)) { RAY_LOG(WARNING) << "Invalid bearer token in syncer connection from node " << NodeID::FromBinary(GetRemoteNodeID()); Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Invalid bearer token")); diff --git a/src/ray/ray_syncer/ray_syncer_server.h b/src/ray/ray_syncer/ray_syncer_server.h index 7c98dd07b9e9..1ac1718bd045 100644 --- a/src/ray/ray_syncer/ray_syncer_server.h +++ b/src/ray/ray_syncer/ray_syncer_server.h @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include "ray/ray_syncer/common.h" @@ -40,7 +40,7 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase const std::string &local_node_id, std::function)> message_processor, std::function cleanup_cb, - const std::optional &auth_token, + std::shared_ptr auth_token, size_t max_batch_size, uint64_t max_batch_delay_ms); @@ -64,9 +64,9 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase /// grpc callback context grpc::CallbackServerContext *server_context_; - /// Authentication token for validation, will be empty if token authentication is + /// Authentication token for validation, will be nullptr if token authentication is /// disabled - std::optional auth_token_; + std::shared_ptr auth_token_; /// Track if Finish() has been called to avoid using a reactor that is terminating std::atomic finished_{false}; diff --git a/src/ray/ray_syncer/tests/ray_syncer_test.cc b/src/ray/ray_syncer/tests/ray_syncer_test.cc index e2a5b34fd348..75a3ff362ae2 100644 --- a/src/ray/ray_syncer/tests/ray_syncer_test.cc +++ b/src/ray/ray_syncer/tests/ray_syncer_test.cc @@ -924,7 +924,7 @@ struct MockRaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackServic node_id.Binary(), message_processor, cleanup_cb, - std::nullopt, + nullptr, /*max_batch_size=*/1, /*max_batch_delay_ms=*/0); return reactor; @@ -1105,8 +1105,8 @@ class SyncerAuthenticationTest : public ::testing::Test { // Create service with authentication token service = std::make_unique( *syncer, - token.empty() ? std::nullopt - : std::make_optional(ray::rpc::AuthenticationToken(token))); + token.empty() ? nullptr + : std::make_shared(token)); auto server_address = BuildAddress("0.0.0.0", port); grpc::ServerBuilder builder; diff --git a/src/ray/raylet/runtime_env_agent_client.cc b/src/ray/raylet/runtime_env_agent_client.cc index 2c7a5d9bda94..060aba5b4ae0 100644 --- a/src/ray/raylet/runtime_env_agent_client.cc +++ b/src/ray/raylet/runtime_env_agent_client.cc @@ -131,7 +131,7 @@ class Session : public std::enable_shared_from_this { req_.prepare_payload(); auto auth_token = rpc::AuthenticationTokenLoader::instance().GetToken(); - if (auth_token.has_value() && !auth_token->empty()) { + if (auth_token && !auth_token->empty()) { req_.set(http::field::authorization, auth_token->ToAuthorizationHeaderValue()); } } diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index ba1a5fd1843f..e6fb79ca8e2c 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -77,6 +77,35 @@ class AuthenticationToken { return !(*this == other); } + /// Compare this token against a metadata value (e.g., "Bearer "). + /// Uses constant-time comparison to prevent timing attacks. + /// @param metadata_value The raw authorization header (should be "Bearer ") + /// @return true if tokens match, false otherwise + bool CompareWithMetadata(std::string_view metadata_value) const noexcept { + // Use sizeof for compile-time constant size (kBearerPrefix is constexpr char[]) + constexpr size_t prefix_len = sizeof(kBearerPrefix) - 1; // -1 for null terminator + + // Check for valid "Bearer " prefix + if (metadata_value.size() < prefix_len || + metadata_value.substr(0, prefix_len) != kBearerPrefix) { + return false; + } + + std::string_view provided_token = metadata_value.substr(prefix_len); + + // Size check (fast path) + if (provided_token.size() != secret_.size()) { + return false; + } + + // Constant-time comparison directly on bytes (avoids vector allocation) + unsigned char diff = 0; + for (size_t i = 0; i < secret_.size(); ++i) { + diff |= secret_[i] ^ static_cast(provided_token[i]); + } + return diff == 0; + } + /// Set authentication metadata on a gRPC client context /// Only call this from client-side code void SetMetadata(grpc::ClientContext &context) const { diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 8ce843811f0d..e273d5e3096f 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -15,6 +15,7 @@ #include "ray/rpc/authentication/authentication_token_loader.h" #include +#include #include #include @@ -47,19 +48,20 @@ AuthenticationTokenLoader &AuthenticationTokenLoader::instance() { return instance; } -std::optional AuthenticationTokenLoader::GetToken( +std::shared_ptr AuthenticationTokenLoader::GetToken( bool ignore_auth_mode) { absl::MutexLock lock(&token_mutex_); // If already loaded, return cached value - if (cached_token_.has_value()) { + if (cache_initialized_) { return cached_token_; } - // If token or k8s auth is not enabled, return std::nullopt (unless ignoring auth mode) + // If token or k8s auth is not enabled, return nullptr (unless ignoring auth mode) if (!ignore_auth_mode && !RequiresTokenAuthentication()) { - cached_token_ = std::nullopt; - return std::nullopt; + cached_token_ = nullptr; + cache_initialized_ = true; + return nullptr; } // Token auth is enabled (or we're ignoring auth mode), try to load from sources @@ -74,10 +76,12 @@ std::optional AuthenticationTokenLoader::GetToken( // Cache and return the loaded token if (has_token) { - cached_token_ = std::move(result.token); - return *cached_token_; + cached_token_ = std::make_shared(std::move(*result.token)); + } else { + cached_token_ = nullptr; } - return std::nullopt; + cache_initialized_ = true; + return cached_token_; } TokenLoadResult AuthenticationTokenLoader::TryLoadToken(bool ignore_auth_mode) { @@ -85,14 +89,19 @@ TokenLoadResult AuthenticationTokenLoader::TryLoadToken(bool ignore_auth_mode) { TokenLoadResult result; // If already loaded, return cached value - if (cached_token_.has_value()) { - result.token = cached_token_; + if (cache_initialized_) { + if (cached_token_) { + result.token = *cached_token_; // Copy from shared_ptr + } else { + result.token = std::nullopt; + } return result; } // If auth is disabled, return nullopt (no token needed) if (!ignore_auth_mode && !RequiresTokenAuthentication()) { - cached_token_ = std::nullopt; + cached_token_ = nullptr; + cache_initialized_ = true; result.token = std::nullopt; return result; } @@ -106,13 +115,16 @@ TokenLoadResult AuthenticationTokenLoader::TryLoadToken(bool ignore_auth_mode) { bool no_token = !result.token.has_value() || result.token->empty(); if (no_token && ignore_auth_mode) { result.token = std::nullopt; + cache_initialized_ = true; return result; } else if (no_token) { result.error_message = kNoTokenErrorMessage; return result; } // Cache and return success - cached_token_ = result.token; + cached_token_ = std::make_shared(std::move(*result.token)); + result.token = *cached_token_; // Copy back for return + cache_initialized_ = true; return result; } @@ -176,7 +188,7 @@ TokenLoadResult AuthenticationTokenLoader::TryLoadTokenFromSources() { // No token found - return empty result (caller decides if error) RAY_LOG(DEBUG) << "No authentication token found in any source"; - result.token = AuthenticationToken(); // Empty token + result.token = AuthenticationToken("f50f7c101ea8484c8acb67f6129e1f46"); return result; } diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h index f13ca28f575c..4031d59bf381 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.h +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -45,12 +46,13 @@ class AuthenticationTokenLoader { public: static AuthenticationTokenLoader &instance(); - /// Get the authentication token. + /// Get the authentication token as shared_ptr. /// If token authentication is enabled but no token is found, fails with RAY_CHECK. + /// Callers should cache this pointer instead of calling repeatedly. /// \param ignore_auth_mode If true, bypass auth mode check and attempt to load token /// regardless of RAY_AUTH_MODE setting. - /// \return The authentication token, or std::nullopt if auth is disabled. - std::optional GetToken(bool ignore_auth_mode = false); + /// \return Shared pointer to the authentication token, or nullptr if auth is disabled. + std::shared_ptr GetToken(bool ignore_auth_mode = false); /// Try to load a token, returning error message instead of crashing. /// Use this for Python entry points where we want to raise AuthenticationError. @@ -60,7 +62,8 @@ class AuthenticationTokenLoader { void ResetCache() { absl::MutexLock lock(&token_mutex_); - cached_token_.reset(); + cached_token_ = nullptr; + cache_initialized_ = false; } AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete; @@ -83,7 +86,8 @@ class AuthenticationTokenLoader { std::string TrimWhitespace(const std::string &str); absl::Mutex token_mutex_; - std::optional cached_token_; + std::shared_ptr cached_token_; + bool cache_initialized_ = false; // Track if we've tried to load }; } // namespace rpc diff --git a/src/ray/rpc/authentication/authentication_token_validator.cc b/src/ray/rpc/authentication/authentication_token_validator.cc index b2061c7db505..b57eba4ed541 100644 --- a/src/ray/rpc/authentication/authentication_token_validator.cc +++ b/src/ray/rpc/authentication/authentication_token_validator.cc @@ -14,6 +14,9 @@ #include "ray/rpc/authentication/authentication_token_validator.h" +#include +#include + #include "ray/rpc/authentication/authentication_mode.h" #include "ray/rpc/authentication/k8s_util.h" #include "ray/util/logging.h" @@ -29,13 +32,14 @@ AuthenticationTokenValidator &AuthenticationTokenValidator::instance() { } bool AuthenticationTokenValidator::ValidateToken( - const std::optional &expected_token, - const AuthenticationToken &provided_token) { + const std::shared_ptr &expected_token, + std::string_view provided_metadata) { if (GetAuthenticationMode() == AuthenticationMode::TOKEN) { - RAY_CHECK(expected_token.has_value() && !expected_token->empty()) + RAY_CHECK(expected_token && !expected_token->empty()) << "Ray token authentication is enabled but expected token is empty"; - return expected_token->Equals(provided_token); + // Use constant-time comparison directly on metadata without constructing object + return expected_token->CompareWithMetadata(provided_metadata); } if (GetAuthenticationMode() == AuthenticationMode::K8S) { @@ -45,6 +49,13 @@ bool AuthenticationTokenValidator::ValidateToken( return false; } + // Parse metadata into token for K8S validation (needed for cache and API call) + AuthenticationToken provided_token = + AuthenticationToken::FromMetadata(std::string(provided_metadata)); + if (provided_token.empty()) { + return false; + } + // Check cache first. { std::lock_guard lock(k8s_token_cache_mutex_); diff --git a/src/ray/rpc/authentication/authentication_token_validator.h b/src/ray/rpc/authentication/authentication_token_validator.h index cf5cfea5a604..47efa4901cbe 100644 --- a/src/ray/rpc/authentication/authentication_token_validator.h +++ b/src/ray/rpc/authentication/authentication_token_validator.h @@ -14,7 +14,8 @@ #pragma once -#include +#include +#include #include #include "ray/rpc/authentication/authentication_token.h" @@ -25,14 +26,14 @@ namespace rpc { class AuthenticationTokenValidator { public: static AuthenticationTokenValidator &instance(); - /// Validate the provided authentication token against the expected token. - /// When auth_mode=token, this is a simple equality check. - /// When auth_mode=k8s, provided_token is validated against Kubernetes API. - /// \param expected_token The expected token (optional). - /// \param provided_token The token to validate. - /// \return true if the tokens are equal, false otherwise. - bool ValidateToken(const std::optional &expected_token, - const AuthenticationToken &provided_token); + /// Validate the provided authentication metadata against the expected token. + /// When auth_mode=token, uses constant-time comparison via CompareWithMetadata. + /// When auth_mode=k8s, provided_metadata is parsed and validated against Kubernetes + /// API. \param expected_token The expected token (nullptr if auth disabled or K8S + /// mode). \param provided_metadata The authorization header value (e.g., "Bearer + /// "). \return true if the token is valid, false otherwise. + bool ValidateToken(const std::shared_ptr &expected_token, + std::string_view provided_metadata); private: // Cache for K8s tokens. diff --git a/src/ray/rpc/authentication/tests/authentication_token_loader_test.cc b/src/ray/rpc/authentication/tests/authentication_token_loader_test.cc index 7c21d7bd36df..7fae2eae43f6 100644 --- a/src/ray/rpc/authentication/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/authentication/tests/authentication_token_loader_test.cc @@ -151,10 +151,10 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { auto &loader = AuthenticationTokenLoader::instance(); auto token_opt = loader.GetToken(); - ASSERT_TRUE(token_opt.has_value()); + ASSERT_TRUE(token_opt != nullptr); AuthenticationToken expected("test-token-from-env"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.GetToken().has_value()); + EXPECT_TRUE(loader.GetToken() != nullptr); } TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { @@ -168,10 +168,10 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { auto &loader = AuthenticationTokenLoader::instance(); auto token_opt = loader.GetToken(); - ASSERT_TRUE(token_opt.has_value()); + ASSERT_TRUE(token_opt != nullptr); AuthenticationToken expected("test-token-from-file"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.GetToken().has_value()); + EXPECT_TRUE(loader.GetToken() != nullptr); // Clean up remove(temp_token_path.c_str()); @@ -185,10 +185,10 @@ TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { auto &loader = AuthenticationTokenLoader::instance(); auto token_opt = loader.GetToken(); - ASSERT_TRUE(token_opt.has_value()); + ASSERT_TRUE(token_opt != nullptr); AuthenticationToken expected("test-token-from-default"); EXPECT_TRUE(token_opt->Equals(expected)); - EXPECT_TRUE(loader.GetToken().has_value()); + EXPECT_TRUE(loader.GetToken() != nullptr); } // Parametrized test for token loading precedence: env var > user-specified file > default @@ -250,7 +250,7 @@ TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { auto &loader = AuthenticationTokenLoader::instance(); auto token_opt = loader.GetToken(); - ASSERT_TRUE(token_opt.has_value()); + ASSERT_TRUE(token_opt != nullptr); AuthenticationToken expected(param.expected_token); EXPECT_TRUE(token_opt->Equals(expected)); @@ -273,8 +273,8 @@ TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { auto &loader = AuthenticationTokenLoader::instance(); auto token_opt = loader.GetToken(); - EXPECT_FALSE(token_opt.has_value()); - EXPECT_FALSE(loader.GetToken().has_value()); + EXPECT_TRUE(token_opt == nullptr); + EXPECT_TRUE(loader.GetToken() == nullptr); // Re-enable for other tests RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); @@ -304,8 +304,8 @@ TEST_F(AuthenticationTokenLoaderTest, TestCaching) { auto token_opt2 = loader.GetToken(); // Should still return the cached token - ASSERT_TRUE(token_opt1.has_value()); - ASSERT_TRUE(token_opt2.has_value()); + ASSERT_TRUE(token_opt1 != nullptr); + ASSERT_TRUE(token_opt2 != nullptr); EXPECT_TRUE(token_opt1->Equals(*token_opt2)); AuthenticationToken expected("cached-token"); EXPECT_TRUE(token_opt2->Equals(expected)); @@ -320,7 +320,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestWhitespaceHandling) { auto token_opt = loader.GetToken(); // Whitespace should be trimmed - ASSERT_TRUE(token_opt.has_value()); + ASSERT_TRUE(token_opt != nullptr); AuthenticationToken expected("token-with-spaces"); EXPECT_TRUE(token_opt->Equals(expected)); } @@ -335,16 +335,16 @@ TEST_F(AuthenticationTokenLoaderTest, TestIgnoreAuthModeGetToken) { auto &loader = AuthenticationTokenLoader::instance(); - // Without ignore_auth_mode, should return nullopt (auth is disabled) + // Without ignore_auth_mode, should return nullptr (auth is disabled) auto token_opt_no_ignore = loader.GetToken(); - EXPECT_FALSE(token_opt_no_ignore.has_value()); + EXPECT_TRUE(token_opt_no_ignore == nullptr); // Reset cache to test ignore_auth_mode loader.ResetCache(); // With ignore_auth_mode=true, should load token despite auth being disabled auto token_opt_ignore = loader.GetToken(true); - ASSERT_TRUE(token_opt_ignore.has_value()); + ASSERT_TRUE(token_opt_ignore != nullptr); AuthenticationToken expected("test-token-ignore-auth"); EXPECT_TRUE(token_opt_ignore->Equals(expected)); diff --git a/src/ray/rpc/authentication/token_auth_client_interceptor.cc b/src/ray/rpc/authentication/token_auth_client_interceptor.cc index bece4f3bc3b9..4180d10d2015 100644 --- a/src/ray/rpc/authentication/token_auth_client_interceptor.cc +++ b/src/ray/rpc/authentication/token_auth_client_interceptor.cc @@ -27,18 +27,19 @@ namespace ray { namespace rpc { +RayTokenAuthClientInterceptor::RayTokenAuthClientInterceptor() + : token_(AuthenticationTokenLoader::instance().GetToken()) {} + void RayTokenAuthClientInterceptor::Intercept( grpc::experimental::InterceptorBatchMethods *methods) { if (methods->QueryInterceptionHookPoint( grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { - auto token = AuthenticationTokenLoader::instance().GetToken(); - - // If token is present and non-empty, add it to the metadata - if (token.has_value() && !token->empty()) { + // Use cached token instead of calling GetToken() on each RPC + if (token_ && !token_->empty()) { // Get the metadata map and add the authorization header auto *metadata = methods->GetSendInitialMetadata(); metadata->insert( - std::make_pair(kAuthTokenKey, token->ToAuthorizationHeaderValue())); + std::make_pair(kAuthTokenKey, token_->ToAuthorizationHeaderValue())); } } methods->Proceed(); diff --git a/src/ray/rpc/authentication/token_auth_client_interceptor.h b/src/ray/rpc/authentication/token_auth_client_interceptor.h index 8dff955c0516..a3b07c9202bd 100644 --- a/src/ray/rpc/authentication/token_auth_client_interceptor.h +++ b/src/ray/rpc/authentication/token_auth_client_interceptor.h @@ -19,15 +19,24 @@ #include #include +#include "ray/rpc/authentication/authentication_token.h" + namespace ray { namespace rpc { /// Client interceptor that automatically adds Ray authentication tokens to outgoing RPCs. -/// The token is loaded from AuthenticationTokenLoader and added as a Bearer token -/// in the "authorization" metadata key. +/// The token is loaded from AuthenticationTokenLoader at construction time and cached. +/// It is added as a Bearer token in the "authorization" metadata key. class RayTokenAuthClientInterceptor : public grpc::experimental::Interceptor { public: + /// Constructor that loads and caches the authentication token. + RayTokenAuthClientInterceptor(); + void Intercept(grpc::experimental::InterceptorBatchMethods *methods) override; + + private: + /// Cached authentication token (loaded once at construction) + std::shared_ptr token_; }; /// Factory for creating RayTokenAuthClientInterceptor instances diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 189d68034645..ebf47c05a497 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -96,7 +96,7 @@ class GrpcServer { bool listen_to_localhost_only, int num_threads = 1, int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ - std::optional auth_token = std::nullopt) + std::shared_ptr auth_token = nullptr) : name_(std::move(name)), port_(port), listen_to_localhost_only_(listen_to_localhost_only), @@ -104,8 +104,8 @@ class GrpcServer { num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { // Initialize auth token: use provided value or load from AuthenticationTokenLoader - if (auth_token.has_value()) { - auth_token_ = std::move(auth_token.value()); + if (auth_token) { + auth_token_ = auth_token; } else { auth_token_ = AuthenticationTokenLoader::instance().GetToken(); } @@ -174,7 +174,7 @@ class GrpcServer { ClusterID cluster_id_; /// Authentication token for token-based authentication. - std::optional auth_token_; + std::shared_ptr auth_token_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; @@ -243,7 +243,7 @@ class GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) = 0; + std::shared_ptr auth_token) = 0; /// The main event loop, to which the service handler functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index ff453d249b37..9de0b2b6acba 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -221,7 +221,7 @@ class NodeManagerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override { + std::shared_ptr auth_token) override { RAY_NODE_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index 576de9396142..36c42a6fe04d 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -81,7 +81,7 @@ class ObjectManagerGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override { + std::shared_ptr auth_token) override { RAY_OBJECT_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 1728ffff81a6..3ba048ce009f 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -182,7 +182,7 @@ class ServerCallImpl : public ServerCall { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, - const std::optional &auth_token, + std::shared_ptr auth_token, bool record_metrics, std::function preprocess_function = nullptr) : state_(ServerCallState::PENDING), @@ -347,13 +347,13 @@ class ServerCallImpl : public ServerCall { private: /// Validates token-based authentication. - /// Returns true if authentication succeeds or is not required. - /// Returns false if authentication is required but fails. + /// Returns true always (non-enforcing mode) but logs warnings for missing/invalid + /// tokens. bool ValidateAuthenticationToken() { // If auth token is empty, we assume auth is not required. // The only exception is when auth mode is 'k8s' where the server // auth token can be empty. - if ((!auth_token_.has_value() || auth_token_->empty()) && + if ((!auth_token_ || auth_token_->empty()) && GetAuthenticationMode() != AuthenticationMode::K8S) { return true; } @@ -361,14 +361,22 @@ class ServerCallImpl : public ServerCall { const auto &metadata = context_.client_metadata(); auto it = metadata.find(kAuthTokenKey); if (it == metadata.end()) { - RAY_LOG(WARNING) << "Missing authorization header in request!"; - return false; + RAY_LOG(WARNING) << "Missing authorization header in request to " << call_name_ + << ", allowing request to proceed (non-enforcing mode)"; + return true; } const std::string_view header(it->second.data(), it->second.length()); - AuthenticationToken provided_token = AuthenticationToken::FromMetadata(header); - return ray::rpc::AuthenticationTokenValidator::instance().ValidateToken( - auth_token_, provided_token); + + // ValidateToken handles both TOKEN and K8S modes: + // - TOKEN mode uses CompareWithMetadata for efficient constant-time comparison + // - K8S mode parses metadata and validates against Kubernetes API + if (!ray::rpc::AuthenticationTokenValidator::instance().ValidateToken(auth_token_, + header)) { + RAY_LOG(WARNING) << "Invalid authentication token in request to " << call_name_ + << ", allowing request to proceed (non-enforcing mode)"; + } + return true; // Non-enforcing } /// Log the duration this query used @@ -438,8 +446,8 @@ class ServerCallImpl : public ServerCall { /// Check skipped if empty. const ClusterID &cluster_id_; - /// Authentication token for token-based authentication. - std::optional auth_token_; + /// Authentication token for token-based authentication (shared, not copied per call). + std::shared_ptr auth_token_; /// The callback when sending reply successes. std::function send_reply_success_callback_ = nullptr; @@ -522,7 +530,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, - const std::optional &auth_token, + std::shared_ptr auth_token, int64_t max_active_rpcs, bool record_metrics) : service_(service), @@ -587,8 +595,8 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// Check skipped if empty. const ClusterID cluster_id_; - /// Authentication token for token-based authentication. - std::optional auth_token_; + /// Authentication token for token-based authentication (shared, not copied per call). + std::shared_ptr auth_token_; /// Maximum request number to handle at the same time. /// -1 means no limit. diff --git a/src/ray/rpc/tests/grpc_test_common.h b/src/ray/rpc/tests/grpc_test_common.h index 1ce199f79511..dca3070b5ba0 100644 --- a/src/ray/rpc/tests/grpc_test_common.h +++ b/src/ray/rpc/tests/grpc_test_common.h @@ -91,7 +91,7 @@ class TestGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories, const ClusterID &cluster_id, - const std::optional &auth_token) override { + std::shared_ptr auth_token) override { RPC_SERVICE_HANDLER_CUSTOM_AUTH( TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH(