From ce73705ee7043afe6cbafaf67ec20b961f002689 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:17:38 +0000 Subject: [PATCH 01/48] Add authentication token logic and related tests Signed-off-by: sampan --- src/ray/common/constants.h | 2 + src/ray/common/ray_config_def.h | 6 + src/ray/rpc/authentication/BUILD.bazel | 34 ++ .../rpc/authentication/authentication_mode.cc | 37 ++ .../rpc/authentication/authentication_mode.h | 33 ++ .../rpc/authentication/authentication_token.h | 157 ++++++++ .../authentication_token_loader.cc | 167 +++++++++ .../authentication_token_loader.h | 72 ++++ src/ray/rpc/tests/BUILD.bazel | 44 +++ .../tests/authentication_token_loader_test.cc | 346 ++++++++++++++++++ .../rpc/tests/authentication_token_test.cc | 131 +++++++ 11 files changed, 1029 insertions(+) create mode 100644 src/ray/rpc/authentication/BUILD.bazel create mode 100644 src/ray/rpc/authentication/authentication_mode.cc create mode 100644 src/ray/rpc/authentication/authentication_mode.h create mode 100644 src/ray/rpc/authentication/authentication_token.h create mode 100644 src/ray/rpc/authentication/authentication_token_loader.cc create mode 100644 src/ray/rpc/authentication/authentication_token_loader.h create mode 100644 src/ray/rpc/tests/authentication_token_loader_test.cc create mode 100644 src/ray/rpc/tests/authentication_token_test.cc diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index bfd06e677e7e..08986d3b415e 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,6 +42,8 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; +constexpr char kAuthTokenKey[] = "authorization"; +constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"; diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 6e8d21956162..e4e8fc1d48ef 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -35,6 +35,12 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) /// Whether to enable cluster authentication. RAY_CONFIG(bool, enable_cluster_auth, true) +/// Whether to enable token-based authentication for RPC calls. +/// will be converted to AuthenticationMode enum defined in +/// rpc/authentication/authentication_mode.h +/// use GetAuthenticationMode() to get the authentication mode enum value. +RAY_CONFIG(std::string, auth_mode, "disabled") + /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available /// in the associated process's log file. diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel new file mode 100644 index 000000000000..8da78e5d728b --- /dev/null +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -0,0 +1,34 @@ +load("//bazel:ray.bzl", "ray_cc_library") + +ray_cc_library( + name = "authentication_mode", + srcs = ["authentication_mode.cc"], + hdrs = ["authentication_mode.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:ray_config", + "@com_google_absl//absl/strings", + ], +) + +ray_cc_library( + name = "authentication_token", + hdrs = ["authentication_token.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:constants", + "@com_github_grpc_grpc//:grpc++", + ], +) + +ray_cc_library( + name = "authentication_token_loader", + srcs = ["authentication_token_loader.cc"], + hdrs = ["authentication_token_loader.h"], + visibility = ["//visibility:public"], + deps = [ + ":authentication_mode", + ":authentication_token", + "//src/ray/util:logging", + ], +) diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc new file mode 100644 index 000000000000..1bbe209733ce --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -0,0 +1,37 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_mode.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "ray/common/ray_config.h" + +namespace ray { +namespace rpc { + +AuthenticationMode GetAuthenticationMode() { + std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); + + if (auth_mode_lower == "token") { + return AuthenticationMode::TOKEN; + } else { + return AuthenticationMode::DISABLED; + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_mode.h b/src/ray/rpc/authentication/authentication_mode.h new file mode 100644 index 000000000000..21bd165fd34b --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.h @@ -0,0 +1,33 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace rpc { + +enum class AuthenticationMode { + DISABLED, + TOKEN, +}; + +/// Get the authentication mode from the RayConfig. +/// \return The authentication mode enum value. returns AuthenticationMode::DISABLED if +/// the authentication mode is not set or is invalid. +AuthenticationMode GetAuthenticationMode(); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h new file mode 100644 index 000000000000..6846d3c08ada --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token.h @@ -0,0 +1,157 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ray/common/constants.h" + +namespace ray { +namespace rpc { + +/// Secure wrapper for authentication tokens. +/// - Wipes memory on destruction +/// - Constant-time comparison +/// - Redacted output when logged or printed +class AuthenticationToken { + public: + AuthenticationToken() = default; + explicit AuthenticationToken(std::string value) : secret_(value.begin(), value.end()) {} + + AuthenticationToken(const AuthenticationToken &other) : secret_(other.secret_) {} + AuthenticationToken &operator=(const AuthenticationToken &other) { + if (this != &other) { + SecureClear(); + secret_ = other.secret_; + } + return *this; + } + + // Move operations + AuthenticationToken(AuthenticationToken &&other) noexcept { + MoveFrom(std::move(other)); + } + AuthenticationToken &operator=(AuthenticationToken &&other) noexcept { + if (this != &other) { + SecureClear(); + MoveFrom(std::move(other)); + } + return *this; + } + ~AuthenticationToken() { SecureClear(); } + + bool empty() const noexcept { return secret_.empty(); } + + /// Constant-time equality comparison + bool Equals(const AuthenticationToken &other) const noexcept { + return ConstTimeEqual(secret_, other.secret_); + } + + /// Equality operator (constant-time) + bool operator==(const AuthenticationToken &other) const noexcept { + return Equals(other); + } + + /// Inequality operator + bool operator!=(const AuthenticationToken &other) const noexcept { + return !(*this == other); + } + + /// Set authentication metadata on a gRPC client context + /// Only call this from client-side code + void SetMetadata(grpc::ClientContext &context) const { + if (!secret_.empty()) { + context.AddMetadata(kAuthTokenKey, + kBearerPrefix + std::string(secret_.begin(), secret_.end())); + } + } + + /// Create AuthenticationToken from gRPC metadata value + /// Strips "Bearer " prefix and creates token object + /// @param metadata_value The raw value from server metadata (should include "Bearer " + /// prefix) + /// @return AuthenticationToken object (empty if format invalid) + static AuthenticationToken FromMetadata(std::string_view metadata_value) { + const std::string_view prefix(kBearerPrefix, sizeof(kBearerPrefix) - 1); + if (metadata_value.size() <= prefix.size() || + metadata_value.substr(0, prefix.size()) != prefix) { + return AuthenticationToken(); // Invalid format, return empty + } + std::string_view token_part = metadata_value.substr(prefix.size()); + return AuthenticationToken(std::string(token_part)); + } + + friend std::ostream &operator<<(std::ostream &os, const AuthenticationToken &t) { + return os << ""; + } + + private: + std::vector secret_; + + // Constant-time string comparison to avoid timing attacks. + // https://en.wikipedia.org/wiki/Timing_attack + static bool ConstTimeEqual(const std::vector &a, + const std::vector &b) noexcept { + if (a.size() != b.size()) { + return false; + } + unsigned char diff = 0; + for (size_t i = 0; i < a.size(); ++i) { + diff |= a[i] ^ b[i]; + } + return diff == 0; + } + + // replace the characters in the memory with 0 + static void ExplicitBurn(void *p, size_t n) noexcept { +#if defined(_MSC_VER) + SecureZeroMemory(p, n); +#elif defined(__STDC_LIB_EXT1__) + memset_s(p, n, 0, n); +#else + // Using array indexing instead of pointer arithmetic + volatile auto *vp = static_cast(p); + for (size_t i = 0; i < n; ++i) { + vp[i] = 0; + } +#endif + } + + void SecureClear() noexcept { + if (!secret_.empty()) { + ExplicitBurn(secret_.data(), secret_.size()); + secret_.clear(); + } + } + + void MoveFrom(AuthenticationToken &&other) noexcept { + SecureClear(); + secret_ = std::move(other.secret_); + // Clear the moved-from object explicitly for security + // Note: 'other' is already an rvalue reference, no need to move again + other.SecureClear(); + } +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc new file mode 100644 index 000000000000..59a1184e080a --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -0,0 +1,167 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include +#include + +#include "ray/util/logging.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#endif + +namespace ray { +namespace rpc { + +AuthenticationTokenLoader &AuthenticationTokenLoader::instance() { + static AuthenticationTokenLoader instance; + return instance; +} + +std::optional AuthenticationTokenLoader::GetToken() { + std::lock_guard lock(token_mutex_); + + // If already loaded, return cached value + if (cached_token_.has_value()) { + return cached_token_; + } + + // If token auth is not enabled, return std::nullopt + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { + cached_token_ = std::nullopt; + return std::nullopt; + } + + // Token auth is enabled, try to load from sources + AuthenticationToken token = LoadTokenFromSources(); + + // If no token found and auth is enabled, fail with RAY_CHECK + RAY_CHECK(!token.empty()) + << "Token authentication is enabled but no authentication token was found. " + << "Please set RAY_AUTH_TOKEN environment variable, RAY_AUTH_TOKEN_PATH to a file " + << "containing the token, or create a token file at ~/.ray/auth_token"; + + // Cache and return the loaded token + cached_token_ = std::move(token); + return *cached_token_; +} + +// Read token from the first line of the file. trim whitespace. +// Returns empty string if file cannot be opened or is empty. +std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { + std::ifstream token_file(file_path); + if (!token_file.is_open()) { + return ""; + } + + std::string token; + std::getline(token_file, token); + token_file.close(); + return token; +} + +AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { + // Precedence 1: RAY_AUTH_TOKEN environment variable + const char *env_token = std::getenv("RAY_AUTH_TOKEN"); + if (env_token != nullptr && std::string(env_token).length() > 0) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return AuthenticationToken(TrimWhitespace(std::string(env_token))); + } + + // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable + const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); + if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { + std::string token_str = TrimWhitespace(ReadTokenFromFile(env_token_path)); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + return AuthenticationToken(token_str); + } else { + RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " + << env_token_path; + } + } + + // Precedence 3: Default token path ~/.ray/auth_token + std::string default_path = GetDefaultTokenPath(); + std::string token_str = TrimWhitespace(ReadTokenFromFile(default_path)); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; + return AuthenticationToken(token_str); + } + + // No token found + RAY_LOG(DEBUG) << "No authentication token found in any source"; + return AuthenticationToken(); +} + +std::string AuthenticationTokenLoader::GetDefaultTokenPath() { + std::string home_dir; + +#ifdef _WIN32 + const char *path_separator = "\\"; + const char *userprofile = std::getenv("USERPROFILE"); + if (userprofile != nullptr) { + home_dir = userprofile; + } else { + const char *homedrive = std::getenv("HOMEDRIVE"); + const char *homepath = std::getenv("HOMEPATH"); + if (homedrive != nullptr && homepath != nullptr) { + home_dir = std::string(homedrive) + std::string(homepath); + } + } +#else + const char *path_separator = "/"; + const char *home = std::getenv("HOME"); + if (home != nullptr) { + home_dir = home; + } +#endif + + const std::string token_subpath = + std::string(path_separator) + ".ray" + std::string(path_separator) + "auth_token"; + + if (home_dir.empty()) { + RAY_LOG(WARNING) << "Cannot determine home directory for token storage"; + return "." + token_subpath; + } + + return home_dir + token_subpath; +} + +std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { + std::string whitespace = " \t\n\r\f\v"; + std::string trimmed_str = str; + trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); + return trimmed_str; +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h new file mode 100644 index 000000000000..4034ecbc78dd --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -0,0 +1,72 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions +// and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "ray/rpc/authentication/authentication_mode.h" +#include "ray/rpc/authentication/authentication_token.h" + +namespace ray { +namespace rpc { + +/// Singleton class for loading and caching authentication tokens. +/// Supports loading tokens from multiple sources with precedence: +/// 1. RAY_AUTH_TOKEN environment variable +/// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) +/// 3. Default token path: ~/.ray/auth_token (Unix) or %USERPROFILE%\.ray\auth_token +/// +/// Thread-safe with internal caching to avoid repeated file I/O. +class AuthenticationTokenLoader { + public: + static AuthenticationTokenLoader &instance(); + + /// Get the authentication token. + /// If token authentication is enabled but no token is found, fails with RAY_CHECK. + /// \return The authentication token, or std::nullopt if auth is disabled. + std::optional GetToken(); + + void ResetCache() { + std::lock_guard lock(token_mutex_); + cached_token_.reset(); + } + + AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete; + AuthenticationTokenLoader &operator=(const AuthenticationTokenLoader &) = delete; + + private: + AuthenticationTokenLoader() = default; + ~AuthenticationTokenLoader() = default; + + /// Read and trim token from file. + std::string ReadTokenFromFile(const std::string &file_path); + + /// Load token from environment or file. + AuthenticationToken LoadTokenFromSources(); + + /// Default token file path (~/.ray/auth_token or %USERPROFILE%\.ray\auth_token). + std::string GetDefaultTokenPath(); + + /// Trim whitespace from the beginning and end of the string. + std::string TrimWhitespace(const std::string &str); + + std::mutex token_mutex_; + std::optional cached_token_; +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 5fa8b14cc4db..279b68f91ba3 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,6 +18,23 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", + "grpc_test_common.h", + ], + tags = ["team:core"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "grpc_auth_token_tests", + size = "small", + srcs = [ + "grpc_auth_token_tests.cc", + "grpc_test_common.h", ], tags = ["team:core"], deps = [ @@ -40,3 +57,30 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "authentication_token_loader_test", + size = "small", + srcs = [ + "authentication_token_loader_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "authentication_token_test", + size = "small", + srcs = [ + "authentication_token_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/rpc/authentication:authentication_token", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc new file mode 100644 index 000000000000..616a13b0e457 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -0,0 +1,346 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#include // For _mkdir on Windows +#include // For _getpid on Windows +#endif + +namespace ray { +namespace rpc { + +class AuthenticationTokenLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication for tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + + // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory + // This ensures tests work in environments where HOME isn't provided +#ifdef _WIN32 + if (std::getenv("USERPROFILE") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "\\ray_test_home"; + } else { + test_home_dir_ = "C:\\Windows\\Temp\\ray_test_home"; + } + _putenv(("USERPROFILE=" + test_home_dir_).c_str()); + } + const char *home_dir = std::getenv("USERPROFILE"); + default_token_path_ = std::string(home_dir) + "\\.ray\\auth_token"; +#else + if (std::getenv("HOME") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "/ray_test_home"; + } else { + test_home_dir_ = "/tmp/ray_test_home"; + } + setenv("HOME", test_home_dir_.c_str(), 1); + } + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + default_token_path_ = std::string(home_dir) + "/.ray/auth_token"; + test_home_dir_ = home_dir; + } else { + default_token_path_ = ".ray/auth_token"; + } +#endif + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + // Clean up after test + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + // Disable token auth after tests + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + + void cleanup_env() { + unset_env_var("RAY_AUTH_TOKEN"); + unset_env_var("RAY_AUTH_TOKEN_PATH"); + remove(default_token_path_.c_str()); + } + + std::string get_temp_token_path() { +#ifdef _WIN32 + return "C:\\Windows\\Temp\\ray_test_token_" + std::to_string(_getpid()); +#else + return "/tmp/ray_test_token_" + std::to_string(getpid()); +#endif + } + + void set_env_var(const char *name, const char *value) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "=" + std::string(value); + _putenv(env_str.c_str()); +#else + setenv(name, value, 1); +#endif + } + + void unset_env_var(const char *name) { +#ifdef _WIN32 + std::string env_str = std::string(name) + "="; + _putenv(env_str.c_str()); +#else + unsetenv(name); +#endif + } + + void ensure_ray_dir_exists() { +#ifdef _WIN32 + const char *home_dir = std::getenv("USERPROFILE"); + _mkdir(home_dir); // Create parent directory + std::string ray_dir = std::string(home_dir) + "\\.ray"; + _mkdir(ray_dir.c_str()); +#else + // Always ensure the home directory exists (it might be a test temp dir we created) + if (!test_home_dir_.empty()) { + mkdir(test_home_dir_.c_str(), + 0700); // Create if it doesn't exist (ignore error if it does) + } + + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + std::string ray_dir = std::string(home_dir) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + } +#endif + } + + void write_token_file(const std::string &path, const std::string &content) { + std::ofstream token_file(path); + token_file << content; + token_file.close(); + } + + std::string default_token_path_; + std::string test_home_dir_; // Fallback home directory for tests +}; + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { + // Set token in environment variable + set_env_var("RAY_AUTH_TOKEN", "test-token-from-env"); + + // Create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-env"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { + // Create a temporary token file + std::string temp_token_path = get_temp_token_path(); + write_token_file(temp_token_path, "test-token-from-file"); + + // Set path in environment variable + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-file"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); + + // Clean up + remove(temp_token_path.c_str()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { + // Create directory and token file in default location + ensure_ray_dir_exists(); + write_token_file(default_token_path_, "test-token-from-default"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-default"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +// Parametrized test for token loading precedence: env var > user-specified file > default +// file + +struct TokenSourceConfig { + bool set_env = false; + bool set_file = false; + bool set_default = false; + std::string expected_token; + std::string env_token = "token-from-env"; + std::string file_token = "token-from-path"; + std::string default_token = "token-from-default"; +}; + +class AuthenticationTokenLoaderPrecedenceTest + : public AuthenticationTokenLoaderTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, + AuthenticationTokenLoaderPrecedenceTest, + ::testing::Values( + // All set: env should win + TokenSourceConfig{true, true, true, "token-from-env"}, + // File and default file set: file should win + TokenSourceConfig{false, true, true, "token-from-path"}, + // Only default file set + TokenSourceConfig{ + false, false, true, "token-from-default"})); + +TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { + const auto ¶m = GetParam(); + + // Optionally set environment variable + if (param.set_env) { + set_env_var("RAY_AUTH_TOKEN", param.env_token.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN"); + } + + // Optionally create file and set path + std::string temp_token_path = get_temp_token_path(); + if (param.set_file) { + write_token_file(temp_token_path, param.file_token); + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN_PATH"); + } + + // Optionally create default file + ensure_ray_dir_exists(); + if (param.set_default) { + write_token_file(default_token_path_, param.default_token); + } else { + remove(default_token_path_.c_str()); + } + + // Always create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected(param.expected_token); + EXPECT_TRUE(token_opt->Equals(expected)); + + // Clean up token file if it was written + if (param.set_file) { + remove(temp_token_path.c_str()); + } + // Clean up default file if it was written + if (param.set_default) { + remove(default_token_path_.c_str()); + } +} + +TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { + // Disable auth for this specific test + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + + // No token set anywhere, but auth is disabled + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + EXPECT_FALSE(token_opt.has_value()); + EXPECT_FALSE(loader.GetToken().has_value()); + + // Re-enable for other tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); +} + +TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { + // Token auth is already enabled in SetUp() + // No token exists, should trigger RAY_CHECK failure + EXPECT_DEATH( + { + auto &loader = AuthenticationTokenLoader::instance(); + loader.GetToken(); + }, + "Token authentication is enabled but no authentication token was found"); +} + +TEST_F(AuthenticationTokenLoaderTest, TestCaching) { + // Set token in environment + set_env_var("RAY_AUTH_TOKEN", "cached-token"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt1 = loader.GetToken(); + + // Change environment variable (shouldn't affect cached value) + set_env_var("RAY_AUTH_TOKEN", "new-token"); + auto token_opt2 = loader.GetToken(); + + // Should still return the cached token + ASSERT_TRUE(token_opt1.has_value()); + ASSERT_TRUE(token_opt2.has_value()); + EXPECT_TRUE(token_opt1->Equals(*token_opt2)); + AuthenticationToken expected("cached-token"); + EXPECT_TRUE(token_opt2->Equals(expected)); +} + +TEST_F(AuthenticationTokenLoaderTest, TestWhitespaceHandling) { + // Create token file with whitespace + ensure_ray_dir_exists(); + write_token_file(default_token_path_, " token-with-spaces \n\t"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + // Whitespace should be trimmed + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("token-with-spaces"); + EXPECT_TRUE(token_opt->Equals(expected)); +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc new file mode 100644 index 000000000000..db88d7481da1 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -0,0 +1,131 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace ray { +namespace rpc { + +class AuthenticationTokenTest : public ::testing::Test {}; + +TEST_F(AuthenticationTokenTest, TestDefaultConstructor) { + AuthenticationToken token; + EXPECT_TRUE(token.empty()); +} + +TEST_F(AuthenticationTokenTest, TestConstructorWithValue) { + AuthenticationToken token("test-token-value"); + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token-value"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestMoveConstructor) { + AuthenticationToken token1("original-token"); + AuthenticationToken token2(std::move(token1)); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("original-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestMoveAssignment) { + AuthenticationToken token1("first-token"); + AuthenticationToken token2("second-token"); + + token2 = std::move(token1); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("first-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestSelfMoveAssignment) { + AuthenticationToken token("test-token"); + + // Self-assignment should not break the token + token = std::move(token); + + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestEquals) { + AuthenticationToken token1("same-token"); + AuthenticationToken token2("same-token"); + AuthenticationToken token3("different-token"); + + EXPECT_TRUE(token1.Equals(token2)); + EXPECT_FALSE(token1.Equals(token3)); + EXPECT_TRUE(token1 == token2); + EXPECT_FALSE(token1 == token3); + EXPECT_FALSE(token1 != token2); + EXPECT_TRUE(token1 != token3); +} + +TEST_F(AuthenticationTokenTest, TestEqualityDifferentLengths) { + AuthenticationToken token1("short"); + AuthenticationToken token2("much-longer-token"); + + EXPECT_FALSE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyTokens) { + AuthenticationToken token1; + AuthenticationToken token2; + + EXPECT_TRUE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyVsNonEmpty) { + AuthenticationToken token1; + AuthenticationToken token2("non-empty"); + + EXPECT_FALSE(token1.Equals(token2)); + EXPECT_FALSE(token2.Equals(token1)); +} + +TEST_F(AuthenticationTokenTest, TestRedactedOutput) { + AuthenticationToken token("super-secret-token"); + + std::ostringstream oss; + oss << token; + + std::string output = oss.str(); + EXPECT_EQ(output, ""); + EXPECT_EQ(output.find("super-secret-token"), std::string::npos); +} + +TEST_F(AuthenticationTokenTest, TestEmptyString) { + AuthenticationToken token(""); + EXPECT_TRUE(token.empty()); + AuthenticationToken expected(""); + EXPECT_TRUE(token.Equals(expected)); +} +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 341b108707c0e6f0abbcd5db7a5e56390bb55937 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:20:13 +0000 Subject: [PATCH 02/48] Add gRPC service and server logic with auth integration tests Signed-off-by: sampan --- src/ray/common/grpc_util.h | 4 + src/ray/common/status.cc | 2 +- src/ray/common/status.h | 8 +- src/ray/core_worker/BUILD.bazel | 1 + src/ray/core_worker/grpc_service.cc | 76 +++--- src/ray/core_worker/grpc_service.h | 6 +- src/ray/gcs/BUILD.bazel | 1 + src/ray/gcs/grpc_services.cc | 38 +-- src/ray/gcs/grpc_services.h | 38 ++- .../gcs_rpc_client/tests/gcs_client_test.cc | 4 +- src/ray/raylet/node_manager.cc | 2 +- src/ray/rpc/BUILD.bazel | 5 + src/ray/rpc/client_call.h | 15 +- src/ray/rpc/grpc_server.cc | 12 +- src/ray/rpc/grpc_server.h | 61 +++-- .../rpc/node_manager/node_manager_server.h | 9 +- src/ray/rpc/object_manager_server.h | 9 +- src/ray/rpc/server_call.h | 105 +++++++-- src/ray/rpc/tests/grpc_auth_token_tests.cc | 221 ++++++++++++++++++ src/ray/rpc/tests/grpc_bench/BUILD.bazel | 1 + src/ray/rpc/tests/grpc_bench/grpc_bench.cc | 10 +- src/ray/rpc/tests/grpc_server_client_test.cc | 83 +------ src/ray/rpc/tests/grpc_test_common.h | 109 +++++++++ 23 files changed, 623 insertions(+), 197 deletions(-) create mode 100644 src/ray/rpc/tests/grpc_auth_token_tests.cc create mode 100644 src/ray/rpc/tests/grpc_test_common.h diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index ae99eaf79081..ed2f8c73eda1 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -83,6 +83,10 @@ inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { if (ray_status.ok()) { return grpc::Status::OK; } + // Map Unauthenticated to gRPC's UNAUTHENTICATED status code + if (ray_status.IsUnauthenticated()) { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, ray_status.message()); + } // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means // more robust. return grpc::Status( diff --git a/src/ray/common/status.cc b/src/ray/common/status.cc index 3500ddaf3b80..528a6766412e 100644 --- a/src/ray/common/status.cc +++ b/src/ray/common/status.cc @@ -74,7 +74,7 @@ const absl::flat_hash_map kCodeToStr = { {StatusCode::RpcError, "RpcError"}, {StatusCode::OutOfResource, "OutOfResource"}, {StatusCode::ObjectRefEndOfStream, "ObjectRefEndOfStream"}, - {StatusCode::AuthError, "AuthError"}, + {StatusCode::Unauthenticated, "Unauthenticated"}, {StatusCode::InvalidArgument, "InvalidArgument"}, {StatusCode::ChannelError, "ChannelError"}, {StatusCode::ChannelTimeoutError, "ChannelTimeoutError"}, diff --git a/src/ray/common/status.h b/src/ray/common/status.h index 2544918ac263..f04040cea934 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -263,7 +263,7 @@ enum class StatusCode : char { RpcError = 30, OutOfResource = 31, ObjectRefEndOfStream = 32, - AuthError = 33, + Unauthenticated = 33, // Indicates the input value is not valid. InvalidArgument = 34, // Indicates that a channel (a mutable plasma object) is closed and cannot be @@ -415,8 +415,8 @@ class RAY_EXPORT Status { return Status(StatusCode::OutOfResource, msg); } - static Status AuthError(const std::string &msg) { - return Status(StatusCode::AuthError, msg); + static Status Unauthenticated(const std::string &msg) { + return Status(StatusCode::Unauthenticated, msg); } static Status ChannelError(const std::string &msg) { @@ -475,7 +475,7 @@ class RAY_EXPORT Status { bool IsOutOfResource() const { return code() == StatusCode::OutOfResource; } - bool IsAuthError() const { return code() == StatusCode::AuthError; } + bool IsUnauthenticated() const { return code() == StatusCode::Unauthenticated; } bool IsChannelError() const { return code() == StatusCode::ChannelError; } diff --git a/src/ray/core_worker/BUILD.bazel b/src/ray/core_worker/BUILD.bazel index a92f4f4323b0..87fee53bb63d 100644 --- a/src/ray/core_worker/BUILD.bazel +++ b/src/ray/core_worker/BUILD.bazel @@ -78,6 +78,7 @@ ray_cc_library( "//src/ray/protobuf:core_worker_cc_proto", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/core_worker/grpc_service.cc b/src/ray/core_worker/grpc_service.cc index e5540aa502df..adb5b62786d4 100644 --- a/src/ray/core_worker/grpc_service.cc +++ b/src/ray/core_worker/grpc_service.cc @@ -15,6 +15,7 @@ #include "ray/core_worker/grpc_service.h" #include +#include #include namespace ray { @@ -23,91 +24,104 @@ namespace rpc { void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { /// TODO(vitsai): Remove this when auth is implemented for node manager. /// Disable gRPC server metrics since it incurs too high cardinality. - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, PushTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + PushTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ActorCallArgWaitComplete, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RayletNotifyGCSRestart, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectStatus, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, WaitForActorRefDeleted, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubLongPolling, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PubsubCommandBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, UpdateObjectLocationBatch, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetObjectLocationsOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, ReportGeneratorItemReturns, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, KillActor, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + KillActor, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + CancelTask, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, CancelRemoteTask, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RegisterMutableObjectReader, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, GetCoreWorkerStats, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, LocalGC, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, DeleteObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, SpillObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + LocalGC, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + DeleteObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, + SpillObjects, + max_active_rpcs_per_handler_, + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, RestoreSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, DeleteSpilledObjects, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, PlasmaObjectReady, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - CoreWorkerService, Exit, max_active_rpcs_per_handler_, AuthType::NO_AUTH); + CoreWorkerService, Exit, max_active_rpcs_per_handler_, ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, AssignObjectOwner, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService, NumPendingTasks, max_active_rpcs_per_handler_, - AuthType::NO_AUTH); + ClusterIdAuthType::NO_AUTH); } } // namespace rpc diff --git a/src/ray/core_worker/grpc_service.h b/src/ray/core_worker/grpc_service.h index 4559a45447c1..d605f5176533 100644 --- a/src/ray/core_worker/grpc_service.h +++ b/src/ray/core_worker/grpc_service.h @@ -29,9 +29,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/core_worker.grpc.pb.h" @@ -158,7 +161,8 @@ class CoreWorkerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: CoreWorkerService::AsyncService service_; diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index 2511da321245..b764ad410db6 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -353,6 +353,7 @@ ray_cc_library( "//src/ray/protobuf:gcs_service_cc_grpc", "//src/ray/rpc:grpc_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) diff --git a/src/ray/gcs/grpc_services.cc b/src/ray/gcs/grpc_services.cc index f1f3c55af3f1..66b4397782c2 100644 --- a/src/ray/gcs/grpc_services.cc +++ b/src/ray/gcs/grpc_services.cc @@ -22,7 +22,8 @@ namespace rpc { void ActorInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { /// The register & create actor RPCs take a long time, so we shouldn't limit their /// concurrency to avoid distributed deadlock. RPC_SERVICE_HANDLER(ActorInfoGcsService, RegisterActor, -1) @@ -42,13 +43,14 @@ void ActorInfoGrpcService::InitServerCallFactories( void NodeInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { // We only allow one cluster ID in the lifetime of a client. // So, if a client connects, it should not have a pre-existing different ID. RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeInfoGcsService, GetClusterId, max_active_rpcs_per_handler_, - AuthType::EMPTY_AUTH); + ClusterIdAuthType::EMPTY_AUTH); RPC_SERVICE_HANDLER(NodeInfoGcsService, RegisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, UnregisterNode, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(NodeInfoGcsService, DrainNode, max_active_rpcs_per_handler_) @@ -61,7 +63,8 @@ void NodeInfoGrpcService::InitServerCallFactories( void NodeResourceInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( NodeResourceInfoGcsService, GetAllAvailableResources, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -75,7 +78,8 @@ void NodeResourceInfoGrpcService::InitServerCallFactories( void InternalPubSubGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalPubSubGcsService, GcsPublish, max_active_rpcs_per_handler_); RPC_SERVICE_HANDLER( InternalPubSubGcsService, GcsSubscriberPoll, max_active_rpcs_per_handler_); @@ -86,7 +90,8 @@ void InternalPubSubGrpcService::InitServerCallFactories( void JobInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(JobInfoGcsService, AddJob, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, MarkJobFinished, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(JobInfoGcsService, GetAllJobInfo, max_active_rpcs_per_handler_) @@ -97,7 +102,8 @@ void JobInfoGrpcService::InitServerCallFactories( void RuntimeEnvGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( RuntimeEnvGcsService, PinRuntimeEnvURI, max_active_rpcs_per_handler_) } @@ -105,7 +111,8 @@ void RuntimeEnvGrpcService::InitServerCallFactories( void WorkerInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( WorkerInfoGcsService, ReportWorkerFailure, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(WorkerInfoGcsService, GetWorkerInfo, max_active_rpcs_per_handler_) @@ -121,7 +128,8 @@ void WorkerInfoGrpcService::InitServerCallFactories( void InternalKVGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(InternalKVGcsService, InternalKVGet, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( InternalKVGcsService, InternalKVMultiGet, max_active_rpcs_per_handler_) @@ -137,7 +145,8 @@ void InternalKVGrpcService::InitServerCallFactories( void TaskInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(TaskInfoGcsService, AddTaskEventData, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER(TaskInfoGcsService, GetTaskEvents, max_active_rpcs_per_handler_) } @@ -145,7 +154,8 @@ void TaskInfoGrpcService::InitServerCallFactories( void PlacementGroupInfoGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( PlacementGroupInfoGcsService, CreatePlacementGroup, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -166,7 +176,8 @@ namespace autoscaler { void AutoscalerStateGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER( AutoscalerStateService, GetClusterResourceState, max_active_rpcs_per_handler_) RPC_SERVICE_HANDLER( @@ -188,7 +199,8 @@ namespace events { void RayEventExportGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) { + const ClusterID &cluster_id, + const std::optional &auth_token) { RPC_SERVICE_HANDLER(RayEventExportGcsService, AddEvents, max_active_rpcs_per_handler_) } diff --git a/src/ray/gcs/grpc_services.h b/src/ray/gcs/grpc_services.h index d8a0899e2439..f7b34746114d 100644 --- a/src/ray/gcs/grpc_services.h +++ b/src/ray/gcs/grpc_services.h @@ -23,11 +23,13 @@ #pragma once #include +#include #include #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/id.h" #include "ray/gcs/grpc_service_interfaces.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/autoscaler.grpc.pb.h" @@ -51,7 +53,8 @@ class ActorInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: ActorInfoGcsService::AsyncService service_; @@ -74,7 +77,8 @@ class NodeInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: NodeInfoGcsService::AsyncService service_; @@ -97,7 +101,8 @@ class NodeResourceInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: NodeResourceInfoGcsService::AsyncService service_; @@ -120,7 +125,8 @@ class InternalPubSubGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: InternalPubSubGcsService::AsyncService service_; @@ -143,7 +149,8 @@ class JobInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: JobInfoGcsService::AsyncService service_; @@ -166,7 +173,8 @@ class RuntimeEnvGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: RuntimeEnvGcsService::AsyncService service_; @@ -189,7 +197,8 @@ class WorkerInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: WorkerInfoGcsService::AsyncService service_; @@ -212,7 +221,8 @@ class InternalKVGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: InternalKVGcsService::AsyncService service_; @@ -235,7 +245,8 @@ class TaskInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: TaskInfoGcsService::AsyncService service_; @@ -258,7 +269,8 @@ class PlacementGroupInfoGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: PlacementGroupInfoGcsService::AsyncService service_; @@ -283,7 +295,8 @@ class AutoscalerStateGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: AutoscalerStateService::AsyncService service_; @@ -310,7 +323,8 @@ class RayEventExportGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override; + const ClusterID &cluster_id, + const std::optional &auth_token) override; private: RayEventExportGcsService::AsyncService service_; diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index 0d7fe2a71b73..620dc9b9a985 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -220,7 +220,7 @@ class GcsClientTest : public ::testing::TestWithParam { auto status = stub->CheckAlive(&context, request, &reply); // If it is in memory, we don't have the new token until we connect again. if (!((!no_redis_ && status.ok()) || - (no_redis_ && GrpcStatusToRayStatus(status).IsAuthError()))) { + (no_redis_ && GrpcStatusToRayStatus(status).IsUnauthenticated()))) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); continue; @@ -991,7 +991,7 @@ TEST_P(GcsClientTest, TestGcsEmptyAuth) { auto status = stub->GetClusterId(&context, request, &reply); // We expect the wrong cluster ID - EXPECT_TRUE(GrpcStatusToRayStatus(status).IsAuthError()); + EXPECT_TRUE(GrpcStatusToRayStatus(status).IsUnauthenticated()); } TEST_P(GcsClientTest, TestGcsAuth) { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 90eaa40d4985..d9e30d76917a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -452,7 +452,7 @@ void NodeManager::RegisterGcs() { << "GCS consider this node to be dead. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " << "in the DB."; - } else if (status.IsAuthError()) { + } else if (status.IsUnauthenticated()) { RAY_LOG(FATAL) << "GCS returned an authentication error. This may happen when " << "GCS is not backed by a DB and restarted or there is data loss " diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 23cedf7eb265..637655b02e29 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -21,6 +21,7 @@ ray_cc_library( "//src/ray/common:grpc_util", "//src/ray/common:id", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_absl//absl/synchronization", ], ) @@ -106,6 +107,7 @@ ray_cc_library( "//src/ray/common:id", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token", "//src/ray/stats:stats_metric", "@com_github_grpc_grpc//:grpc++", ], @@ -122,6 +124,7 @@ ray_cc_library( "//src/ray/common:asio", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:network_util", "//src/ray/util:thread_utils", "@com_github_grpc_grpc//:grpc++", @@ -139,6 +142,7 @@ ray_cc_library( deps = [ ":grpc_server", "//src/ray/protobuf:node_manager_cc_grpc", + "//src/ray/rpc/authentication:authentication_token", "@com_github_grpc_grpc//:grpc++", ], ) @@ -154,6 +158,7 @@ ray_cc_library( "//src/ray/object_manager:object_manager_grpc_client_manager", "//src/ray/protobuf:object_manager_cc_grpc", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", ], diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 29bb21c29ebe..319915f3e17a 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -27,9 +27,12 @@ #include "absl/synchronization/mutex.h" #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/metrics.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/util/thread_utils.h" @@ -71,6 +74,7 @@ class ClientCallImpl : public ClientCall { /// \param[in] callback The callback function to handle the reply. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, + const std::optional &auth_token, std::shared_ptr stats_handle, bool record_stats, int64_t timeout_ms = -1) @@ -85,6 +89,10 @@ class ClientCallImpl : public ClientCall { if (!cluster_id.IsNil()) { context_.AddMetadata(kClusterIdKey, cluster_id.Hex()); } + // Add authentication token if provided + if (auth_token.has_value()) { + auth_token->SetMetadata(context_); + } } Status GetStatus() override { @@ -276,7 +284,12 @@ class ClientCallManager { } auto call = std::make_shared>( - callback, cluster_id_, std::move(stats_handle), record_stats_, method_timeout_ms); + callback, + cluster_id_, + AuthenticationTokenLoader::instance().GetToken(), + std::move(stats_handle), + record_stats_, + method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 542326de0bce..e471bf7e39de 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -26,6 +26,7 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/common.h" #include "ray/util/network_util.h" #include "ray/util/thread_utils.h" @@ -178,12 +179,13 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) } void GrpcServer::RegisterService(std::unique_ptr &&service, - bool token_auth) { + bool cluster_id_auth_enabled) { + if (cluster_id_auth_enabled && cluster_id_.IsNil()) { + RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + } for (int i = 0; i < num_threads_; i++) { - if (token_auth && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; - } - service->InitServerCallFactories(cqs_[i], &server_call_factories_, cluster_id_); + service->InitServerCallFactories( + cqs_[i], &server_call_factories_, cluster_id_, auth_token_); } services_.push_back(std::move(service)); } diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 0727b4d550f3..bf7eb7f8c5d1 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -24,39 +24,44 @@ #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/server_call.h" namespace ray { namespace rpc { /// \param MAX_ACTIVE_RPCS Maximum number of RPCs to handle at the same time. -1 means no /// limit. -#define _RPC_SERVICE_HANDLER( \ - SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, \ - &SERVICE::AsyncService::Request##HANDLER, \ - service_handler_, \ - &SERVICE##Handler::Handle##HANDLER, \ - cq, \ - main_service_, \ - #SERVICE ".grpc_server." #HANDLER, \ - AUTH_TYPE == AuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ - MAX_ACTIVE_RPCS, \ - RECORD_METRICS)); \ +#define _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE, RECORD_METRICS) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, \ + &SERVICE::AsyncService::Request##HANDLER, \ + service_handler_, \ + &SERVICE##Handler::Handle##HANDLER, \ + cq, \ + main_service_, \ + #SERVICE ".grpc_server." #HANDLER, \ + AUTH_TYPE == ClusterIdAuthType::NO_AUTH ? ClusterID::Nil() : cluster_id, \ + auth_token, \ + MAX_ACTIVE_RPCS, \ + RECORD_METRICS)); \ server_call_factories->emplace_back(std::move(HANDLER##_call_factory)); /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, true) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, true) /// Define a RPC service handler with gRPC server metrics disabled. #define RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(SERVICE, HANDLER, MAX_ACTIVE_RPCS) \ - _RPC_SERVICE_HANDLER(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AuthType::LAZY_AUTH, false) + _RPC_SERVICE_HANDLER( \ + SERVICE, HANDLER, MAX_ACTIVE_RPCS, ClusterIdAuthType::LAZY_AUTH, false) /// Define a RPC service handler with gRPC server metrics enabled. #define RPC_SERVICE_HANDLER_CUSTOM_AUTH(SERVICE, HANDLER, MAX_ACTIVE_RPCS, AUTH_TYPE) \ @@ -90,13 +95,20 @@ class GrpcServer { const uint32_t port, bool listen_to_localhost_only, int num_threads = 1, - int64_t keepalive_time_ms = 7200000 /*2 hours, grpc default*/) + int64_t keepalive_time_ms = 7200000, /*2 hours, grpc default*/ + std::optional auth_token = std::nullopt) : name_(std::move(name)), port_(port), listen_to_localhost_only_(listen_to_localhost_only), is_shutdown_(true), num_threads_(num_threads), keepalive_time_ms_(keepalive_time_ms) { + // Initialize auth token: use provided value or load from AuthenticationTokenLoader + if (auth_token.has_value()) { + auth_token_ = std::move(auth_token.value()); + } else { + auth_token_ = AuthenticationTokenLoader::instance().GetToken(); + } Init(); } @@ -157,6 +169,8 @@ class GrpcServer { const bool listen_to_localhost_only_; /// Token representing ID of this cluster. ClusterID cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; /// Indicates whether this server is in shutdown state. std::atomic is_shutdown_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. @@ -208,10 +222,13 @@ class GrpcService { /// \param[in] cq The grpc completion queue. /// \param[out] server_call_factories The `ServerCallFactory` objects, /// and the maximum number of concurrent requests that this gRPC server can handle. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. virtual void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) = 0; + const ClusterID &cluster_id, + const std::optional &auth_token) = 0; /// The main event loop, to which the service handler functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index fba7780afc69..b819a7e98a13 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -15,9 +15,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/node_manager.grpc.pb.h" #include "src/ray/protobuf/node_manager.pb.h" @@ -29,7 +32,8 @@ class ServerCallFactory; /// TODO(vitsai): Remove this when auth is implemented for node manager #define RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(NodeManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + NodeManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. #define RAY_NODE_MANAGER_RPC_HANDLERS \ @@ -206,7 +210,8 @@ class NodeManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::optional &auth_token) override { RAY_NODE_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/object_manager_server.h b/src/ray/rpc/object_manager_server.h index 4d294b483fff..576de9396142 100644 --- a/src/ray/rpc/object_manager_server.h +++ b/src/ray/rpc/object_manager_server.h @@ -15,9 +15,12 @@ #pragma once #include +#include +#include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" #include "src/ray/protobuf/object_manager.pb.h" @@ -28,7 +31,8 @@ namespace rpc { class ServerCallFactory; #define RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(METHOD) \ - RPC_SERVICE_HANDLER_CUSTOM_AUTH(ObjectManagerService, METHOD, -1, AuthType::NO_AUTH) + RPC_SERVICE_HANDLER_CUSTOM_AUTH( \ + ObjectManagerService, METHOD, -1, ClusterIdAuthType::NO_AUTH) #define RAY_OBJECT_MANAGER_RPC_HANDLERS \ RAY_OBJECT_MANAGER_RPC_SERVICE_HANDLER(Push) \ @@ -76,7 +80,8 @@ class ObjectManagerGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override { + const ClusterID &cluster_id, + const std::optional &auth_token) override { RAY_OBJECT_MANAGER_RPC_HANDLERS } diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e79ae6dae22d..b84ab4e22dc2 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -20,13 +20,16 @@ #include #include #include +#include #include #include "ray/common/asio/asio_chaos.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/constants.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/metrics.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/stats/metric.h" @@ -34,8 +37,8 @@ namespace ray { namespace rpc { -// Authentication type of ServerCall. -enum class AuthType { +// Cluster ID authentication type of ServerCall. +enum class ClusterIdAuthType { NO_AUTH, // Do not authenticate (accept all). LAZY_AUTH, // Accept missing cluster ID, but reject incorrect one. EMPTY_AUTH, // Accept only empty cluster ID. @@ -149,7 +152,7 @@ using HandleRequestFunction = void (ServiceHandler::*)(Request, template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallImpl : public ServerCall { public: /// Constructor. @@ -159,6 +162,8 @@ class ServerCallImpl : public ServerCall { /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. /// \param[in] preprocess_function If not nullptr, it will be called before handling /// request. @@ -169,6 +174,7 @@ class ServerCallImpl : public ServerCall { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::optional &auth_token, bool record_metrics, std::function preprocess_function = nullptr) : state_(ServerCallState::PENDING), @@ -179,6 +185,7 @@ class ServerCallImpl : public ServerCall { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(auth_token), start_time_(0), record_metrics_(record_metrics) { reply_ = google::protobuf::Arena::CreateMessage(&arena_); @@ -194,8 +201,18 @@ class ServerCallImpl : public ServerCall { void HandleRequest() override { stats_handle_ = io_service_.stats().RecordStart(call_name_); bool auth_success = true; + bool token_auth_failed = false; + bool cluster_id_auth_failed = false; + + // Token authentication + if (!ValidateBearerToken()) { + auth_success = false; + token_auth_failed = true; + } + + // Cluster ID authentication if (::RayConfig::instance().enable_cluster_auth()) { - if constexpr (EnableAuth == AuthType::LAZY_AUTH) { + if constexpr (EnableAuth == ClusterIdAuthType::LAZY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -203,8 +220,9 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Wrong cluster ID token in request! Expected: " << cluster_id_.Hex() << ", but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } - } else if constexpr (EnableAuth == AuthType::EMPTY_AUTH) { + } else if constexpr (EnableAuth == ClusterIdAuthType::EMPTY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); if (auto it = metadata.find(kClusterIdKey); @@ -212,6 +230,7 @@ class ServerCallImpl : public ServerCall { RAY_LOG(WARNING) << "Cluster ID token in request! Expected Nil, " << "but got: " << it->second; auth_success = false; + cluster_id_auth_failed = true; } } } @@ -221,24 +240,32 @@ class ServerCallImpl : public ServerCall { grpc_server_req_handling_counter_.Record(1.0, {{"Method", call_name_}}); } if (!io_service_.stopped()) { - io_service_.post([this, auth_success] { HandleRequestImpl(auth_success); }, - call_name_ + ".HandleRequestImpl", - // Implement the delay of the rpc server call as the - // delay of HandleRequestImpl(). - ray::asio::testing::GetDelayUs(call_name_)); + io_service_.post( + [this, auth_success, token_auth_failed, cluster_id_auth_failed] { + HandleRequestImpl(auth_success, token_auth_failed, cluster_id_auth_failed); + }, + call_name_ + ".HandleRequestImpl", + // Implement the delay of the rpc server call as the + // delay of HandleRequestImpl(). + ray::asio::testing::GetDelayUs(call_name_)); } else { // Handle service for rpc call has stopped, we must handle the call here // to send reply and remove it from cq RAY_LOG(DEBUG) << "Handle service has been closed."; if (auth_success) { SendReply(Status::Invalid("HandleServiceClosed")); + } else if (token_auth_failed) { + SendReply(Status::Unauthenticated( + "InvalidAuthToken: Authentication token is missing or incorrect")); } else { - SendReply(Status::AuthError("WrongClusterID")); + SendReply(Status::Unauthenticated("WrongClusterID")); } } } - void HandleRequestImpl(bool auth_success) { + void HandleRequestImpl(bool auth_success, + bool token_auth_failed, + bool cluster_id_auth_failed) { if constexpr (std::is_base_of_v) { if (!service_handler_initialized_) { service_handler_.WaitUntilInitialized(); @@ -254,10 +281,15 @@ class ServerCallImpl : public ServerCall { factory_.CreateCall(); } if (!auth_success) { - boost::asio::post(GetServerCallExecutor(), [this]() { - SendReply( - Status::AuthError("WrongClusterID: Perhaps the client is accessing GCS " - "after it has restarted.")); + boost::asio::post(GetServerCallExecutor(), [this, token_auth_failed]() { + if (token_auth_failed) { + SendReply(Status::Unauthenticated( + "InvalidAuthToken: Authentication token is missing or incorrect")); + } else { + SendReply(Status::Unauthenticated( + "WrongClusterID: Perhaps the client is accessing GCS " + "after it has restarted.")); + } }); } else { (service_handler_.*handle_request_function_)( @@ -306,6 +338,32 @@ class ServerCallImpl : public ServerCall { const ServerCallFactory &GetServerCallFactory() override { return factory_; } private: + /// Validates token-based authentication. + /// Returns true if authentication succeeds or is not required. + /// Returns false if authentication is required but fails. + bool ValidateBearerToken() { + if (!auth_token_.has_value() || auth_token_->empty()) { + return true; // No auth required + } + + const auto &metadata = context_.client_metadata(); + auto it = metadata.find(kAuthTokenKey); + if (it == metadata.end()) { + RAY_LOG(WARNING) << "Missing authorization header in request!"; + return false; + } + + const std::string_view header(it->second.data(), it->second.length()); + AuthenticationToken provided_token = AuthenticationToken::FromMetadata(header); + + if (!auth_token_->Equals(provided_token)) { + RAY_LOG(WARNING) << "Invalid bearer token in request!"; + return false; + } + + return true; + } + /// Log the duration this query used void LogProcessTime() { EventTracker::RecordEnd(std::move(stats_handle_)); @@ -373,6 +431,9 @@ class ServerCallImpl : public ServerCall { /// Check skipped if empty. const ClusterID &cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; + /// The callback when sending reply successes. std::function send_reply_success_callback_ = nullptr; @@ -397,7 +458,7 @@ class ServerCallImpl : public ServerCall { ray::stats::Count grpc_server_req_failed_counter_{ GetGrpcServerReqFailedCounterMetric()}; - template + template friend class ServerCallFactoryImpl; }; @@ -425,7 +486,7 @@ template + ClusterIdAuthType EnableAuth = ClusterIdAuthType::NO_AUTH> class ServerCallFactoryImpl : public ServerCallFactory { using AsyncService = typename GrpcService::AsyncService; @@ -440,6 +501,8 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] cq The `CompletionQueue`. /// \param[in] io_service The event loop. /// \param[in] call_name The name of the RPC call. + /// \param[in] cluster_id The cluster ID for authentication. + /// \param[in] auth_token The authentication token for token-based authentication. /// \param[in] max_active_rpcs Maximum request number to handle at the same time. -1 /// means no limit. /// \param[in] record_metrics If true, it records and exports the gRPC server metrics. @@ -452,6 +515,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { instrumented_io_context &io_service, std::string call_name, const ClusterID &cluster_id, + const std::optional &auth_token, int64_t max_active_rpcs, bool record_metrics) : service_(service), @@ -462,6 +526,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_(io_service), call_name_(std::move(call_name)), cluster_id_(cluster_id), + auth_token_(auth_token), max_active_rpcs_(max_active_rpcs), record_metrics_(record_metrics) {} @@ -475,6 +540,7 @@ class ServerCallFactoryImpl : public ServerCallFactory { io_service_, call_name_, cluster_id_, + auth_token_, record_metrics_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. @@ -514,6 +580,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// Check skipped if empty. const ClusterID cluster_id_; + /// Authentication token for token-based authentication. + std::optional auth_token_; + /// Maximum request number to handle at the same time. /// -1 means no limit. uint64_t max_active_rpcs_; diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc new file mode 100644 index 000000000000..5501e6e568b0 --- /dev/null +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -0,0 +1,221 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" + +namespace ray { +namespace rpc { + +class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { + public: + void SetUp() override { + // Configure token auth via RayConfig + std::string config_json = R"({"auth_mode": "token"})"; + RayConfig::instance().initialize(config_json); + AuthenticationTokenLoader::instance().ResetCache(); + } + + void SetUpServerAndClient(const std::string &server_token, + const std::string &client_token) { + // Set client token in environment for ClientCallManager to read from + // AuthenticationTokenLoader + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + } else { + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + unsetenv("RAY_AUTH_TOKEN"); + } + + // Start client thread FIRST + client_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + client_io_service_work_(client_io_service_.get_executor()); + client_io_service_.run(); + }); + + // Start handler thread for server + handler_thread_ = std::make_unique([this]() { + boost::asio::executor_work_guard + handler_io_service_work_(handler_io_service_.get_executor()); + handler_io_service_.run(); + }); + + // Create and start server + // Pass server token explicitly for testing scenarios with different tokens + std::optional server_auth_token; + if (!server_token.empty()) { + server_auth_token = AuthenticationToken(server_token); + } else { + // Explicitly set empty token (no auth required) + server_auth_token = AuthenticationToken(""); + } + grpc_server_.reset(new GrpcServer("test", 0, true, 1, 7200000, server_auth_token)); + grpc_server_->RegisterService( + std::make_unique(handler_io_service_, test_service_handler_), + false); + grpc_server_->Run(); + + while (grpc_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Create client (will read auth token from AuthenticationTokenLoader which reads the + // environment) + client_call_manager_.reset( + new ClientCallManager(client_io_service_, false, /*local_address=*/"")); + grpc_client_.reset(new GrpcClient( + "127.0.0.1", grpc_server_->GetPort(), *client_call_manager_)); + } + + void TearDown() override { + if (grpc_client_) { + grpc_client_.reset(); + } + if (client_call_manager_) { + client_call_manager_.reset(); + } + if (client_thread_) { + client_io_service_.stop(); + if (client_thread_->joinable()) { + client_thread_->join(); + } + } + + if (grpc_server_) { + grpc_server_->Shutdown(); + } + if (handler_thread_) { + handler_io_service_.stop(); + if (handler_thread_->joinable()) { + handler_thread_->join(); + } + } + + // Clean up environment variables + unsetenv("RAY_AUTH_TOKEN"); + unsetenv("RAY_AUTH_TOKEN_PATH"); + // Reset the token loader for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + // Helper to execute RPC and wait for result + struct PingResult { + bool completed; + bool success; + std::string error_msg; + }; + + PingResult ExecutePingAndWait() { + PingRequest request; + auto result_promise = std::make_shared>(); + std::future result_future = result_promise->get_future(); + + Ping(request, [result_promise](const Status &status, const PingReply &reply) { + RAY_LOG(INFO) << "Token auth test replied, status=" << status; + bool success = status.ok(); + std::string error_msg = status.ok() ? "" : status.message(); + result_promise->set_value({true, success, error_msg}); + }); + + // Wait for response with timeout + if (result_future.wait_for(std::chrono::seconds(5)) == std::future_status::timeout) { + return {false, false, "Request timed out"}; + } + + return result_future.get(); + } + + protected: + VOID_RPC_CLIENT_METHOD(TestService, Ping, grpc_client_, /*method_timeout_ms*/ -1, ) + + TestServiceHandler test_service_handler_; + instrumented_io_context handler_io_service_; + std::unique_ptr handler_thread_; + std::unique_ptr grpc_server_; + + instrumented_io_context client_io_service_; + std::unique_ptr client_thread_; + std::unique_ptr client_call_manager_; + std::unique_ptr> grpc_client_; +}; + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthSuccess) { + // Both server and client have the same token + const std::string token = "test_secret_token_123"; + SetUpServerAndClient(token, token); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_TRUE(result.success) << "Request should succeed with matching token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureWrongToken) { + // Server and client have different tokens + SetUpServerAndClient("server_token", "wrong_client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + ASSERT_FALSE(result.success) << "Request should fail with wrong client token"; + ASSERT_TRUE(result.error_msg.find( + "InvalidAuthToken: Authentication token is missing or incorrect") != + std::string::npos) + << "Error message should contain token auth error. Got: " << result.error_msg; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, TestTokenAuthFailureMissingToken) { + // Server expects token, client doesn't send one (empty token) + SetUpServerAndClient("server_token", ""); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // If the server has a token but the client doesn't, auth should fail + ASSERT_FALSE(result.success) + << "Request should fail when client doesn't provide required token"; +} + +TEST_F(TestGrpcServerClientTokenAuthFixture, + TestClientProvidesTokenServerDoesNotRequire) { + // Client provides token, but server doesn't require one (should succeed) + SetUpServerAndClient("", "client_token"); + + auto result = ExecutePingAndWait(); + + ASSERT_TRUE(result.completed) << "Request did not complete in time"; + // Server should accept request even though client sent unnecessary token + ASSERT_TRUE(result.success) + << "Request should succeed when server doesn't require token"; +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/grpc_bench/BUILD.bazel b/src/ray/rpc/tests/grpc_bench/BUILD.bazel index 5238a11c0baf..4594e3873c5f 100644 --- a/src/ray/rpc/tests/grpc_bench/BUILD.bazel +++ b/src/ray/rpc/tests/grpc_bench/BUILD.bazel @@ -28,5 +28,6 @@ cc_binary( ":helloworld_cc_lib", "//src/ray/common:asio", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", ], ) diff --git a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc index 552b83bff3bc..81dd9477f948 100644 --- a/src/ray/rpc/tests/grpc_bench/grpc_bench.cc +++ b/src/ray/rpc/tests/grpc_bench/grpc_bench.cc @@ -13,10 +13,12 @@ // limitations under the License. #include +#include #include #include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "src/ray/rpc/test/grpc_bench/helloworld.grpc.pb.h" #include "src/ray/rpc/test/grpc_bench/helloworld.pb.h" @@ -57,9 +59,11 @@ class GreeterGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector> *server_call_factories, - const ClusterID &cluster_id) override{ - RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( - Greeter, SayHello, -1, AuthType::NO_AUTH)} + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED( + Greeter, SayHello, -1, ClusterIdAuthType::NO_AUTH); + } /// The grpc async service object. Greeter::AsyncService service_; diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index f51a80b99f73..8bc6e8284493 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -14,92 +14,16 @@ #include #include -#include +#include #include "gtest/gtest.h" +#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" -#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { -class TestServiceHandler { - public: - void HandlePing(PingRequest request, - PingReply *reply, - SendReplyCallback send_reply_callback) { - RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); - request_count++; - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - if (request.no_reply()) { - RAY_LOG(INFO) << "No reply!"; - return; - } - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - void HandlePingTimeout(PingTimeoutRequest request, - PingTimeoutReply *reply, - SendReplyCallback send_reply_callback) { - while (frozen) { - RAY_LOG(INFO) << "Server is frozen..."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - RAY_LOG(INFO) << "Handling and replying request."; - send_reply_callback( - ray::Status::OK(), - /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, - /*reply_failure=*/ - [this]() { - RAY_LOG(INFO) << "Reply failed."; - reply_failure_count++; - }); - } - - std::atomic request_count{0}; - std::atomic reply_failure_count{0}; - std::atomic frozen{false}; -}; - -class TestGrpcService : public GrpcService { - public: - /// Constructor. - /// - /// \param[in] handler The service handler that actually handle the requests. - explicit TestGrpcService(instrumented_io_context &handler_io_service_, - TestServiceHandler &handler) - : GrpcService(handler_io_service_), service_handler_(handler){}; - - protected: - grpc::Service &GetGrpcService() override { return service_; } - - void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector> *server_call_factories, - const ClusterID &cluster_id) override { - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, Ping, /*max_active_rpcs=*/1, AuthType::NO_AUTH); - RPC_SERVICE_HANDLER_CUSTOM_AUTH( - TestService, PingTimeout, /*max_active_rpcs=*/1, AuthType::NO_AUTH); - } - - private: - /// The grpc async service object. - TestService::AsyncService service_; - /// The service handler that actually handle the requests. - TestServiceHandler &service_handler_; -}; class TestGrpcServerClientFixture : public ::testing::Test { public: @@ -326,6 +250,7 @@ TEST_F(TestGrpcServerClientFixture, TestTimeoutMacro) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/tests/grpc_test_common.h b/src/ray/rpc/tests/grpc_test_common.h new file mode 100644 index 000000000000..1ce199f79511 --- /dev/null +++ b/src/ray/rpc/tests/grpc_test_common.h @@ -0,0 +1,109 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" + +namespace ray { +namespace rpc { + +class TestServiceHandler { + public: + void HandlePing(PingRequest request, + PingReply *reply, + SendReplyCallback send_reply_callback) { + RAY_LOG(INFO) << "Got ping request, no_reply=" << request.no_reply(); + request_count++; + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + if (request.no_reply()) { + RAY_LOG(INFO) << "No reply!"; + return; + } + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + void HandlePingTimeout(PingTimeoutRequest request, + PingTimeoutReply *reply, + SendReplyCallback send_reply_callback) { + while (frozen) { + RAY_LOG(INFO) << "Server is frozen..."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_LOG(INFO) << "Handling and replying request."; + send_reply_callback( + ray::Status::OK(), + /*reply_success=*/[]() { RAY_LOG(INFO) << "Reply success."; }, + /*reply_failure=*/ + [this]() { + RAY_LOG(INFO) << "Reply failed."; + reply_failure_count++; + }); + } + + std::atomic request_count{0}; + std::atomic reply_failure_count{0}; + std::atomic frozen{false}; +}; + +class TestGrpcService : public GrpcService { + public: + /// Constructor. + /// + /// \param[in] handler The service handler that actually handle the requests. + explicit TestGrpcService(instrumented_io_context &handler_io_service_, + TestServiceHandler &handler) + : GrpcService(handler_io_service_), service_handler_(handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector> *server_call_factories, + const ClusterID &cluster_id, + const std::optional &auth_token) override { + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, Ping, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + RPC_SERVICE_HANDLER_CUSTOM_AUTH( + TestService, PingTimeout, /*max_active_rpcs=*/1, ClusterIdAuthType::NO_AUTH); + } + + private: + /// The grpc async service object. + TestService::AsyncService service_; + /// The service handler that actually handle the requests. + TestServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray From c821c21aca7127bb3d5917f11069ea85589c8bab Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:25:26 +0000 Subject: [PATCH 03/48] revert unneeded changs from src/ray/rpc/tests/BUILD.bazel Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 279b68f91ba3..0e253f612952 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,29 +18,9 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ - "//src/ray/protobuf:test_service_cc_grpc", - "//src/ray/rpc:grpc_client", - "//src/ray/rpc:grpc_server", - "@com_google_googletest//:gtest_main", - ], -) - -ray_cc_test( - name = "grpc_auth_token_tests", - size = "small", - srcs = [ - "grpc_auth_token_tests.cc", - "grpc_test_common.h", - ], - tags = ["team:core"], - deps = [ - "//src/ray/protobuf:test_service_cc_grpc", - "//src/ray/rpc:grpc_client", - "//src/ray/rpc:grpc_server", "@com_google_googletest//:gtest_main", ], ) From a14dc6951a0f66045867053e349112b6496583c5 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 23 Oct 2025 16:26:07 +0000 Subject: [PATCH 04/48] readd dependencies Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 0e253f612952..d5113ae0d3aa 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -21,6 +21,9 @@ ray_cc_test( ], tags = ["team:core"], deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", "@com_google_googletest//:gtest_main", ], ) From 4801ed7353d131bd905a372a27a60cb55a8ec658 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 24 Oct 2025 07:17:54 +0000 Subject: [PATCH 05/48] address comments + fix build Signed-off-by: sampan --- .../rpc/authentication/authentication_token.h | 5 +-- .../authentication_token_loader.cc | 44 +++++++++++-------- .../tests/authentication_token_loader_test.cc | 8 ++-- .../rpc/tests/authentication_token_test.cc | 11 ----- 4 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h index 6846d3c08ada..4f32310784de 100644 --- a/src/ray/rpc/authentication/authentication_token.h +++ b/src/ray/rpc/authentication/authentication_token.h @@ -92,8 +92,8 @@ class AuthenticationToken { /// prefix) /// @return AuthenticationToken object (empty if format invalid) static AuthenticationToken FromMetadata(std::string_view metadata_value) { - const std::string_view prefix(kBearerPrefix, sizeof(kBearerPrefix) - 1); - if (metadata_value.size() <= prefix.size() || + const std::string_view prefix(kBearerPrefix); + if (metadata_value.size() < prefix.size() || metadata_value.substr(0, prefix.size()) != prefix) { return AuthenticationToken(); // Invalid format, return empty } @@ -145,7 +145,6 @@ class AuthenticationToken { } void MoveFrom(AuthenticationToken &&other) noexcept { - SecureClear(); secret_ = std::move(other.secret_); // Clear the moved-from object explicitly for security // Note: 'other' is already an rvalue reference, no need to move again diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 59a1184e080a..621f28fe351c 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -20,11 +20,6 @@ #include "ray/util/logging.h" -#if defined(__APPLE__) || defined(__linux__) -#include -#include -#endif - #ifdef _WIN32 #ifndef _WINDOWS_ #ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related @@ -63,9 +58,10 @@ std::optional AuthenticationTokenLoader::GetToken() { // If no token found and auth is enabled, fail with RAY_CHECK RAY_CHECK(!token.empty()) - << "Token authentication is enabled but no authentication token was found. " - << "Please set RAY_AUTH_TOKEN environment variable, RAY_AUTH_TOKEN_PATH to a file " - << "containing the token, or create a token file at ~/.ray/auth_token"; + << "Token authentication is enabled but Ray couldn't find an authentication token. " + << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " + "point to a file with the token, " + << "or create a token file at ~/.ray/auth_token."; // Cache and return the loaded token cached_token_ = std::move(token); @@ -89,22 +85,26 @@ std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { // Precedence 1: RAY_AUTH_TOKEN environment variable const char *env_token = std::getenv("RAY_AUTH_TOKEN"); - if (env_token != nullptr && std::string(env_token).length() > 0) { - RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " - "variable"; - return AuthenticationToken(TrimWhitespace(std::string(env_token))); + if (env_token != nullptr) { + std::string token_str(env_token); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return AuthenticationToken(TrimWhitespace(token_str)); + } } // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); - if (env_token_path != nullptr && std::string(env_token_path).length() > 0) { - std::string token_str = TrimWhitespace(ReadTokenFromFile(env_token_path)); - if (!token_str.empty()) { - RAY_LOG(DEBUG) << "Loaded authentication token from file: " << env_token_path; + if (env_token_path != nullptr) { + std::string path_str(env_token_path); + if (!path_str.empty()) { + std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); + RAY_CHECK(!token_str.empty()) + << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " + << path_str; + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); - } else { - RAY_LOG(WARNING) << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened: " - << env_token_path; } } @@ -159,6 +159,12 @@ std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { std::string whitespace = " \t\n\r\f\v"; std::string trimmed_str = str; trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + + // if the string is empty, return it + if (trimmed_str.empty()) { + return trimmed_str; + } + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); return trimmed_str; } diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 616a13b0e457..2332c6d09313 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -109,8 +109,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { void set_env_var(const char *name, const char *value) { #ifdef _WIN32 - std::string env_str = std::string(name) + "=" + std::string(value); - _putenv(env_str.c_str()); + _putenv_s(name, value); #else setenv(name, value, 1); #endif @@ -118,8 +117,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { void unset_env_var(const char *name) { #ifdef _WIN32 - std::string env_str = std::string(name) + "="; - _putenv(env_str.c_str()); + _putenv_s(name, "") #else unsetenv(name); #endif @@ -301,7 +299,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { auto &loader = AuthenticationTokenLoader::instance(); loader.GetToken(); }, - "Token authentication is enabled but no authentication token was found"); + "Token authentication is enabled but Ray couldn't find an authentication token."); } TEST_F(AuthenticationTokenLoaderTest, TestCaching) { diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc index db88d7481da1..77ae4eb7cfc2 100644 --- a/src/ray/rpc/tests/authentication_token_test.cc +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -59,17 +59,6 @@ TEST_F(AuthenticationTokenTest, TestMoveAssignment) { EXPECT_TRUE(token1.empty()); } -TEST_F(AuthenticationTokenTest, TestSelfMoveAssignment) { - AuthenticationToken token("test-token"); - - // Self-assignment should not break the token - token = std::move(token); - - EXPECT_FALSE(token.empty()); - AuthenticationToken expected("test-token"); - EXPECT_TRUE(token.Equals(expected)); -} - TEST_F(AuthenticationTokenTest, TestEquals) { AuthenticationToken token1("same-token"); AuthenticationToken token2("same-token"); From d24f23c3bb3049b41c7046b14685ba02233217d6 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 24 Oct 2025 09:36:31 +0000 Subject: [PATCH 06/48] address comments Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 2 +- src/ray/rpc/server_call.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index e471bf7e39de..785d65be741b 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -181,7 +181,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { if (cluster_id_auth_enabled && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for token auth!"; + RAY_LOG(FATAL) << "Expected cluster ID for cluster ID authentication!"; } for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index b84ab4e22dc2..bb7a431934cf 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -211,7 +211,7 @@ class ServerCallImpl : public ServerCall { } // Cluster ID authentication - if (::RayConfig::instance().enable_cluster_auth()) { + if (auth_success && ::RayConfig::instance().enable_cluster_auth()) { if constexpr (EnableAuth == ClusterIdAuthType::LAZY_AUTH) { RAY_CHECK(!cluster_id_.IsNil()) << "Expected cluster ID in server call!"; auto &metadata = context_.client_metadata(); From f6017a0ab28d3a39213b58e5c3af65de7b8372dd Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 05:51:19 +0000 Subject: [PATCH 07/48] [Core] token auth support in bidi-syncer and pubsub rpc Signed-off-by: sampan --- src/ray/common/ray_syncer/ray_syncer.cc | 3 +- src/ray/common/ray_syncer/ray_syncer.h | 7 +- .../common/ray_syncer/ray_syncer_client.cc | 7 + .../common/ray_syncer/ray_syncer_server.cc | 33 +++- src/ray/common/ray_syncer/ray_syncer_server.h | 9 +- src/ray/common/tests/ray_syncer_test.cc | 155 ++++++++++++++++++ src/ray/gcs/gcs_server.cc | 4 +- src/ray/pubsub/python_gcs_subscriber.cc | 11 ++ src/ray/raylet/node_manager.cc | 5 +- src/ray/rpc/grpc_server.h | 2 + 10 files changed, 228 insertions(+), 8 deletions(-) diff --git a/src/ray/common/ray_syncer/ray_syncer.cc b/src/ray/common/ray_syncer/ray_syncer.cc index 7991fdcd2c92..49d4ac9d7404 100644 --- a/src/ray/common/ray_syncer/ray_syncer.cc +++ b/src/ray/common/ray_syncer/ray_syncer.cc @@ -244,7 +244,8 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont } RAY_LOG(INFO).WithField(NodeID::FromBinary(node_id)) << "Connection is broken."; syncer_.node_state_->RemoveNode(node_id); - }); + }, + /*auth_token=*/auth_token_); RAY_LOG(DEBUG).WithField(NodeID::FromBinary(reactor->GetRemoteNodeID())) << "Get connection"; // Disconnect exiting connection if there is any. diff --git a/src/ray/common/ray_syncer/ray_syncer.h b/src/ray/common/ray_syncer/ray_syncer.h index e00cdffe3f3c..d83ab65157d0 100644 --- a/src/ray/common/ray_syncer/ray_syncer.h +++ b/src/ray/common/ray_syncer/ray_syncer.h @@ -197,7 +197,10 @@ class RaySyncer { /// like tree-based one. class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { public: - explicit RaySyncerService(RaySyncer &syncer) : syncer_(syncer) {} + explicit RaySyncerService( + RaySyncer &syncer, + std::optional auth_token = std::nullopt) + : syncer_(syncer), auth_token_(std::move(auth_token)) {} grpc::ServerBidiReactor *StartSync( grpc::CallbackServerContext *context) override; @@ -205,6 +208,8 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { private: // The ray syncer this RPC wrappers of. RaySyncer &syncer_; + // Authentication token for validation, will be empty if token authentication is disabled + std::optional auth_token_; }; } // namespace ray::syncer diff --git a/src/ray/common/ray_syncer/ray_syncer_client.cc b/src/ray/common/ray_syncer/ray_syncer_client.cc index 935879bfa731..6945878e88f2 100644 --- a/src/ray/common/ray_syncer/ray_syncer_client.cc +++ b/src/ray/common/ray_syncer/ray_syncer_client.cc @@ -18,6 +18,8 @@ #include #include +#include "ray/rpc/authentication/authentication_token_loader.h" + namespace ray::syncer { RayClientBidiReactor::RayClientBidiReactor( @@ -32,6 +34,11 @@ RayClientBidiReactor::RayClientBidiReactor( cleanup_cb_(std::move(cleanup_cb)), stub_(std::move(stub)) { client_context_.AddMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); + // Add authentication token if token authentication is enabled + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(client_context_); + } stub_->async()->StartSync(&client_context_, this); // Prevent this call from being terminated. // Check https://github.com/grpc/proposal/blob/master/L67-cpp-callback-api.md diff --git a/src/ray/common/ray_syncer/ray_syncer_server.cc b/src/ray/common/ray_syncer/ray_syncer_server.cc index 2dfc569fc494..65a6188d92f9 100644 --- a/src/ray/common/ray_syncer/ray_syncer_server.cc +++ b/src/ray/common/ray_syncer/ray_syncer_server.cc @@ -17,6 +17,8 @@ #include #include +#include "ray/common/constants.h" + namespace ray::syncer { namespace { @@ -35,13 +37,40 @@ RayServerBidiReactor::RayServerBidiReactor( instrumented_io_context &io_context, const std::string &local_node_id, std::function)> message_processor, - std::function cleanup_cb) + std::function cleanup_cb, + const std::optional &auth_token) : RaySyncerBidiReactorBase( io_context, GetNodeIDFromServerContext(server_context), std::move(message_processor)), cleanup_cb_(std::move(cleanup_cb)), - server_context_(server_context) { + server_context_(server_context), + auth_token_(auth_token) { + + if (auth_token_.has_value() && !auth_token_->empty()) { + // Validate authentication token + const auto &metadata = server_context->client_metadata(); + auto it = metadata.find(ray::rpc::kAuthTokenKey); + if (it == metadata.end()) { + RAY_LOG(WARNING) << "Missing authorization header in syncer connection from node " + << NodeID::FromBinary(GetRemoteNodeID()); + Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, + "Missing authorization header")); + return; + } + + const std::string_view header(it->second.data(), it->second.length()); + ray::rpc::AuthenticationToken provided_token = + ray::rpc::AuthenticationToken::FromMetadata(header); + + if (!auth_token_->Equals(provided_token)) { + RAY_LOG(WARNING) << "Invalid bearer token in syncer connection from node " + << NodeID::FromBinary(GetRemoteNodeID()); + Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Invalid bearer token")); + return; + } + } + // Send the local node id to the remote server_context_->AddInitialMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); StartSendInitialMetadata(); diff --git a/src/ray/common/ray_syncer/ray_syncer_server.h b/src/ray/common/ray_syncer/ray_syncer_server.h index ca548822da73..983dc3dfea9f 100644 --- a/src/ray/common/ray_syncer/ray_syncer_server.h +++ b/src/ray/common/ray_syncer/ray_syncer_server.h @@ -16,11 +16,13 @@ #include +#include #include #include "ray/common/ray_syncer/common.h" #include "ray/common/ray_syncer/ray_syncer_bidi_reactor.h" #include "ray/common/ray_syncer/ray_syncer_bidi_reactor_base.h" +#include "ray/rpc/authentication/authentication_token.h" namespace ray::syncer { @@ -35,7 +37,8 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase instrumented_io_context &io_context, const std::string &local_node_id, std::function)> message_processor, - std::function cleanup_cb); + std::function cleanup_cb, + const std::optional &auth_token); ~RayServerBidiReactor() override = default; @@ -49,6 +52,10 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase /// grpc callback context grpc::CallbackServerContext *server_context_; + + /// Authentication token for validation, will be empty if token authentication is disabled + std::optional auth_token_; + FRIEND_TEST(SyncerReactorTest, TestReactorFailure); }; diff --git a/src/ray/common/tests/ray_syncer_test.cc b/src/ray/common/tests/ray_syncer_test.cc index cb2b81579eb4..90ddd7f471cd 100644 --- a/src/ray/common/tests/ray_syncer_test.cc +++ b/src/ray/common/tests/ray_syncer_test.cc @@ -983,6 +983,161 @@ TEST_F(SyncerReactorTest, TestReactorFailure) { ASSERT_EQ(true, c_cleanup.second); } +// Authentication tests +class SyncerAuthenticationTest : public ::testing::Test { + protected: + void SetUp() override { + // Clear any existing environment variables + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + std::unique_ptr CreateAuthenticatedServer(const std::string &port, + const std::string &token) { + auto node_id = NodeID::FromRandom(); + auto server = std::make_unique(port); + + // Recreate server with authentication token + server->server->Shutdown(); + server->server->Wait(); + server->service.reset(); + + // Create new service with authentication token + server->service = std::make_unique( + *server->syncer, ray::rpc::AuthenticationToken(token)); + + auto server_address = BuildAddress("0.0.0.0", port); + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(server->service.get()); + server->server = builder.BuildAndStart(); + + return server; + } +}; + +TEST_F(SyncerAuthenticationTest, MatchingTokens) { + // Test that connections succeed when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + // Set client token via environment variable + setenv("RAY_AUTH_TOKEN", test_token.c_str(), 1); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server + auto server = CreateAuthenticatedServer("37892", test_token); + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37892"), + grpc::InsecureChannelCredentials()); + auto syncer = + std::make_unique(server->io_context, NodeID::FromRandom().Binary()); + + // Should connect successfully with matching token + syncer->Connect(NodeID::FromRandom().Binary(), channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection is established + ASSERT_GT(syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, MismatchedTokens) { + // Test that connections fail when client and server use different tokens + const std::string server_token = "server-token-12345"; + const std::string client_token = "different-client-token"; + + // Set client token via environment variable + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server with different token + auto server = CreateAuthenticatedServer("37893", server_token); + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37893"), + grpc::InsecureChannelCredentials()); + auto syncer = + std::make_unique(server->io_context, NodeID::FromRandom().Binary()); + + // Should fail to connect with mismatched token + syncer->Connect(NodeID::FromRandom().Binary(), channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection fails - no connected nodes + ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, ClientHasTokenServerDoesNot) { + // Test that connections fail when client has token but server doesn't require it + const std::string client_token = "client-token-12345"; + + // Set client token via environment variable + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create server without authentication + auto server = std::make_unique("37894"); + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37894"), + grpc::InsecureChannelCredentials()); + auto syncer = + std::make_unique(server->io_context, NodeID::FromRandom().Binary()); + + // Should connect successfully - server accepts any client when auth is not required + syncer->Connect(NodeID::FromRandom().Binary(), channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection is established + ASSERT_GT(syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, ServerHasTokenClientDoesNot) { + // Test that connections fail when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + + // Client has no token + unsetenv("RAY_AUTH_TOKEN"); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server + auto server = CreateAuthenticatedServer("37895", server_token); + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37895"), + grpc::InsecureChannelCredentials()); + auto syncer = + std::make_unique(server->io_context, NodeID::FromRandom().Binary()); + + // Should fail to connect without token + syncer->Connect(NodeID::FromRandom().Binary(), channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection fails - no connected nodes + ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); +} + +TEST_F(SyncerAuthenticationTest, MismatchedTokens) { + // Test that connections fail when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + const std::string client_token = "different-client-token"; + + // Client has incorrect token + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + + // Create authenticated server + auto server = CreateAuthenticatedServer("37895", server_token); + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37895"), + grpc::InsecureChannelCredentials()); + auto syncer = + std::make_unique(server->io_context, NodeID::FromRandom().Binary()); + + // Should fail to connect without token + syncer->Connect(NodeID::FromRandom().Binary(), channel); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Verify connection fails - no connected nodes + ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); +} + } // namespace syncer } // namespace ray diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 0f8043d35639..72f14b3a888f 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -614,7 +614,9 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get()); ray_syncer_->Register( syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get()); - rpc_server_.RegisterService(std::make_unique(*ray_syncer_)); + // Pass auth token from the RPC server to the syncer service + rpc_server_.RegisterService(std::make_unique( + *ray_syncer_, rpc_server_.GetAuthToken())); } void GcsServer::InitFunctionManager() { diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index 5d54c4c94a1b..c0ead5274e0c 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -22,6 +22,7 @@ #include #include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/rpc/authentication/authentication_token_loader.h" namespace ray { namespace pubsub { @@ -51,6 +52,11 @@ Status PythonGcsSubscriber::Subscribe() { } grpc::ClientContext context; + // Add authentication token + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(context); + } rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -78,6 +84,11 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) return Status::OK(); } current_polling_context_ = std::make_shared(); + // Add authentication token + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(*current_polling_context_); + } if (timeout_ms != -1) { current_polling_context_->set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d9e30d76917a..5274f357b655 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -254,8 +254,9 @@ NodeManager::NodeManager( // Run the node manager rpc server. node_manager_server_.RegisterService( std::make_unique(io_service, *this), false); - node_manager_server_.RegisterService( - std::make_unique(ray_syncer_)); + // Pass auth token from the RPC server to the syncer service + node_manager_server_.RegisterService(std::make_unique( + ray_syncer_, node_manager_server_.GetAuthToken())); node_manager_server_.Run(); // GCS will check the health of the service named with the node id. // Fail to setup this will lead to the health check failure. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index bf7eb7f8c5d1..df3070b01101 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -151,6 +151,8 @@ class GrpcServer { cluster_id_ = cluster_id; } + const std::optional &GetAuthToken() const { return auth_token_; } + protected: /// Initialize this server. void Init(); From b8bec0c11a4c1995f20b8e485ba7ec4b65de84e9 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 05:51:55 +0000 Subject: [PATCH 08/48] fix lint Signed-off-by: sampan --- src/ray/common/ray_syncer/ray_syncer.h | 3 ++- src/ray/common/ray_syncer/ray_syncer_server.cc | 1 - src/ray/common/ray_syncer/ray_syncer_server.h | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ray/common/ray_syncer/ray_syncer.h b/src/ray/common/ray_syncer/ray_syncer.h index d83ab65157d0..823d693b6740 100644 --- a/src/ray/common/ray_syncer/ray_syncer.h +++ b/src/ray/common/ray_syncer/ray_syncer.h @@ -208,7 +208,8 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { private: // The ray syncer this RPC wrappers of. RaySyncer &syncer_; - // Authentication token for validation, will be empty if token authentication is disabled + // Authentication token for validation, will be empty if token authentication is + // disabled std::optional auth_token_; }; diff --git a/src/ray/common/ray_syncer/ray_syncer_server.cc b/src/ray/common/ray_syncer/ray_syncer_server.cc index 65a6188d92f9..8fd28b0a2000 100644 --- a/src/ray/common/ray_syncer/ray_syncer_server.cc +++ b/src/ray/common/ray_syncer/ray_syncer_server.cc @@ -46,7 +46,6 @@ RayServerBidiReactor::RayServerBidiReactor( cleanup_cb_(std::move(cleanup_cb)), server_context_(server_context), auth_token_(auth_token) { - if (auth_token_.has_value() && !auth_token_->empty()) { // Validate authentication token const auto &metadata = server_context->client_metadata(); diff --git a/src/ray/common/ray_syncer/ray_syncer_server.h b/src/ray/common/ray_syncer/ray_syncer_server.h index 983dc3dfea9f..a570673efa62 100644 --- a/src/ray/common/ray_syncer/ray_syncer_server.h +++ b/src/ray/common/ray_syncer/ray_syncer_server.h @@ -53,7 +53,8 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase /// grpc callback context grpc::CallbackServerContext *server_context_; - /// Authentication token for validation, will be empty if token authentication is disabled + /// Authentication token for validation, will be empty if token authentication is + /// disabled std::optional auth_token_; FRIEND_TEST(SyncerReactorTest, TestReactorFailure); From 1ca6f2f39c37046f1c344d0b5356291a8217be48 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 05:55:29 +0000 Subject: [PATCH 09/48] add missing import Signed-off-by: sampan --- src/ray/common/ray_syncer/ray_syncer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ray/common/ray_syncer/ray_syncer.h b/src/ray/common/ray_syncer/ray_syncer.h index 823d693b6740..d935e7cb73b4 100644 --- a/src/ray/common/ray_syncer/ray_syncer.h +++ b/src/ray/common/ray_syncer/ray_syncer.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "absl/container/flat_hash_map.h" From e9cc57f0878efec5b0db7d4a936fbd9076d8dda9 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 05:58:18 +0000 Subject: [PATCH 10/48] address comments Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 4 +--- src/ray/rpc/server_call.h | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 785d65be741b..daeebff99e28 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,9 +180,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - if (cluster_id_auth_enabled && cluster_id_.IsNil()) { - RAY_LOG(FATAL) << "Expected cluster ID for cluster ID authentication!"; - } + RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( cqs_[i], &server_call_factories_, cluster_id_, auth_token_); diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index bb7a431934cf..b691cc52fe09 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -205,7 +205,7 @@ class ServerCallImpl : public ServerCall { bool cluster_id_auth_failed = false; // Token authentication - if (!ValidateBearerToken()) { + if (!ValidateAuthenticationToken()) { auth_success = false; token_auth_failed = true; } @@ -341,7 +341,7 @@ class ServerCallImpl : public ServerCall { /// Validates token-based authentication. /// Returns true if authentication succeeds or is not required. /// Returns false if authentication is required but fails. - bool ValidateBearerToken() { + bool ValidateAuthenticationToken() { if (!auth_token_.has_value() || auth_token_->empty()) { return true; // No auth required } From f8c08e0c3018de9e1f685894f2a6d7ad21df4494 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 06:26:24 +0000 Subject: [PATCH 11/48] fix lint Signed-off-by: sampan --- src/ray/rpc/grpc_server.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index daeebff99e28..781c2da790df 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,7 +180,8 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; + RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) + << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( cqs_[i], &server_call_factories_, cluster_id_, auth_token_); From a7a8efa42c3dfb7bd095f30b50a7eb17ae3d3801 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 26 Oct 2025 06:55:55 +0000 Subject: [PATCH 12/48] fix ci Signed-off-by: sampan --- src/ray/common/grpc_util.h | 4 ++++ src/ray/rpc/grpc_server.cc | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index ed2f8c73eda1..52858cca2207 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -110,6 +110,10 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { // status code. return {StatusCode::TimedOut, GrpcStatusToRayStatusMessage(grpc_status)}; } + if (grpc_status.error_code() == grpc::StatusCode::UNAUTHENTICATED) { + // UNAUTHENTICATED means authentication failed (e.g., wrong cluster ID). + return Status::Unauthenticated(GrpcStatusToRayStatusMessage(grpc_status)); + } if (grpc_status.error_code() == grpc::StatusCode::ABORTED) { // This is a status generated by ray code. // See RayStatusToGrpcStatus for details. diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 781c2da790df..5809cc005783 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -180,7 +180,7 @@ void GrpcServer::RegisterService(std::unique_ptr &&grpc_service) void GrpcServer::RegisterService(std::unique_ptr &&service, bool cluster_id_auth_enabled) { - RAY_CHECK(cluster_id_auth_enabled && cluster_id_.IsNil()) + RAY_CHECK(!cluster_id_auth_enabled || !cluster_id_.IsNil()) << "Expected cluster ID for cluster ID authentication!"; for (int i = 0; i < num_threads_; i++) { service->InitServerCallFactories( From 5910ecfca9f961744c1747553782ea4b081c0a86 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 02:59:48 +0000 Subject: [PATCH 13/48] fix build.bazel and imports Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 17 +++++++++++++++++ src/ray/rpc/tests/grpc_auth_token_tests.cc | 8 ++++---- src/ray/rpc/tests/grpc_server_client_test.cc | 8 ++++---- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index d5113ae0d3aa..279b68f91ba3 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -18,6 +18,23 @@ ray_cc_test( size = "small", srcs = [ "grpc_server_client_test.cc", + "grpc_test_common.h", + ], + tags = ["team:core"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_client", + "//src/ray/rpc:grpc_server", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "grpc_auth_token_tests", + size = "small", + srcs = [ + "grpc_auth_token_tests.cc", + "grpc_test_common.h", ], tags = ["team:core"], deps = [ diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 5501e6e568b0..1c55dfb511a5 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -19,11 +19,11 @@ #include #include "gtest/gtest.h" -#include "ray/protobuf/test_service.grpc.pb.h" #include "ray/rpc/authentication/authentication_token_loader.h" -#include "ray/rpc/grpc_client.h" -#include "ray/rpc/grpc_server.h" -#include "ray/rpc/tests/grpc_test_common.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/grpc_client.h" +#include "src/ray/rpc/grpc_server.h" +#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 8bc6e8284493..07e87c1a2f44 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -17,10 +17,10 @@ #include #include "gtest/gtest.h" -#include "ray/protobuf/test_service.grpc.pb.h" -#include "ray/rpc/grpc_client.h" -#include "ray/rpc/grpc_server.h" -#include "ray/rpc/tests/grpc_test_common.h" +#include "src/ray/protobuf/test_service.grpc.pb.h" +#include "src/ray/rpc/grpc_client.h" +#include "src/ray/rpc/grpc_server.h" +#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { From d36e22fd6c26814afda5c2ac665fb08d5badd4f0 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 03:05:22 +0000 Subject: [PATCH 14/48] fix lint Signed-off-by: sampan --- src/ray/rpc/tests/grpc_auth_token_tests.cc | 4 ++-- src/ray/rpc/tests/grpc_server_client_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 1c55dfb511a5..5feaf5563add 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -20,9 +20,9 @@ #include "gtest/gtest.h" #include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/grpc_client.h" -#include "src/ray/rpc/grpc_server.h" #include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 07e87c1a2f44..0e95fc9823a5 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -17,9 +17,9 @@ #include #include "gtest/gtest.h" +#include "ray/rpc/grpc_client.h" +#include "ray/rpc/grpc_server.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/grpc_client.h" -#include "src/ray/rpc/grpc_server.h" #include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { From 4063d743d06a8cc9a79acb2f3f02674c8ceea58b Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 03:13:18 +0000 Subject: [PATCH 15/48] fix lint issues Signed-off-by: sampan --- src/ray/rpc/tests/BUILD.bazel | 17 ++++++++++++++--- src/ray/rpc/tests/grpc_auth_token_tests.cc | 2 +- src/ray/rpc/tests/grpc_server_client_test.cc | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 279b68f91ba3..6f12ac30f65e 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "ray_cc_test") +load("//bazel:ray.bzl", "ray_cc_library", "ray_cc_test") ray_cc_test( name = "rpc_chaos_test", @@ -13,15 +13,25 @@ ray_cc_test( ], ) +ray_cc_library( + name = "grpc_test_common", + testonly = True, + hdrs = ["grpc_test_common.h"], + deps = [ + "//src/ray/protobuf:test_service_cc_grpc", + "//src/ray/rpc:grpc_server", + ], +) + ray_cc_test( name = "grpc_server_client_test", size = "small", srcs = [ "grpc_server_client_test.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ + ":grpc_test_common", "//src/ray/protobuf:test_service_cc_grpc", "//src/ray/rpc:grpc_client", "//src/ray/rpc:grpc_server", @@ -34,13 +44,14 @@ ray_cc_test( size = "small", srcs = [ "grpc_auth_token_tests.cc", - "grpc_test_common.h", ], tags = ["team:core"], deps = [ + ":grpc_test_common", "//src/ray/protobuf:test_service_cc_grpc", "//src/ray/rpc:grpc_client", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_googletest//:gtest_main", ], ) diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 5feaf5563add..4499b4c43129 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -22,8 +22,8 @@ #include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { diff --git a/src/ray/rpc/tests/grpc_server_client_test.cc b/src/ray/rpc/tests/grpc_server_client_test.cc index 0e95fc9823a5..09a168eac9b5 100644 --- a/src/ray/rpc/tests/grpc_server_client_test.cc +++ b/src/ray/rpc/tests/grpc_server_client_test.cc @@ -19,8 +19,8 @@ #include "gtest/gtest.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/grpc_server.h" +#include "ray/rpc/tests/grpc_test_common.h" #include "src/ray/protobuf/test_service.grpc.pb.h" -#include "src/ray/rpc/tests/grpc_test_common.h" namespace ray { namespace rpc { From 63273bdaf801ca4b9af1fcc6486628a13046a1cb Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 05:44:06 +0000 Subject: [PATCH 16/48] fix ray_syncer tests Signed-off-by: sampan --- src/ray/common/BUILD.bazel | 3 + src/ray/common/ray_syncer/ray_syncer.h | 1 + .../common/ray_syncer/ray_syncer_server.cc | 2 +- src/ray/common/tests/BUILD.bazel | 1 + src/ray/common/tests/ray_syncer_test.cc | 184 +++++++++++------- 5 files changed, 120 insertions(+), 71 deletions(-) diff --git a/src/ray/common/BUILD.bazel b/src/ray/common/BUILD.bazel index df407ae16d52..4be4ed2999b0 100644 --- a/src/ray/common/BUILD.bazel +++ b/src/ray/common/BUILD.bazel @@ -353,6 +353,9 @@ ray_cc_library( ":asio", ":id", "//:ray_syncer_cc_grpc", + "//src/ray/common:constants", + "//src/ray/rpc/authentication:authentication_token", + "//src/ray/rpc/authentication:authentication_token_loader", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/src/ray/common/ray_syncer/ray_syncer.h b/src/ray/common/ray_syncer/ray_syncer.h index d935e7cb73b4..6b9953ca1fb6 100644 --- a/src/ray/common/ray_syncer/ray_syncer.h +++ b/src/ray/common/ray_syncer/ray_syncer.h @@ -29,6 +29,7 @@ #include "ray/common/asio/periodical_runner.h" #include "ray/common/id.h" #include "ray/common/ray_syncer/common.h" +#include "ray/rpc/authentication/authentication_token.h" #include "src/ray/protobuf/ray_syncer.grpc.pb.h" namespace ray::syncer { diff --git a/src/ray/common/ray_syncer/ray_syncer_server.cc b/src/ray/common/ray_syncer/ray_syncer_server.cc index 8fd28b0a2000..a7d465466ae2 100644 --- a/src/ray/common/ray_syncer/ray_syncer_server.cc +++ b/src/ray/common/ray_syncer/ray_syncer_server.cc @@ -49,7 +49,7 @@ RayServerBidiReactor::RayServerBidiReactor( if (auth_token_.has_value() && !auth_token_->empty()) { // Validate authentication token const auto &metadata = server_context->client_metadata(); - auto it = metadata.find(ray::rpc::kAuthTokenKey); + auto it = metadata.find(kAuthTokenKey); if (it == metadata.end()) { RAY_LOG(WARNING) << "Missing authorization header in syncer connection from node " << NodeID::FromBinary(GetRemoteNodeID()); diff --git a/src/ray/common/tests/BUILD.bazel b/src/ray/common/tests/BUILD.bazel index 840a4a2b4f5a..5388384f4eb4 100644 --- a/src/ray/common/tests/BUILD.bazel +++ b/src/ray/common/tests/BUILD.bazel @@ -13,6 +13,7 @@ ray_cc_test( "//:ray_mock_syncer", "//src/ray/common:ray_syncer", "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", "//src/ray/util:network_util", "//src/ray/util:path_utils", "//src/ray/util:raii", diff --git a/src/ray/common/tests/ray_syncer_test.cc b/src/ray/common/tests/ray_syncer_test.cc index 90ddd7f471cd..e5083df93a59 100644 --- a/src/ray/common/tests/ray_syncer_test.cc +++ b/src/ray/common/tests/ray_syncer_test.cc @@ -37,6 +37,7 @@ #include "ray/common/ray_syncer/ray_syncer.h" #include "ray/common/ray_syncer/ray_syncer_client.h" #include "ray/common/ray_syncer/ray_syncer_server.h" +#include "ray/rpc/authentication/authentication_token.h" #include "ray/rpc/grpc_server.h" #include "ray/util/network_util.h" #include "ray/util/path_utils.h" @@ -840,8 +841,12 @@ struct MockRaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackServic io_context(_io_context) {} grpc::ServerBidiReactor *StartSync( grpc::CallbackServerContext *context) override { - reactor = new RayServerBidiReactor( - context, io_context, node_id.Binary(), message_processor, cleanup_cb); + reactor = new RayServerBidiReactor(context, + io_context, + node_id.Binary(), + message_processor, + cleanup_cb, + std::nullopt); return reactor; } @@ -987,38 +992,90 @@ TEST_F(SyncerReactorTest, TestReactorFailure) { class SyncerAuthenticationTest : public ::testing::Test { protected: void SetUp() override { - // Clear any existing environment variables + // Clear any existing environment variables and reset state unsetenv("RAY_AUTH_TOKEN"); ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + RayConfig::instance().auth_mode() = "disabled"; } void TearDown() override { unsetenv("RAY_AUTH_TOKEN"); ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); + RayConfig::instance().auth_mode() = "disabled"; } - std::unique_ptr CreateAuthenticatedServer(const std::string &port, - const std::string &token) { - auto node_id = NodeID::FromRandom(); - auto server = std::make_unique(port); + struct AuthenticatedSyncerServerTest { + std::string server_port; + instrumented_io_context io_context; + boost::asio::executor_work_guard work_guard; + std::unique_ptr thread; + std::unique_ptr syncer; + std::unique_ptr service; + std::unique_ptr server; + + AuthenticatedSyncerServerTest(const std::string &port, const std::string &token) + : server_port(port), work_guard(io_context.get_executor()) { + // Setup syncer and grpc server + syncer = std::make_unique(io_context, NodeID::FromRandom().Binary()); + thread = std::make_unique([this] { io_context.run(); }); + + // Create service with authentication token + service = std::make_unique( + *syncer, + token.empty() ? std::nullopt + : std::make_optional(ray::rpc::AuthenticationToken(token))); + + auto server_address = BuildAddress("0.0.0.0", port); + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + server = builder.BuildAndStart(); + } - // Recreate server with authentication token - server->server->Shutdown(); - server->server->Wait(); - server->service.reset(); + ~AuthenticatedSyncerServerTest() { + server->Shutdown(); + server->Wait(); + work_guard.reset(); + io_context.stop(); + thread->join(); + } + }; - // Create new service with authentication token - server->service = std::make_unique( - *server->syncer, ray::rpc::AuthenticationToken(token)); + std::unique_ptr CreateAuthenticatedServer( + const std::string &port, const std::string &token) { + return std::make_unique(port, token); + } - auto server_address = BuildAddress("0.0.0.0", port); - grpc::ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(server->service.get()); - server->server = builder.BuildAndStart(); + // Helper struct to manage client io_context and syncer + struct ClientSyncer { + instrumented_io_context io_context; + boost::asio::executor_work_guard work_guard; + std::thread thread; + std::unique_ptr syncer; + std::string remote_node_id; + + ClientSyncer() + : work_guard(boost::asio::make_work_guard(io_context.get_executor())), + thread([this]() { io_context.run(); }) { + syncer = std::make_unique(io_context, NodeID::FromRandom().Binary()); + remote_node_id = NodeID::FromRandom().Binary(); + } - return server; - } + ~ClientSyncer() { + if (syncer) { + syncer->Disconnect(remote_node_id); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + syncer.reset(); + } + work_guard.reset(); + io_context.stop(); + thread.join(); + } + + void Connect(const std::shared_ptr &channel) { + syncer->Connect(remote_node_id, channel); + } + }; }; TEST_F(SyncerAuthenticationTest, MatchingTokens) { @@ -1027,21 +1084,24 @@ TEST_F(SyncerAuthenticationTest, MatchingTokens) { // Set client token via environment variable setenv("RAY_AUTH_TOKEN", test_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create authenticated server auto server = CreateAuthenticatedServer("37892", test_token); + + // Create client with separate io_context + ClientSyncer client; auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37892"), grpc::InsecureChannelCredentials()); - auto syncer = - std::make_unique(server->io_context, NodeID::FromRandom().Binary()); // Should connect successfully with matching token - syncer->Connect(NodeID::FromRandom().Binary(), channel); + client.Connect(channel); std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Verify connection is established - ASSERT_GT(syncer->GetAllConnectedNodeIDs().size(), 0); + ASSERT_GT(client.syncer->GetAllConnectedNodeIDs().size(), 0); } TEST_F(SyncerAuthenticationTest, MismatchedTokens) { @@ -1051,91 +1111,75 @@ TEST_F(SyncerAuthenticationTest, MismatchedTokens) { // Set client token via environment variable setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create authenticated server with different token auto server = CreateAuthenticatedServer("37893", server_token); + + // Create client with separate io_context + ClientSyncer client; auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37893"), grpc::InsecureChannelCredentials()); - auto syncer = - std::make_unique(server->io_context, NodeID::FromRandom().Binary()); // Should fail to connect with mismatched token - syncer->Connect(NodeID::FromRandom().Binary(), channel); + client.Connect(channel); std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Verify connection fails - no connected nodes - ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); -} - -TEST_F(SyncerAuthenticationTest, ClientHasTokenServerDoesNot) { - // Test that connections fail when client has token but server doesn't require it - const std::string client_token = "client-token-12345"; - - // Set client token via environment variable - setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); - ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); - - // Create server without authentication - auto server = std::make_unique("37894"); - auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37894"), - grpc::InsecureChannelCredentials()); - auto syncer = - std::make_unique(server->io_context, NodeID::FromRandom().Binary()); - - // Should connect successfully - server accepts any client when auth is not required - syncer->Connect(NodeID::FromRandom().Binary(), channel); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // Verify connection is established - ASSERT_GT(syncer->GetAllConnectedNodeIDs().size(), 0); + ASSERT_EQ(client.syncer->GetAllConnectedNodeIDs().size(), 0); } TEST_F(SyncerAuthenticationTest, ServerHasTokenClientDoesNot) { // Test that connections fail when server requires token but client doesn't provide it const std::string server_token = "server-token-12345"; - // Client has no token + // Client has no token - auth mode is disabled (default from SetUp) unsetenv("RAY_AUTH_TOKEN"); ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create authenticated server auto server = CreateAuthenticatedServer("37895", server_token); + + // Create client with separate io_context + ClientSyncer client; auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37895"), grpc::InsecureChannelCredentials()); - auto syncer = - std::make_unique(server->io_context, NodeID::FromRandom().Binary()); // Should fail to connect without token - syncer->Connect(NodeID::FromRandom().Binary(), channel); + client.Connect(channel); std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Verify connection fails - no connected nodes - ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); + ASSERT_EQ(client.syncer->GetAllConnectedNodeIDs().size(), 0); } -TEST_F(SyncerAuthenticationTest, MismatchedTokens) { - // Test that connections fail when server requires token but client doesn't provide it - const std::string server_token = "server-token-12345"; +TEST_F(SyncerAuthenticationTest, ClientHasTokenServerDoesNotRequire) { + // Test that connections succeed when client has token but server doesn't require it + const std::string server_token = ""; const std::string client_token = "different-client-token"; - // Client has incorrect token + // Set client token setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + // Enable token authentication + RayConfig::instance().auth_mode() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); - // Create authenticated server - auto server = CreateAuthenticatedServer("37895", server_token); - auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37895"), + // Create server without authentication (empty token) + auto server = CreateAuthenticatedServer("37896", server_token); + + // Create client with separate io_context + ClientSyncer client; + auto channel = grpc::CreateChannel(BuildAddress("0.0.0.0", "37896"), grpc::InsecureChannelCredentials()); - auto syncer = - std::make_unique(server->io_context, NodeID::FromRandom().Binary()); - // Should fail to connect without token - syncer->Connect(NodeID::FromRandom().Binary(), channel); + // Should connect successfully - server accepts any client when auth is not required + client.Connect(channel); std::this_thread::sleep_for(std::chrono::milliseconds(100)); - // Verify connection fails - no connected nodes - ASSERT_EQ(syncer->GetAllConnectedNodeIDs().size(), 0); + // Verify connection is established + ASSERT_GT(client.syncer->GetAllConnectedNodeIDs().size(), 0); } } // namespace syncer From 12c7c04a1dc7c7509bd72c45c520a6565800080d Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 06:20:43 +0000 Subject: [PATCH 17/48] add pub-sub test Signed-off-by: sampan --- src/ray/gcs/BUILD.bazel | 1 + src/ray/gcs/gcs_server.cc | 4 +- src/ray/pubsub/tests/BUILD.bazel | 17 + .../tests/python_gcs_subscriber_auth_test.cc | 336 ++++++++++++++++++ src/ray/rpc/grpc_server.h | 2 - 5 files changed, 356 insertions(+), 4 deletions(-) create mode 100644 src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index b764ad410db6..db4a0b77614c 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -513,6 +513,7 @@ ray_cc_library( ":grpc_service_interfaces", ":grpc_services", ":metrics", + "//src/ray/authentication:authentication_token_loader", "//src/ray/core_worker_rpc_client:core_worker_client", "//src/ray/core_worker_rpc_client:core_worker_client_pool", "//src/ray/gcs/store_client", diff --git a/src/ray/gcs/gcs_server.cc b/src/ray/gcs/gcs_server.cc index 72f14b3a888f..b4392dbf9c73 100644 --- a/src/ray/gcs/gcs_server.cc +++ b/src/ray/gcs/gcs_server.cc @@ -39,6 +39,7 @@ #include "ray/observability/metric_constants.h" #include "ray/pubsub/publisher.h" #include "ray/raylet_rpc_client/raylet_client.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/stats/stats.h" #include "ray/util/network_util.h" @@ -614,9 +615,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get()); ray_syncer_->Register( syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get()); - // Pass auth token from the RPC server to the syncer service rpc_server_.RegisterService(std::make_unique( - *ray_syncer_, rpc_server_.GetAuthToken())); + *ray_syncer_, ray::rpc::AuthenticationTokenLoader::instance().GetToken())); } void GcsServer::InitFunctionManager() { diff --git a/src/ray/pubsub/tests/BUILD.bazel b/src/ray/pubsub/tests/BUILD.bazel index fc1d17ffda7b..395bb7c8b240 100644 --- a/src/ray/pubsub/tests/BUILD.bazel +++ b/src/ray/pubsub/tests/BUILD.bazel @@ -40,3 +40,20 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "python_gcs_subscriber_auth_test", + size = "small", + srcs = ["python_gcs_subscriber_auth_test.cc"], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/common:status", + "//src/ray/protobuf:gcs_service_cc_grpc", + "//src/ray/pubsub:python_gcs_subscriber", + "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc new file mode 100644 index 000000000000..a96f252ea81b --- /dev/null +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -0,0 +1,336 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" +#include "ray/pubsub/python_gcs_subscriber.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/gcs_service.grpc.pb.h" + +namespace ray { +namespace pubsub { + +// Mock implementation of InternalPubSubGcsService for testing authentication +class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService::Service { + public: + explicit MockInternalPubSubGcsService(bool should_accept_requests) + : should_accept_requests_(should_accept_requests) {} + + grpc::Status GcsSubscriberCommandBatch( + grpc::ServerContext *context, + const rpc::GcsSubscriberCommandBatchRequest *request, + rpc::GcsSubscriberCommandBatchReply *reply) override { + if (should_accept_requests_) { + subscribe_count_++; + return grpc::Status::OK; + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + grpc::Status GcsSubscriberPoll(grpc::ServerContext *context, + const rpc::GcsSubscriberPollRequest *request, + rpc::GcsSubscriberPollReply *reply) override { + if (should_accept_requests_) { + poll_count_++; + // Return empty response with publisher_id + reply->set_publisher_id("test-publisher"); + return grpc::Status::OK; + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + int subscribe_count() const { return subscribe_count_; } + int poll_count() const { return poll_count_; } + + private: + bool should_accept_requests_; + std::atomic subscribe_count_{0}; + std::atomic poll_count_{0}; +}; + +class PythonGcsSubscriberAuthTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication by default + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + unsetenv("RAY_AUTH_TOKEN"); + // Reset to default auth mode + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + // Start a GCS server with optional authentication token + void StartServer(const std::string &server_token, bool should_accept_requests = true) { + auto mock_service = + std::make_unique(should_accept_requests); + mock_service_ptr_ = mock_service.get(); + + std::optional auth_token; + if (!server_token.empty()) { + auth_token = rpc::AuthenticationToken(server_token); + } else { + // Empty token means no auth required + auth_token = std::nullopt; + } + + server_ = std::make_unique("test-gcs-server", + 0, // Random port + true, + 1, + 7200000, + auth_token); + + server_->RegisterService(std::move(mock_service)); + server_->Run(); + + // Wait for server to start + while (server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + server_port_ = server_->GetPort(); + } + + // Set client authentication token via environment variable + void SetClientToken(const std::string &client_token) { + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + } else { + unsetenv("RAY_AUTH_TOKEN"); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + std::unique_ptr CreateSubscriber() { + return std::make_unique("127.0.0.1", + server_port_, + rpc::ChannelType::RAY_LOG_CHANNEL, + "test-subscriber-id", + "test-worker-id"); + } + + std::unique_ptr server_; + MockInternalPubSubGcsService *mock_service_ptr_ = nullptr; + int server_port_ = 0; +}; + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokens) { + // Test that subscription succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed with matching tokens: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokens) { + // Test that subscription fails when client and server use different tokens + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) << "Subscribe should fail with mismatched tokens"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, ClientTokenServerNoAuth) { + // Test that subscription succeeds when client provides token but server doesn't require + // it + const std::string client_token = "client-token-12345"; + + StartServer(""); // Server doesn't require auth + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) + << "Subscribe should succeed when server doesn't require auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, ServerTokenClientNoAuth) { + // Test that subscription fails when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + + StartServer(server_token, false); // Server will reject requests without valid token + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) + << "Subscribe should fail when server requires token but client doesn't provide it"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensPoll) { + // Test that polling succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + + // Test polling with matching tokens - use very short timeout to avoid blocking + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should succeed (returns OK even on timeout or when no messages available) + ASSERT_TRUE(status.ok()) << "Poll should succeed with matching tokens: " + << status.ToString(); + // At least one poll should have been made + EXPECT_GE(mock_service_ptr_->poll_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { + // Test that polling fails when tokens don't match + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + + // Subscribe will fail, but let's try anyway + subscriber->Subscribe(); + + // Test polling with mismatched tokens - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + Status status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should fail with auth error or return OK if it was cancelled + // (OK is acceptable because the subscriber may have been closed) + if (!status.ok()) { + EXPECT_TRUE(status.IsInvalid() || status.IsRpcError()) + << "Status should be Invalid or RpcError: " << status.ToString(); + } + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { + // Test that closing/unsubscribing succeeds with matching tokens + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + + // Close should succeed with matching tokens + status = subscriber->Close(); + ASSERT_TRUE(status.ok()) << "Close should succeed with matching tokens: " + << status.ToString(); +} + +TEST_F(PythonGcsSubscriberAuthTest, NoAuthRequired) { + // Test that everything works when neither client nor server use auth + StartServer(""); // Server doesn't require auth + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed without auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + // Test polling without auth - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + ASSERT_TRUE(status.ok()) << "Poll should succeed without auth: " << status.ToString(); + + // Test close without auth + status = subscriber->Close(); + ASSERT_TRUE(status.ok()) << "Close should succeed without auth: " << status.ToString(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MultipleSubscribersMatchingTokens) { + // Test multiple subscribers with the same token + const std::string test_token = "shared-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber1 = CreateSubscriber(); + auto subscriber2 = CreateSubscriber(); + + Status status1 = subscriber1->Subscribe(); + Status status2 = subscriber2->Subscribe(); + + ASSERT_TRUE(status1.ok()) << "First subscriber should succeed: " << status1.ToString(); + ASSERT_TRUE(status2.ok()) << "Second subscriber should succeed: " << status2.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 2); + + subscriber1->Close(); + subscriber2->Close(); +} + +} // namespace pubsub +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index df3070b01101..bf7eb7f8c5d1 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -151,8 +151,6 @@ class GrpcServer { cluster_id_ = cluster_id; } - const std::optional &GetAuthToken() const { return auth_token_; } - protected: /// Initialize this server. void Init(); From ae9345b9edc2a023df3f738633297c3a3bd2486b Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 06:22:25 +0000 Subject: [PATCH 18/48] fix lint Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index a96f252ea81b..87780ca50f60 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include "gtest/gtest.h" #include "ray/common/ray_config.h" From 9ac5effc7fd591d9c0f2c5d5be9b52017b985546 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 06:34:30 +0000 Subject: [PATCH 19/48] [Core] Support token auth in ray Pub-Sub Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 11 + src/ray/pubsub/tests/BUILD.bazel | 17 + .../tests/python_gcs_subscriber_auth_test.cc | 337 ++++++++++++++++++ 3 files changed, 365 insertions(+) create mode 100644 src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index 5d54c4c94a1b..c0ead5274e0c 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -22,6 +22,7 @@ #include #include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/rpc/authentication/authentication_token_loader.h" namespace ray { namespace pubsub { @@ -51,6 +52,11 @@ Status PythonGcsSubscriber::Subscribe() { } grpc::ClientContext context; + // Add authentication token + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(context); + } rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -78,6 +84,11 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) return Status::OK(); } current_polling_context_ = std::make_shared(); + // Add authentication token + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(*current_polling_context_); + } if (timeout_ms != -1) { current_polling_context_->set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); diff --git a/src/ray/pubsub/tests/BUILD.bazel b/src/ray/pubsub/tests/BUILD.bazel index fc1d17ffda7b..395bb7c8b240 100644 --- a/src/ray/pubsub/tests/BUILD.bazel +++ b/src/ray/pubsub/tests/BUILD.bazel @@ -40,3 +40,20 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "python_gcs_subscriber_auth_test", + size = "small", + srcs = ["python_gcs_subscriber_auth_test.cc"], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/common:status", + "//src/ray/protobuf:gcs_service_cc_grpc", + "//src/ray/pubsub:python_gcs_subscriber", + "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc new file mode 100644 index 000000000000..87780ca50f60 --- /dev/null +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -0,0 +1,337 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" +#include "ray/pubsub/python_gcs_subscriber.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/gcs_service.grpc.pb.h" + +namespace ray { +namespace pubsub { + +// Mock implementation of InternalPubSubGcsService for testing authentication +class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService::Service { + public: + explicit MockInternalPubSubGcsService(bool should_accept_requests) + : should_accept_requests_(should_accept_requests) {} + + grpc::Status GcsSubscriberCommandBatch( + grpc::ServerContext *context, + const rpc::GcsSubscriberCommandBatchRequest *request, + rpc::GcsSubscriberCommandBatchReply *reply) override { + if (should_accept_requests_) { + subscribe_count_++; + return grpc::Status::OK; + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + grpc::Status GcsSubscriberPoll(grpc::ServerContext *context, + const rpc::GcsSubscriberPollRequest *request, + rpc::GcsSubscriberPollReply *reply) override { + if (should_accept_requests_) { + poll_count_++; + // Return empty response with publisher_id + reply->set_publisher_id("test-publisher"); + return grpc::Status::OK; + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + int subscribe_count() const { return subscribe_count_; } + int poll_count() const { return poll_count_; } + + private: + bool should_accept_requests_; + std::atomic subscribe_count_{0}; + std::atomic poll_count_{0}; +}; + +class PythonGcsSubscriberAuthTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication by default + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + unsetenv("RAY_AUTH_TOKEN"); + // Reset to default auth mode + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + // Start a GCS server with optional authentication token + void StartServer(const std::string &server_token, bool should_accept_requests = true) { + auto mock_service = + std::make_unique(should_accept_requests); + mock_service_ptr_ = mock_service.get(); + + std::optional auth_token; + if (!server_token.empty()) { + auth_token = rpc::AuthenticationToken(server_token); + } else { + // Empty token means no auth required + auth_token = std::nullopt; + } + + server_ = std::make_unique("test-gcs-server", + 0, // Random port + true, + 1, + 7200000, + auth_token); + + server_->RegisterService(std::move(mock_service)); + server_->Run(); + + // Wait for server to start + while (server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + server_port_ = server_->GetPort(); + } + + // Set client authentication token via environment variable + void SetClientToken(const std::string &client_token) { + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + } else { + unsetenv("RAY_AUTH_TOKEN"); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + std::unique_ptr CreateSubscriber() { + return std::make_unique("127.0.0.1", + server_port_, + rpc::ChannelType::RAY_LOG_CHANNEL, + "test-subscriber-id", + "test-worker-id"); + } + + std::unique_ptr server_; + MockInternalPubSubGcsService *mock_service_ptr_ = nullptr; + int server_port_ = 0; +}; + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokens) { + // Test that subscription succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed with matching tokens: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokens) { + // Test that subscription fails when client and server use different tokens + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) << "Subscribe should fail with mismatched tokens"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, ClientTokenServerNoAuth) { + // Test that subscription succeeds when client provides token but server doesn't require + // it + const std::string client_token = "client-token-12345"; + + StartServer(""); // Server doesn't require auth + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) + << "Subscribe should succeed when server doesn't require auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, ServerTokenClientNoAuth) { + // Test that subscription fails when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + + StartServer(server_token, false); // Server will reject requests without valid token + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) + << "Subscribe should fail when server requires token but client doesn't provide it"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensPoll) { + // Test that polling succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + + // Test polling with matching tokens - use very short timeout to avoid blocking + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should succeed (returns OK even on timeout or when no messages available) + ASSERT_TRUE(status.ok()) << "Poll should succeed with matching tokens: " + << status.ToString(); + // At least one poll should have been made + EXPECT_GE(mock_service_ptr_->poll_count(), 1); + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { + // Test that polling fails when tokens don't match + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + + // Subscribe will fail, but let's try anyway + subscriber->Subscribe(); + + // Test polling with mismatched tokens - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + Status status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should fail with auth error or return OK if it was cancelled + // (OK is acceptable because the subscriber may have been closed) + if (!status.ok()) { + EXPECT_TRUE(status.IsInvalid() || status.IsRpcError()) + << "Status should be Invalid or RpcError: " << status.ToString(); + } + + subscriber->Close(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { + // Test that closing/unsubscribing succeeds with matching tokens + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + + // Close should succeed with matching tokens + status = subscriber->Close(); + ASSERT_TRUE(status.ok()) << "Close should succeed with matching tokens: " + << status.ToString(); +} + +TEST_F(PythonGcsSubscriberAuthTest, NoAuthRequired) { + // Test that everything works when neither client nor server use auth + StartServer(""); // Server doesn't require auth + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed without auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + // Test polling without auth - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + ASSERT_TRUE(status.ok()) << "Poll should succeed without auth: " << status.ToString(); + + // Test close without auth + status = subscriber->Close(); + ASSERT_TRUE(status.ok()) << "Close should succeed without auth: " << status.ToString(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MultipleSubscribersMatchingTokens) { + // Test multiple subscribers with the same token + const std::string test_token = "shared-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber1 = CreateSubscriber(); + auto subscriber2 = CreateSubscriber(); + + Status status1 = subscriber1->Subscribe(); + Status status2 = subscriber2->Subscribe(); + + ASSERT_TRUE(status1.ok()) << "First subscriber should succeed: " << status1.ToString(); + ASSERT_TRUE(status2.ok()) << "Second subscriber should succeed: " << status2.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 2); + + subscriber1->Close(); + subscriber2->Close(); +} + +} // namespace pubsub +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 10eb3b027652645439e2ec259fe5c22c90496fe3 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 06:39:43 +0000 Subject: [PATCH 20/48] get rid of getToken() Signed-off-by: sampan --- src/ray/raylet/BUILD.bazel | 1 + src/ray/raylet/node_manager.cc | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ray/raylet/BUILD.bazel b/src/ray/raylet/BUILD.bazel index 0db9a2516d0b..a8486bdca4d8 100644 --- a/src/ray/raylet/BUILD.bazel +++ b/src/ray/raylet/BUILD.bazel @@ -244,6 +244,7 @@ ray_cc_library( ":worker", ":worker_killing_policy", ":worker_pool", + "//src/ray/authentication:authentication_token_loader", "//src/ray/common:buffer", "//src/ray/common:flatbuf_utils", "//src/ray/common:lease", diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 5274f357b655..29c5fd33d614 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -50,6 +50,7 @@ #include "ray/raylet/worker_killing_policy_group_by_owner.h" #include "ray/raylet/worker_pool.h" #include "ray/raylet_ipc_client/client_connection.h" +#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/stats/metric_defs.h" #include "ray/util/cmd_line_utils.h" #include "ray/util/event.h" @@ -256,7 +257,7 @@ NodeManager::NodeManager( std::make_unique(io_service, *this), false); // Pass auth token from the RPC server to the syncer service node_manager_server_.RegisterService(std::make_unique( - ray_syncer_, node_manager_server_.GetAuthToken())); + ray_syncer_, ray::rpc::AuthenticationTokenLoader::instance().GetToken())); node_manager_server_.Run(); // GCS will check the health of the service named with the node id. // Fail to setup this will lead to the health check failure. From e3b8c3f96111cc614a1ef1b5ef403606eec145be Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 27 Oct 2025 06:55:05 +0000 Subject: [PATCH 21/48] address comments Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 20 +++++++++---------- src/ray/pubsub/python_gcs_subscriber.h | 4 ++++ .../tests/python_gcs_subscriber_auth_test.cc | 13 +++++++++++- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index c0ead5274e0c..c4b5ae762e9b 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -52,11 +52,7 @@ Status PythonGcsSubscriber::Subscribe() { } grpc::ClientContext context; - // Add authentication token - auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); - if (auth_token.has_value() && !auth_token->empty()) { - auth_token->SetMetadata(context); - } + SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -84,11 +80,7 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) return Status::OK(); } current_polling_context_ = std::make_shared(); - // Add authentication token - auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); - if (auth_token.has_value() && !auth_token->empty()) { - auth_token->SetMetadata(*current_polling_context_); - } + SetAuthenticationToken(*current_polling_context_); if (timeout_ms != -1) { current_polling_context_->set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); @@ -184,6 +176,7 @@ Status PythonGcsSubscriber::Close() { } grpc::ClientContext context; + SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -206,5 +199,12 @@ int64_t PythonGcsSubscriber::last_batch_size() { return last_batch_size_; } +void PythonGcsSubscriber::SetAuthenticationToken(grpc::ClientContext &context) { + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(context); + } +} + } // namespace pubsub } // namespace ray diff --git a/src/ray/pubsub/python_gcs_subscriber.h b/src/ray/pubsub/python_gcs_subscriber.h index 5fe4eda29812..e8aeaa116566 100644 --- a/src/ray/pubsub/python_gcs_subscriber.h +++ b/src/ray/pubsub/python_gcs_subscriber.h @@ -80,6 +80,10 @@ class RAY_EXPORT PythonGcsSubscriber { std::deque queue_ ABSL_GUARDED_BY(mu_); bool closed_ ABSL_GUARDED_BY(mu_) = false; std::shared_ptr current_polling_context_ ABSL_GUARDED_BY(mu_); + + // Set authentication token on a gRPC client context if token-based authentication is + // enabled + void SetAuthenticationToken(grpc::ClientContext &context); }; /// Get the .lines() attribute of a LogBatch as a std::vector diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index 87780ca50f60..10648f5490c2 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -40,7 +40,13 @@ class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService: const rpc::GcsSubscriberCommandBatchRequest *request, rpc::GcsSubscriberCommandBatchReply *reply) override { if (should_accept_requests_) { - subscribe_count_++; + for (const auto &command : request->commands()) { + if (command.has_subscribe_message()) { + subscribe_count_++; + } else if (command.has_unsubscribe_message()) { + unsubscribe_count_++; + } + } return grpc::Status::OK; } else { return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); @@ -62,11 +68,13 @@ class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService: int subscribe_count() const { return subscribe_count_; } int poll_count() const { return poll_count_; } + int unsubscribe_count() const { return unsubscribe_count_; } private: bool should_accept_requests_; std::atomic subscribe_count_{0}; std::atomic poll_count_{0}; + std::atomic unsubscribe_count_{0}; }; class PythonGcsSubscriberAuthTest : public ::testing::Test { @@ -277,11 +285,14 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { auto subscriber = CreateSubscriber(); Status status = subscriber->Subscribe(); ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); // Close should succeed with matching tokens status = subscriber->Close(); ASSERT_TRUE(status.ok()) << "Close should succeed with matching tokens: " << status.ToString(); + // This assertion will fail until auth is added to `Close()` and the mock is updated. + EXPECT_EQ(mock_service_ptr_->unsubscribe_count(), 1); } TEST_F(PythonGcsSubscriberAuthTest, NoAuthRequired) { From cd0f9338bd94c8dd11819daf16159f83cc399df4 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 27 Oct 2025 09:51:34 -0500 Subject: [PATCH 22/48] fix Signed-off-by: Edward Oakes --- .../tests/python_gcs_subscriber_auth_test.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index 10648f5490c2..c32c9c4a0960 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -166,7 +166,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokens) { << status.ToString(); EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokens) { @@ -183,7 +183,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokens) { ASSERT_FALSE(status.ok()) << "Subscribe should fail with mismatched tokens"; EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, ClientTokenServerNoAuth) { @@ -202,7 +202,7 @@ TEST_F(PythonGcsSubscriberAuthTest, ClientTokenServerNoAuth) { << status.ToString(); EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, ServerTokenClientNoAuth) { @@ -219,7 +219,7 @@ TEST_F(PythonGcsSubscriberAuthTest, ServerTokenClientNoAuth) { << "Subscribe should fail when server requires token but client doesn't provide it"; EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensPoll) { @@ -244,7 +244,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensPoll) { // At least one poll should have been made EXPECT_GE(mock_service_ptr_->poll_count(), 1); - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { @@ -258,7 +258,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { auto subscriber = CreateSubscriber(); // Subscribe will fail, but let's try anyway - subscriber->Subscribe(); + ASSERT_TRUE(subscriber->Subscribe().ok()); // Test polling with mismatched tokens - use very short timeout std::string key_id; @@ -272,7 +272,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { << "Status should be Invalid or RpcError: " << status.ToString(); } - subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); } TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { @@ -288,7 +288,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); // Close should succeed with matching tokens - status = subscriber->Close(); + ASSERT_TRUE(subscriber->Close().ok()); ASSERT_TRUE(status.ok()) << "Close should succeed with matching tokens: " << status.ToString(); // This assertion will fail until auth is added to `Close()` and the mock is updated. @@ -335,8 +335,8 @@ TEST_F(PythonGcsSubscriberAuthTest, MultipleSubscribersMatchingTokens) { ASSERT_TRUE(status2.ok()) << "Second subscriber should succeed: " << status2.ToString(); EXPECT_EQ(mock_service_ptr_->subscribe_count(), 2); - subscriber1->Close(); - subscriber2->Close(); + ASSERT_TRUE(subscriber1->Close().ok()); + ASSERT_TRUE(subscriber2->Close().ok()); } } // namespace pubsub From 199d18e4ccaa3df45c04a3ebb7126bcb4248ce01 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 27 Oct 2025 09:58:25 -0500 Subject: [PATCH 23/48] fix Signed-off-by: Edward Oakes --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index c32c9c4a0960..b9e67424fb29 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -258,7 +258,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { auto subscriber = CreateSubscriber(); // Subscribe will fail, but let's try anyway - ASSERT_TRUE(subscriber->Subscribe().ok()); + ASSERT_FALSE(subscriber->Subscribe().ok()); // Test polling with mismatched tokens - use very short timeout std::string key_id; @@ -288,8 +288,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); // Close should succeed with matching tokens - ASSERT_TRUE(subscriber->Close().ok()); - ASSERT_TRUE(status.ok()) << "Close should succeed with matching tokens: " + ASSERT_TRUE(subscriber->Close().ok()) << "Close should succeed with matching tokens: " << status.ToString(); // This assertion will fail until auth is added to `Close()` and the mock is updated. EXPECT_EQ(mock_service_ptr_->unsubscribe_count(), 1); From 537e90aedbd7fe14e52dcc80039acf7cb97fc61b Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 27 Oct 2025 09:58:30 -0500 Subject: [PATCH 24/48] fix Signed-off-by: Edward Oakes --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index b9e67424fb29..32b211044388 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -288,8 +288,8 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); // Close should succeed with matching tokens - ASSERT_TRUE(subscriber->Close().ok()) << "Close should succeed with matching tokens: " - << status.ToString(); + ASSERT_TRUE(subscriber->Close().ok()) + << "Close should succeed with matching tokens: " << status.ToString(); // This assertion will fail until auth is added to `Close()` and the mock is updated. EXPECT_EQ(mock_service_ptr_->unsubscribe_count(), 1); } From 3bc34f2b7923c290df53de7348bc8e7ae0196063 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 03:38:37 +0000 Subject: [PATCH 25/48] [Core] Introduce new macros for user facing exceptions Signed-off-by: sampan --- src/ray/util/logging.cc | 11 ++++++++--- src/ray/util/logging.h | 9 ++++++++- src/ray/util/tests/logging_test.cc | 25 +++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 0b4c4155a5f6..02c8b90e501d 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -269,6 +269,8 @@ static spdlog::level::level_enum GetMappedSeverity(RayLogLevel severity) { return spdlog::level::warn; case RayLogLevel::ERROR: return spdlog::level::err; + case RayLogLevel::USER_FATAL: + return spdlog::level::critical; case RayLogLevel::FATAL: return spdlog::level::critical; default: @@ -530,7 +532,7 @@ void RayLog::AddFatalLogCallbacks( RayLog::RayLog(const char *file_name, int line_number, RayLogLevel severity) : is_enabled_(severity >= severity_threshold_), severity_(severity), - is_fatal_(severity == RayLogLevel::FATAL) { + is_fatal_(severity == RayLogLevel::FATAL || severity == RayLogLevel::USER_FATAL) { if (is_fatal_) { #ifdef _WIN32 int pid = _getpid(); @@ -569,8 +571,11 @@ RayLog::~RayLog() { if (IsFatal()) { msg_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); expose_fatal_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); + const char *callback_label = (severity_ == RayLogLevel::USER_FATAL) + ? "RAY_USER_ERROR" + : "RAY_FATAL_CHECK_FAILED"; for (const auto &callback : fatal_log_callbacks_) { - callback("RAY_FATAL_CHECK_FAILED", expose_fatal_osstream_.str()); + callback(callback_label, expose_fatal_osstream_.str()); } } @@ -593,7 +598,7 @@ RayLog::~RayLog() { } logger->flush(); - if (severity_ == RayLogLevel::FATAL) { + if (severity_ == RayLogLevel::FATAL || severity_ == RayLogLevel::USER_FATAL) { std::_Exit(EXIT_FAILURE); } } diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index 405d772b57ef..10b86883d289 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -122,7 +122,8 @@ enum class RayLogLevel { INFO = 0, WARNING = 1, ERROR = 2, - FATAL = 3 + USER_FATAL = 3, // User-facing fatal errors (config/usage errors) + FATAL = 4 // Internal fatal errors (bugs in Ray) }; #define RAY_LOG_INTERNAL(level) ::ray::RayLog(__FILE__, __LINE__, level) @@ -151,6 +152,12 @@ enum class RayLogLevel { #define RAY_CHECK(condition) RAY_CHECK_WITH_DISPLAY(condition, #condition) +// User-facing fatal check without "bug in Ray" message +#define RAY_USER_CHECK(condition) \ + RAY_PREDICT_TRUE((condition)) \ + ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::USER_FATAL) + #ifdef NDEBUG #define RAY_DCHECK(condition) \ diff --git a/src/ray/util/tests/logging_test.cc b/src/ray/util/tests/logging_test.cc index ab82497afccf..11527ea3be66 100644 --- a/src/ray/util/tests/logging_test.cc +++ b/src/ray/util/tests/logging_test.cc @@ -379,6 +379,31 @@ TEST(PrintLogTest, TestFailureSignalHandler) { ASSERT_DEATH(abort(), ".*SIGABRT received.*"); } +TEST(UserErrorTest, TestUserCheck) { + // RAY_USER_CHECK should not exit the application when condition is true + RAY_USER_CHECK(true) << "This should not trigger"; + // RAY_USER_CHECK should exit the application when condition is false + // and should NOT contain "bug in Ray" message + ASSERT_DEATH( + { RAY_USER_CHECK(false) << "Custom user error message"; }, + testing::AllOf( + testing::HasSubstr("Custom user error message"), + testing::Not(testing::HasSubstr("bug in Ray")) + ) + ); +} + +TEST(UserErrorTest, TestUserError) { + // RAY_USER_ERROR should always fail with user message + ASSERT_DEATH( + { RAY_USER_ERROR() << "Token authentication is enabled but token is missing"; }, + testing::AllOf( + testing::HasSubstr("Token authentication is enabled but token is missing"), + testing::Not(testing::HasSubstr("bug in Ray")) + ) + ); +} + } // namespace ray int main(int argc, char **argv) { From e5b90babb840164846c4740df8546d0c591f8bca Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 03:38:56 +0000 Subject: [PATCH 26/48] fix lint issues Signed-off-by: sampan --- src/ray/util/tests/logging_test.cc | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/ray/util/tests/logging_test.cc b/src/ray/util/tests/logging_test.cc index 11527ea3be66..9781c3e5584d 100644 --- a/src/ray/util/tests/logging_test.cc +++ b/src/ray/util/tests/logging_test.cc @@ -384,24 +384,18 @@ TEST(UserErrorTest, TestUserCheck) { RAY_USER_CHECK(true) << "This should not trigger"; // RAY_USER_CHECK should exit the application when condition is false // and should NOT contain "bug in Ray" message - ASSERT_DEATH( - { RAY_USER_CHECK(false) << "Custom user error message"; }, - testing::AllOf( - testing::HasSubstr("Custom user error message"), - testing::Not(testing::HasSubstr("bug in Ray")) - ) - ); + ASSERT_DEATH({ RAY_USER_CHECK(false) << "Custom user error message"; }, + testing::AllOf(testing::HasSubstr("Custom user error message"), + testing::Not(testing::HasSubstr("bug in Ray")))); } TEST(UserErrorTest, TestUserError) { // RAY_USER_ERROR should always fail with user message ASSERT_DEATH( - { RAY_USER_ERROR() << "Token authentication is enabled but token is missing"; }, - testing::AllOf( - testing::HasSubstr("Token authentication is enabled but token is missing"), - testing::Not(testing::HasSubstr("bug in Ray")) - ) - ); + { RAY_USER_ERROR() << "Token authentication is enabled but token is missing"; }, + testing::AllOf( + testing::HasSubstr("Token authentication is enabled but token is missing"), + testing::Not(testing::HasSubstr("bug in Ray")))); } } // namespace ray From acd95ac19882ef4af41f2f0ff8ee3d11877a6759 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 03:57:47 +0000 Subject: [PATCH 27/48] dont print stack trace for user errors Signed-off-by: sampan --- src/ray/util/logging.cc | 7 +++++-- src/ray/util/tests/logging_test.cc | 18 +++++------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 02c8b90e501d..4195ef5eb3d8 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -569,8 +569,11 @@ bool RayLog::IsFatal() const { return is_fatal_; } RayLog::~RayLog() { if (IsFatal()) { - msg_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); - expose_fatal_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); + // Only add stack trace for internal fatal errors, not user-facing errors + if (severity_ == RayLogLevel::FATAL) { + msg_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); + expose_fatal_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); + } const char *callback_label = (severity_ == RayLogLevel::USER_FATAL) ? "RAY_USER_ERROR" : "RAY_FATAL_CHECK_FAILED"; diff --git a/src/ray/util/tests/logging_test.cc b/src/ray/util/tests/logging_test.cc index 9781c3e5584d..752c8b10579b 100644 --- a/src/ray/util/tests/logging_test.cc +++ b/src/ray/util/tests/logging_test.cc @@ -379,23 +379,15 @@ TEST(PrintLogTest, TestFailureSignalHandler) { ASSERT_DEATH(abort(), ".*SIGABRT received.*"); } -TEST(UserErrorTest, TestUserCheck) { +TEST(PrintLogTest, TestUserCheck) { // RAY_USER_CHECK should not exit the application when condition is true RAY_USER_CHECK(true) << "This should not trigger"; // RAY_USER_CHECK should exit the application when condition is false - // and should NOT contain "bug in Ray" message - ASSERT_DEATH({ RAY_USER_CHECK(false) << "Custom user error message"; }, - testing::AllOf(testing::HasSubstr("Custom user error message"), - testing::Not(testing::HasSubstr("bug in Ray")))); -} - -TEST(UserErrorTest, TestUserError) { - // RAY_USER_ERROR should always fail with user message ASSERT_DEATH( - { RAY_USER_ERROR() << "Token authentication is enabled but token is missing"; }, - testing::AllOf( - testing::HasSubstr("Token authentication is enabled but token is missing"), - testing::Not(testing::HasSubstr("bug in Ray")))); + { RAY_USER_CHECK(false) << "Custom user error message"; }, + testing::AllOf(testing::HasSubstr("Custom user error message"), + testing::Not(testing::HasSubstr("bug in Ray")), + testing::Not(testing::HasSubstr("StackTrace Information")))); } } // namespace ray From c5be15ff4748b6ee0e4803a7f4eca371a0ab8999 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:04:57 +0000 Subject: [PATCH 28/48] address comments Signed-off-by: sampan --- src/ray/util/logging.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 4195ef5eb3d8..70ce5908645f 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -297,6 +297,8 @@ void RayLog::InitSeverityThreshold(RayLogLevel severity_threshold) { severity_threshold = RayLogLevel::WARNING; } else if (data == "error") { severity_threshold = RayLogLevel::ERROR; + } else if (data == "user_fatal") { + severity_threshold = RayLogLevel::USER_FATAL; } else if (data == "fatal") { severity_threshold = RayLogLevel::FATAL; } else { @@ -601,7 +603,7 @@ RayLog::~RayLog() { } logger->flush(); - if (severity_ == RayLogLevel::FATAL || severity_ == RayLogLevel::USER_FATAL) { + if (IsFatal()) { std::_Exit(EXIT_FAILURE); } } From 2698b8deb99217c6979f0dc5061143687a053b8a Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:20:44 +0000 Subject: [PATCH 29/48] use RAY_USER_CHECK instead of RAY_CHECK + fix test Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 1 + .../rpc/authentication/authentication_token_loader.cc | 9 +++++---- src/ray/rpc/tests/authentication_token_loader_test.cc | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index 32b211044388..b42ad5f98e6d 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -107,6 +107,7 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { auth_token = rpc::AuthenticationToken(server_token); } else { // Empty token means no auth required + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); auth_token = std::nullopt; } diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 621f28fe351c..d2e25f978442 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -57,8 +57,9 @@ std::optional AuthenticationTokenLoader::GetToken() { AuthenticationToken token = LoadTokenFromSources(); // If no token found and auth is enabled, fail with RAY_CHECK - RAY_CHECK(!token.empty()) - << "Token authentication is enabled but Ray couldn't find an authentication token. " + RAY_USER_CHECK(!token.empty()) + << "Ray Setup Error: Token authentication is enabled but Ray couldn't find an " + "authentication token. " << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " "point to a file with the token, " << "or create a token file at ~/.ray/auth_token."; @@ -100,8 +101,8 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { std::string path_str(env_token_path); if (!path_str.empty()) { std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); - RAY_CHECK(!token_str.empty()) - << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " + RAY_USER_CHECK(!token_str.empty()) + << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " << path_str; RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 2332c6d09313..b758acd64c05 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -299,7 +299,8 @@ TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { auto &loader = AuthenticationTokenLoader::instance(); loader.GetToken(); }, - "Token authentication is enabled but Ray couldn't find an authentication token."); + "Ray Setup Error: Token authentication is enabled but Ray couldn't find an " + "authentication token."); } TEST_F(AuthenticationTokenLoaderTest, TestCaching) { From 8572c0120c969ce5c33833c6e23e560517bdb8d4 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:21:08 +0000 Subject: [PATCH 30/48] fix lint Signed-off-by: sampan --- src/ray/rpc/authentication/authentication_token_loader.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index d2e25f978442..7e3903423890 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -101,9 +101,9 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { std::string path_str(env_token_path); if (!path_str.empty()) { std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); - RAY_USER_CHECK(!token_str.empty()) - << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " - << path_str; + RAY_USER_CHECK(!token_str.empty()) << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set " + "but file cannot be opened or is empty: " + << path_str; RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); } From d47ae2bc0f19c8f085e34e2e9af2e5db8bed18cd Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:25:22 +0000 Subject: [PATCH 31/48] improve test Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index b42ad5f98e6d..c25d4d50f429 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -107,8 +107,7 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { auth_token = rpc::AuthenticationToken(server_token); } else { // Empty token means no auth required - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); - auth_token = std::nullopt; + auth_token = rpc::AuthenticationToken("");; } server_ = std::make_unique("test-gcs-server", From d054131bdbb632564392ab7cb1c79b8268de1850 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:25:39 +0000 Subject: [PATCH 32/48] fix lint Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index c25d4d50f429..7089b1e9b65d 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -107,7 +107,8 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { auth_token = rpc::AuthenticationToken(server_token); } else { // Empty token means no auth required - auth_token = rpc::AuthenticationToken("");; + auth_token = rpc::AuthenticationToken(""); + ; } server_ = std::make_unique("test-gcs-server", From 2ee555599f540cab4d73aa951a244cd8fcc41f11 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 04:26:28 +0000 Subject: [PATCH 33/48] fix lint Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index 7089b1e9b65d..b99e1b2353b4 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -108,7 +108,6 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { } else { // Empty token means no auth required auth_token = rpc::AuthenticationToken(""); - ; } server_ = std::make_unique("test-gcs-server", From 06d1773b4e59ce9d0948146361dec6b8705d9967 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 08:55:36 +0000 Subject: [PATCH 34/48] fix builds Signed-off-by: sampan --- src/ray/gcs/BUILD.bazel | 2 +- src/ray/raylet/BUILD.bazel | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index db4a0b77614c..09c96f5a6c1e 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -513,7 +513,7 @@ ray_cc_library( ":grpc_service_interfaces", ":grpc_services", ":metrics", - "//src/ray/authentication:authentication_token_loader", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/core_worker_rpc_client:core_worker_client", "//src/ray/core_worker_rpc_client:core_worker_client_pool", "//src/ray/gcs/store_client", diff --git a/src/ray/raylet/BUILD.bazel b/src/ray/raylet/BUILD.bazel index a8486bdca4d8..586a0d9d9958 100644 --- a/src/ray/raylet/BUILD.bazel +++ b/src/ray/raylet/BUILD.bazel @@ -244,7 +244,7 @@ ray_cc_library( ":worker", ":worker_killing_policy", ":worker_pool", - "//src/ray/authentication:authentication_token_loader", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/common:buffer", "//src/ray/common:flatbuf_utils", "//src/ray/common:lease", From cc69ae33f957fbb8c168ed35cdebbd93ffa3e010 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 08:56:05 +0000 Subject: [PATCH 35/48] lint Signed-off-by: sampan --- src/ray/gcs/BUILD.bazel | 2 +- src/ray/raylet/BUILD.bazel | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/gcs/BUILD.bazel b/src/ray/gcs/BUILD.bazel index 09c96f5a6c1e..0b6315b935ac 100644 --- a/src/ray/gcs/BUILD.bazel +++ b/src/ray/gcs/BUILD.bazel @@ -513,7 +513,6 @@ ray_cc_library( ":grpc_service_interfaces", ":grpc_services", ":metrics", - "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/core_worker_rpc_client:core_worker_client", "//src/ray/core_worker_rpc_client:core_worker_client_pool", "//src/ray/gcs/store_client", @@ -530,6 +529,7 @@ ray_cc_library( "//src/ray/raylet_rpc_client:raylet_client_pool", "//src/ray/rpc:grpc_server", "//src/ray/rpc:metrics_agent_client", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:counter_map", "//src/ray/util:exponential_backoff", "//src/ray/util:network_util", diff --git a/src/ray/raylet/BUILD.bazel b/src/ray/raylet/BUILD.bazel index 586a0d9d9958..3090f94aa001 100644 --- a/src/ray/raylet/BUILD.bazel +++ b/src/ray/raylet/BUILD.bazel @@ -244,7 +244,6 @@ ray_cc_library( ":worker", ":worker_killing_policy", ":worker_pool", - "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/common:buffer", "//src/ray/common:flatbuf_utils", "//src/ray/common:lease", @@ -261,6 +260,7 @@ ray_cc_library( "//src/ray/raylet/scheduling:scheduler", "//src/ray/rpc:node_manager_server", "//src/ray/rpc:rpc_callback_types", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/stats:stats_lib", "//src/ray/util:cmd_line_utils", "//src/ray/util:container_util", From 9f0a56355ef83a2bc41f98471578451925c53ae2 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 15:29:11 +0000 Subject: [PATCH 36/48] attempt to fix test Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index c4b5ae762e9b..0db5445fe3a1 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -75,7 +75,7 @@ Status PythonGcsSubscriber::Subscribe() { Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) { absl::MutexLock lock(&mu_); - while (queue_.empty()) { + if (queue_.empty()) { if (closed_) { return Status::OK(); } @@ -135,6 +135,11 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) } } + if (queue_.empty()) { + // No messages available after polling + return Status::OK(); + } + *message = std::move(queue_.front()); queue_.pop_front(); From 15aa5e2b2341bf1feda87b073058b3df54adb157 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 28 Oct 2025 15:29:11 +0000 Subject: [PATCH 37/48] attempt to fix test Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index c4b5ae762e9b..0db5445fe3a1 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -75,7 +75,7 @@ Status PythonGcsSubscriber::Subscribe() { Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) { absl::MutexLock lock(&mu_); - while (queue_.empty()) { + if (queue_.empty()) { if (closed_) { return Status::OK(); } @@ -135,6 +135,11 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) } } + if (queue_.empty()) { + // No messages available after polling + return Status::OK(); + } + *message = std::move(queue_.front()); queue_.pop_front(); From 1c600e6d348130e846941a8f146b492645a250f8 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 29 Oct 2025 05:33:50 +0000 Subject: [PATCH 38/48] fix test Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 7 +------ .../pubsub/tests/python_gcs_subscriber_auth_test.cc | 11 ++++++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index 0db5445fe3a1..c4b5ae762e9b 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -75,7 +75,7 @@ Status PythonGcsSubscriber::Subscribe() { Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) { absl::MutexLock lock(&mu_); - if (queue_.empty()) { + while (queue_.empty()) { if (closed_) { return Status::OK(); } @@ -135,11 +135,6 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) } } - if (queue_.empty()) { - // No messages available after polling - return Status::OK(); - } - *message = std::move(queue_.front()); queue_.pop_front(); diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index b99e1b2353b4..a807cab479a9 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -58,9 +58,14 @@ class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService: rpc::GcsSubscriberPollReply *reply) override { if (should_accept_requests_) { poll_count_++; - // Return empty response with publisher_id - reply->set_publisher_id("test-publisher"); - return grpc::Status::OK; + // Simulate long polling: block until deadline expires since we have no messages + // Real server would hold the connection open until messages arrive or timeout + auto deadline = context->deadline(); + std::this_thread::sleep_until(deadline); + + // Return deadline exceeded (timeout) with empty messages + // This simulates the real server behavior when no messages are published + return grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, "Long poll timeout"); } else { return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); } From f23ea2e53115ac02939d5893fb0ae1bd0af6c032 Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 02:34:43 +0000 Subject: [PATCH 39/48] revert logging changes Signed-off-by: sampan --- .../authentication_token_loader.cc | 8 ++++---- src/ray/util/logging.cc | 20 +++++-------------- src/ray/util/logging.h | 9 +-------- src/ray/util/tests/logging_test.cc | 11 ---------- 4 files changed, 10 insertions(+), 38 deletions(-) diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 7e3903423890..1c8082c5a4a5 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -57,7 +57,7 @@ std::optional AuthenticationTokenLoader::GetToken() { AuthenticationToken token = LoadTokenFromSources(); // If no token found and auth is enabled, fail with RAY_CHECK - RAY_USER_CHECK(!token.empty()) + RAY_CHECK(!token.empty()) << "Ray Setup Error: Token authentication is enabled but Ray couldn't find an " "authentication token. " << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " @@ -101,9 +101,9 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { std::string path_str(env_token_path); if (!path_str.empty()) { std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); - RAY_USER_CHECK(!token_str.empty()) << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set " - "but file cannot be opened or is empty: " - << path_str; + RAY_CHECK(!token_str.empty()) << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set " + "but file cannot be opened or is empty: " + << path_str; RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); } diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 70ce5908645f..0b4c4155a5f6 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -269,8 +269,6 @@ static spdlog::level::level_enum GetMappedSeverity(RayLogLevel severity) { return spdlog::level::warn; case RayLogLevel::ERROR: return spdlog::level::err; - case RayLogLevel::USER_FATAL: - return spdlog::level::critical; case RayLogLevel::FATAL: return spdlog::level::critical; default: @@ -297,8 +295,6 @@ void RayLog::InitSeverityThreshold(RayLogLevel severity_threshold) { severity_threshold = RayLogLevel::WARNING; } else if (data == "error") { severity_threshold = RayLogLevel::ERROR; - } else if (data == "user_fatal") { - severity_threshold = RayLogLevel::USER_FATAL; } else if (data == "fatal") { severity_threshold = RayLogLevel::FATAL; } else { @@ -534,7 +530,7 @@ void RayLog::AddFatalLogCallbacks( RayLog::RayLog(const char *file_name, int line_number, RayLogLevel severity) : is_enabled_(severity >= severity_threshold_), severity_(severity), - is_fatal_(severity == RayLogLevel::FATAL || severity == RayLogLevel::USER_FATAL) { + is_fatal_(severity == RayLogLevel::FATAL) { if (is_fatal_) { #ifdef _WIN32 int pid = _getpid(); @@ -571,16 +567,10 @@ bool RayLog::IsFatal() const { return is_fatal_; } RayLog::~RayLog() { if (IsFatal()) { - // Only add stack trace for internal fatal errors, not user-facing errors - if (severity_ == RayLogLevel::FATAL) { - msg_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); - expose_fatal_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); - } - const char *callback_label = (severity_ == RayLogLevel::USER_FATAL) - ? "RAY_USER_ERROR" - : "RAY_FATAL_CHECK_FAILED"; + msg_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); + expose_fatal_osstream_ << "\n*** StackTrace Information ***\n" << ray::StackTrace(); for (const auto &callback : fatal_log_callbacks_) { - callback(callback_label, expose_fatal_osstream_.str()); + callback("RAY_FATAL_CHECK_FAILED", expose_fatal_osstream_.str()); } } @@ -603,7 +593,7 @@ RayLog::~RayLog() { } logger->flush(); - if (IsFatal()) { + if (severity_ == RayLogLevel::FATAL) { std::_Exit(EXIT_FAILURE); } } diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index 10b86883d289..405d772b57ef 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -122,8 +122,7 @@ enum class RayLogLevel { INFO = 0, WARNING = 1, ERROR = 2, - USER_FATAL = 3, // User-facing fatal errors (config/usage errors) - FATAL = 4 // Internal fatal errors (bugs in Ray) + FATAL = 3 }; #define RAY_LOG_INTERNAL(level) ::ray::RayLog(__FILE__, __LINE__, level) @@ -152,12 +151,6 @@ enum class RayLogLevel { #define RAY_CHECK(condition) RAY_CHECK_WITH_DISPLAY(condition, #condition) -// User-facing fatal check without "bug in Ray" message -#define RAY_USER_CHECK(condition) \ - RAY_PREDICT_TRUE((condition)) \ - ? RAY_IGNORE_EXPR(0) \ - : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::USER_FATAL) - #ifdef NDEBUG #define RAY_DCHECK(condition) \ diff --git a/src/ray/util/tests/logging_test.cc b/src/ray/util/tests/logging_test.cc index 752c8b10579b..ab82497afccf 100644 --- a/src/ray/util/tests/logging_test.cc +++ b/src/ray/util/tests/logging_test.cc @@ -379,17 +379,6 @@ TEST(PrintLogTest, TestFailureSignalHandler) { ASSERT_DEATH(abort(), ".*SIGABRT received.*"); } -TEST(PrintLogTest, TestUserCheck) { - // RAY_USER_CHECK should not exit the application when condition is true - RAY_USER_CHECK(true) << "This should not trigger"; - // RAY_USER_CHECK should exit the application when condition is false - ASSERT_DEATH( - { RAY_USER_CHECK(false) << "Custom user error message"; }, - testing::AllOf(testing::HasSubstr("Custom user error message"), - testing::Not(testing::HasSubstr("bug in Ray")), - testing::Not(testing::HasSubstr("StackTrace Information")))); -} - } // namespace ray int main(int argc, char **argv) { From 4646909ff6685d36d48f5bf5e5be2d68b0bf88bd Mon Sep 17 00:00:00 2001 From: sampan Date: Fri, 31 Oct 2025 02:43:09 +0000 Subject: [PATCH 40/48] address comments Signed-off-by: sampan --- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index a807cab479a9..c2518d30a07c 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -273,8 +273,7 @@ TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { // Poll should fail with auth error or return OK if it was cancelled // (OK is acceptable because the subscriber may have been closed) if (!status.ok()) { - EXPECT_TRUE(status.IsInvalid() || status.IsRpcError()) - << "Status should be Invalid or RpcError: " << status.ToString(); + EXPECT_TRUE(status.IsInvalid()) << "Status should be Invalid: " << status.ToString(); } ASSERT_TRUE(subscriber->Close().ok()); @@ -295,7 +294,6 @@ TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { // Close should succeed with matching tokens ASSERT_TRUE(subscriber->Close().ok()) << "Close should succeed with matching tokens: " << status.ToString(); - // This assertion will fail until auth is added to `Close()` and the mock is updated. EXPECT_EQ(mock_service_ptr_->unsubscribe_count(), 1); } From 1274e7410e41f17a87b8ebb67f1a2c73cb3f5ac3 Mon Sep 17 00:00:00 2001 From: Sampan S Nayak Date: Tue, 28 Oct 2025 15:29:11 +0000 Subject: [PATCH 41/48] attempt to fix test Signed-off-by: Sampan S Nayak --- src/ray/pubsub/python_gcs_subscriber.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index c4b5ae762e9b..0db5445fe3a1 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -75,7 +75,7 @@ Status PythonGcsSubscriber::Subscribe() { Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) { absl::MutexLock lock(&mu_); - while (queue_.empty()) { + if (queue_.empty()) { if (closed_) { return Status::OK(); } @@ -135,6 +135,11 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) } } + if (queue_.empty()) { + // No messages available after polling + return Status::OK(); + } + *message = std::move(queue_.front()); queue_.pop_front(); From fe6bab4a5b9baf20de5e6fb756abe09025316d76 Mon Sep 17 00:00:00 2001 From: Sampan S Nayak Date: Thu, 30 Oct 2025 04:42:47 +0530 Subject: [PATCH 42/48] [Core] Support token auth in ray Pub-Sub (#58186) This PR adds token-based authentication support to the PythonGcsSubscriber, which previously made direct gRPC calls via the stub without auth. The rest of the pub-sub layer already uses the shared gRPC infrastructure (GrpcServer, GrpcClient), which supports token authentication. --------- Signed-off-by: Sampan S Nayak Signed-off-by: sampan Signed-off-by: Edward Oakes Co-authored-by: sampan Co-authored-by: Edward Oakes --- src/ray/pubsub/python_gcs_subscriber.cc | 11 + src/ray/pubsub/python_gcs_subscriber.h | 4 + src/ray/pubsub/tests/BUILD.bazel | 17 + .../tests/python_gcs_subscriber_auth_test.cc | 352 ++++++++++++++++++ .../authentication_token_loader.cc | 11 +- .../tests/authentication_token_loader_test.cc | 3 +- 6 files changed, 392 insertions(+), 6 deletions(-) create mode 100644 src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index 5d54c4c94a1b..c4b5ae762e9b 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -22,6 +22,7 @@ #include #include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/rpc/authentication/authentication_token_loader.h" namespace ray { namespace pubsub { @@ -51,6 +52,7 @@ Status PythonGcsSubscriber::Subscribe() { } grpc::ClientContext context; + SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -78,6 +80,7 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) return Status::OK(); } current_polling_context_ = std::make_shared(); + SetAuthenticationToken(*current_polling_context_); if (timeout_ms != -1) { current_polling_context_->set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); @@ -173,6 +176,7 @@ Status PythonGcsSubscriber::Close() { } grpc::ClientContext context; + SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -195,5 +199,12 @@ int64_t PythonGcsSubscriber::last_batch_size() { return last_batch_size_; } +void PythonGcsSubscriber::SetAuthenticationToken(grpc::ClientContext &context) { + auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); + if (auth_token.has_value() && !auth_token->empty()) { + auth_token->SetMetadata(context); + } +} + } // namespace pubsub } // namespace ray diff --git a/src/ray/pubsub/python_gcs_subscriber.h b/src/ray/pubsub/python_gcs_subscriber.h index 5fe4eda29812..e8aeaa116566 100644 --- a/src/ray/pubsub/python_gcs_subscriber.h +++ b/src/ray/pubsub/python_gcs_subscriber.h @@ -80,6 +80,10 @@ class RAY_EXPORT PythonGcsSubscriber { std::deque queue_ ABSL_GUARDED_BY(mu_); bool closed_ ABSL_GUARDED_BY(mu_) = false; std::shared_ptr current_polling_context_ ABSL_GUARDED_BY(mu_); + + // Set authentication token on a gRPC client context if token-based authentication is + // enabled + void SetAuthenticationToken(grpc::ClientContext &context); }; /// Get the .lines() attribute of a LogBatch as a std::vector diff --git a/src/ray/pubsub/tests/BUILD.bazel b/src/ray/pubsub/tests/BUILD.bazel index fc1d17ffda7b..395bb7c8b240 100644 --- a/src/ray/pubsub/tests/BUILD.bazel +++ b/src/ray/pubsub/tests/BUILD.bazel @@ -40,3 +40,20 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "python_gcs_subscriber_auth_test", + size = "small", + srcs = ["python_gcs_subscriber_auth_test.cc"], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/common:status", + "//src/ray/protobuf:gcs_service_cc_grpc", + "//src/ray/pubsub:python_gcs_subscriber", + "//src/ray/rpc:grpc_server", + "//src/ray/rpc/authentication:authentication_token", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc new file mode 100644 index 000000000000..a807cab479a9 --- /dev/null +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -0,0 +1,352 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" +#include "ray/pubsub/python_gcs_subscriber.h" +#include "ray/rpc/authentication/authentication_token.h" +#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/grpc_server.h" +#include "src/ray/protobuf/gcs_service.grpc.pb.h" + +namespace ray { +namespace pubsub { + +// Mock implementation of InternalPubSubGcsService for testing authentication +class MockInternalPubSubGcsService final : public rpc::InternalPubSubGcsService::Service { + public: + explicit MockInternalPubSubGcsService(bool should_accept_requests) + : should_accept_requests_(should_accept_requests) {} + + grpc::Status GcsSubscriberCommandBatch( + grpc::ServerContext *context, + const rpc::GcsSubscriberCommandBatchRequest *request, + rpc::GcsSubscriberCommandBatchReply *reply) override { + if (should_accept_requests_) { + for (const auto &command : request->commands()) { + if (command.has_subscribe_message()) { + subscribe_count_++; + } else if (command.has_unsubscribe_message()) { + unsubscribe_count_++; + } + } + return grpc::Status::OK; + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + grpc::Status GcsSubscriberPoll(grpc::ServerContext *context, + const rpc::GcsSubscriberPollRequest *request, + rpc::GcsSubscriberPollReply *reply) override { + if (should_accept_requests_) { + poll_count_++; + // Simulate long polling: block until deadline expires since we have no messages + // Real server would hold the connection open until messages arrive or timeout + auto deadline = context->deadline(); + std::this_thread::sleep_until(deadline); + + // Return deadline exceeded (timeout) with empty messages + // This simulates the real server behavior when no messages are published + return grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, "Long poll timeout"); + } else { + return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "Authentication failed"); + } + } + + int subscribe_count() const { return subscribe_count_; } + int poll_count() const { return poll_count_; } + int unsubscribe_count() const { return unsubscribe_count_; } + + private: + bool should_accept_requests_; + std::atomic subscribe_count_{0}; + std::atomic poll_count_{0}; + std::atomic unsubscribe_count_{0}; +}; + +class PythonGcsSubscriberAuthTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication by default + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + unsetenv("RAY_AUTH_TOKEN"); + // Reset to default auth mode + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + // Start a GCS server with optional authentication token + void StartServer(const std::string &server_token, bool should_accept_requests = true) { + auto mock_service = + std::make_unique(should_accept_requests); + mock_service_ptr_ = mock_service.get(); + + std::optional auth_token; + if (!server_token.empty()) { + auth_token = rpc::AuthenticationToken(server_token); + } else { + // Empty token means no auth required + auth_token = rpc::AuthenticationToken(""); + } + + server_ = std::make_unique("test-gcs-server", + 0, // Random port + true, + 1, + 7200000, + auth_token); + + server_->RegisterService(std::move(mock_service)); + server_->Run(); + + // Wait for server to start + while (server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + server_port_ = server_->GetPort(); + } + + // Set client authentication token via environment variable + void SetClientToken(const std::string &client_token) { + if (!client_token.empty()) { + setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + } else { + unsetenv("RAY_AUTH_TOKEN"); + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + rpc::AuthenticationTokenLoader::instance().ResetCache(); + } + + std::unique_ptr CreateSubscriber() { + return std::make_unique("127.0.0.1", + server_port_, + rpc::ChannelType::RAY_LOG_CHANNEL, + "test-subscriber-id", + "test-worker-id"); + } + + std::unique_ptr server_; + MockInternalPubSubGcsService *mock_service_ptr_ = nullptr; + int server_port_ = 0; +}; + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokens) { + // Test that subscription succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed with matching tokens: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokens) { + // Test that subscription fails when client and server use different tokens + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) << "Subscribe should fail with mismatched tokens"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, ClientTokenServerNoAuth) { + // Test that subscription succeeds when client provides token but server doesn't require + // it + const std::string client_token = "client-token-12345"; + + StartServer(""); // Server doesn't require auth + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) + << "Subscribe should succeed when server doesn't require auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, ServerTokenClientNoAuth) { + // Test that subscription fails when server requires token but client doesn't provide it + const std::string server_token = "server-token-12345"; + + StartServer(server_token, false); // Server will reject requests without valid token + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_FALSE(status.ok()) + << "Subscribe should fail when server requires token but client doesn't provide it"; + EXPECT_TRUE(status.IsRpcError()) << "Status should be RpcError"; + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensPoll) { + // Test that polling succeeds when client and server use the same token + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + + // Test polling with matching tokens - use very short timeout to avoid blocking + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should succeed (returns OK even on timeout or when no messages available) + ASSERT_TRUE(status.ok()) << "Poll should succeed with matching tokens: " + << status.ToString(); + // At least one poll should have been made + EXPECT_GE(mock_service_ptr_->poll_count(), 1); + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, MismatchedTokensPoll) { + // Test that polling fails when tokens don't match + const std::string server_token = "server-token-12345"; + const std::string client_token = "wrong-client-token-67890"; + + StartServer(server_token, false); // Server will reject requests + SetClientToken(client_token); + + auto subscriber = CreateSubscriber(); + + // Subscribe will fail, but let's try anyway + ASSERT_FALSE(subscriber->Subscribe().ok()); + + // Test polling with mismatched tokens - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + Status status = subscriber->PollLogs(&key_id, 10, &log_batch); + + // Poll should fail with auth error or return OK if it was cancelled + // (OK is acceptable because the subscriber may have been closed) + if (!status.ok()) { + EXPECT_TRUE(status.IsInvalid() || status.IsRpcError()) + << "Status should be Invalid or RpcError: " << status.ToString(); + } + + ASSERT_TRUE(subscriber->Close().ok()); +} + +TEST_F(PythonGcsSubscriberAuthTest, MatchingTokensClose) { + // Test that closing/unsubscribing succeeds with matching tokens + const std::string test_token = "matching-test-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + ASSERT_TRUE(status.ok()) << "Subscribe should succeed: " << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + // Close should succeed with matching tokens + ASSERT_TRUE(subscriber->Close().ok()) + << "Close should succeed with matching tokens: " << status.ToString(); + // This assertion will fail until auth is added to `Close()` and the mock is updated. + EXPECT_EQ(mock_service_ptr_->unsubscribe_count(), 1); +} + +TEST_F(PythonGcsSubscriberAuthTest, NoAuthRequired) { + // Test that everything works when neither client nor server use auth + StartServer(""); // Server doesn't require auth + SetClientToken(""); // Client doesn't provide token + + auto subscriber = CreateSubscriber(); + Status status = subscriber->Subscribe(); + + ASSERT_TRUE(status.ok()) << "Subscribe should succeed without auth: " + << status.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 1); + + // Test polling without auth - use very short timeout + std::string key_id; + rpc::LogBatch log_batch; + status = subscriber->PollLogs(&key_id, 10, &log_batch); + ASSERT_TRUE(status.ok()) << "Poll should succeed without auth: " << status.ToString(); + + // Test close without auth + status = subscriber->Close(); + ASSERT_TRUE(status.ok()) << "Close should succeed without auth: " << status.ToString(); +} + +TEST_F(PythonGcsSubscriberAuthTest, MultipleSubscribersMatchingTokens) { + // Test multiple subscribers with the same token + const std::string test_token = "shared-token-12345"; + + StartServer(test_token); + SetClientToken(test_token); + + auto subscriber1 = CreateSubscriber(); + auto subscriber2 = CreateSubscriber(); + + Status status1 = subscriber1->Subscribe(); + Status status2 = subscriber2->Subscribe(); + + ASSERT_TRUE(status1.ok()) << "First subscriber should succeed: " << status1.ToString(); + ASSERT_TRUE(status2.ok()) << "Second subscriber should succeed: " << status2.ToString(); + EXPECT_EQ(mock_service_ptr_->subscribe_count(), 2); + + ASSERT_TRUE(subscriber1->Close().ok()); + ASSERT_TRUE(subscriber2->Close().ok()); +} + +} // namespace pubsub +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc index 621f28fe351c..7e3903423890 100644 --- a/src/ray/rpc/authentication/authentication_token_loader.cc +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -57,8 +57,9 @@ std::optional AuthenticationTokenLoader::GetToken() { AuthenticationToken token = LoadTokenFromSources(); // If no token found and auth is enabled, fail with RAY_CHECK - RAY_CHECK(!token.empty()) - << "Token authentication is enabled but Ray couldn't find an authentication token. " + RAY_USER_CHECK(!token.empty()) + << "Ray Setup Error: Token authentication is enabled but Ray couldn't find an " + "authentication token. " << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " "point to a file with the token, " << "or create a token file at ~/.ray/auth_token."; @@ -100,9 +101,9 @@ AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { std::string path_str(env_token_path); if (!path_str.empty()) { std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); - RAY_CHECK(!token_str.empty()) - << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " - << path_str; + RAY_USER_CHECK(!token_str.empty()) << "Ray Setup Error: RAY_AUTH_TOKEN_PATH is set " + "but file cannot be opened or is empty: " + << path_str; RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; return AuthenticationToken(token_str); } diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 2332c6d09313..b758acd64c05 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -299,7 +299,8 @@ TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { auto &loader = AuthenticationTokenLoader::instance(); loader.GetToken(); }, - "Token authentication is enabled but Ray couldn't find an authentication token."); + "Ray Setup Error: Token authentication is enabled but Ray couldn't find an " + "authentication token."); } TEST_F(AuthenticationTokenLoaderTest, TestCaching) { From 82d0b7cbda5369dff2e7bf889f08cf9b7e5ad491 Mon Sep 17 00:00:00 2001 From: sampan Date: Sat, 1 Nov 2025 15:06:44 +0000 Subject: [PATCH 43/48] fix build Signed-off-by: sampan --- src/ray/ray_syncer/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/ray_syncer/BUILD.bazel b/src/ray/ray_syncer/BUILD.bazel index a5f287dd1e8e..a6b6896a5253 100644 --- a/src/ray/ray_syncer/BUILD.bazel +++ b/src/ray/ray_syncer/BUILD.bazel @@ -20,7 +20,7 @@ ray_cc_library( deps = [ ":asio", ":id", - "//:ray_syncer_cc_grpc", + "//src/ray/protobuf:ray_syncer_cc_grpc", "//src/ray/common:constants", "//src/ray/rpc/authentication:authentication_token", "//src/ray/rpc/authentication:authentication_token_loader", From b64bd44f9738d625d886d2ea6abe7fb732b951b9 Mon Sep 17 00:00:00 2001 From: sampan Date: Sat, 1 Nov 2025 15:07:24 +0000 Subject: [PATCH 44/48] fix lint Signed-off-by: sampan --- src/ray/ray_syncer/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/ray_syncer/BUILD.bazel b/src/ray/ray_syncer/BUILD.bazel index a6b6896a5253..2ca673035a74 100644 --- a/src/ray/ray_syncer/BUILD.bazel +++ b/src/ray/ray_syncer/BUILD.bazel @@ -20,8 +20,8 @@ ray_cc_library( deps = [ ":asio", ":id", - "//src/ray/protobuf:ray_syncer_cc_grpc", "//src/ray/common:constants", + "//src/ray/protobuf:ray_syncer_cc_grpc", "//src/ray/rpc/authentication:authentication_token", "//src/ray/rpc/authentication:authentication_token_loader", "@com_github_grpc_grpc//:grpc++", From f27a39af18665f40b01b95ac8efac503c57d0213 Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 13:00:20 +0000 Subject: [PATCH 45/48] fix build Signed-off-by: sampan --- src/ray/ray_syncer/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/ray_syncer/BUILD.bazel b/src/ray/ray_syncer/BUILD.bazel index 2ca673035a74..c5c65ac532c0 100644 --- a/src/ray/ray_syncer/BUILD.bazel +++ b/src/ray/ray_syncer/BUILD.bazel @@ -18,9 +18,9 @@ ray_cc_library( "ray_syncer/ray_syncer_server.h", ], deps = [ - ":asio", - ":id", + "//src/ray/common:asio", "//src/ray/common:constants", + "//src/ray/common:id", "//src/ray/protobuf:ray_syncer_cc_grpc", "//src/ray/rpc/authentication:authentication_token", "//src/ray/rpc/authentication:authentication_token_loader", From b8f2d6389c15868a33689dc9f7ab0b2669efa89d Mon Sep 17 00:00:00 2001 From: sampan Date: Sun, 2 Nov 2025 15:36:06 +0000 Subject: [PATCH 46/48] fix issue during merge conflict Signed-off-by: sampan --- src/ray/ray_syncer/BUILD.bazel | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ray/ray_syncer/BUILD.bazel b/src/ray/ray_syncer/BUILD.bazel index c5c65ac532c0..c7cd8ca2a0b0 100644 --- a/src/ray/ray_syncer/BUILD.bazel +++ b/src/ray/ray_syncer/BUILD.bazel @@ -3,19 +3,19 @@ load("//bazel:ray.bzl", "ray_cc_library") ray_cc_library( name = "ray_syncer", srcs = [ - "ray_syncer/node_state.cc", - "ray_syncer/ray_syncer.cc", - "ray_syncer/ray_syncer_client.cc", - "ray_syncer/ray_syncer_server.cc", + "node_state.cc", + "ray_syncer.cc", + "ray_syncer_client.cc", + "ray_syncer_server.cc", ], hdrs = [ - "ray_syncer/common.h", - "ray_syncer/node_state.h", - "ray_syncer/ray_syncer.h", - "ray_syncer/ray_syncer_bidi_reactor.h", - "ray_syncer/ray_syncer_bidi_reactor_base.h", - "ray_syncer/ray_syncer_client.h", - "ray_syncer/ray_syncer_server.h", + "common.h", + "node_state.h", + "ray_syncer.h", + "ray_syncer_bidi_reactor.h", + "ray_syncer_bidi_reactor_base.h", + "ray_syncer_client.h", + "ray_syncer_server.h", ], deps = [ "//src/ray/common:asio", From 6d1a9c65a74c1da4af79984095dd1b48f9b8a0d3 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 02:32:15 +0000 Subject: [PATCH 47/48] revert unneeded changes Signed-off-by: sampan --- src/ray/pubsub/python_gcs_subscriber.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index 0db5445fe3a1..c4b5ae762e9b 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -75,7 +75,7 @@ Status PythonGcsSubscriber::Subscribe() { Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) { absl::MutexLock lock(&mu_); - if (queue_.empty()) { + while (queue_.empty()) { if (closed_) { return Status::OK(); } @@ -135,11 +135,6 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) } } - if (queue_.empty()) { - // No messages available after polling - return Status::OK(); - } - *message = std::move(queue_.front()); queue_.pop_front(); From 80c77439c195a7fb3936bf47e3eecb97716fd5ff Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 03:13:14 +0000 Subject: [PATCH 48/48] address cursor comment Signed-off-by: sampan --- src/ray/ray_syncer/ray_syncer.cc | 7 +++++++ src/ray/ray_syncer/ray_syncer_server.h | 11 +++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/ray/ray_syncer/ray_syncer.cc b/src/ray/ray_syncer/ray_syncer.cc index eecbbdaff3b7..837b314ed104 100644 --- a/src/ray/ray_syncer/ray_syncer.cc +++ b/src/ray/ray_syncer/ray_syncer.cc @@ -248,6 +248,13 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont /*auth_token=*/auth_token_); RAY_LOG(DEBUG).WithField(NodeID::FromBinary(reactor->GetRemoteNodeID())) << "Get connection"; + + // If the reactor has already called Finish() (e.g., due to authentication failure), + // skip registration. The reactor will clean itself up via OnDone(). + if (reactor->IsFinished()) { + return reactor; + } + // Disconnect exiting connection if there is any. // This can happen when there is transient network error // and the client reconnects. diff --git a/src/ray/ray_syncer/ray_syncer_server.h b/src/ray/ray_syncer/ray_syncer_server.h index 2b239f54ead6..6db427958667 100644 --- a/src/ray/ray_syncer/ray_syncer_server.h +++ b/src/ray/ray_syncer/ray_syncer_server.h @@ -16,6 +16,7 @@ #include +#include #include #include @@ -42,11 +43,18 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase ~RayServerBidiReactor() override = default; + bool IsFinished() const { return finished_.load(); } + private: void DoDisconnect() override; void OnCancel() override; void OnDone() override; + void Finish(grpc::Status status) { + finished_.store(true); + ServerBidiReactor::Finish(status); + } + /// Cleanup callback when the call ends. const std::function cleanup_cb_; @@ -57,6 +65,9 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase /// disabled std::optional auth_token_; + /// Track if Finish() has been called to avoid using a reactor that is terminating + std::atomic finished_{false}; + FRIEND_TEST(SyncerReactorTest, TestReactorFailure); };