From f0c063bbad7d6880de0dc3f45678fcc8c248d1b6 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 08:01:54 +0000 Subject: [PATCH 01/26] [core] Support token based auth in ray dashboard UI Signed-off-by: sampan --- python/ray/dashboard/client/src/App.tsx | 102 +++++++++ .../TokenAuthenticationDialog.test.tsx | 202 ++++++++++++++++++ .../TokenAuthenticationDialog.tsx | 146 +++++++++++++ .../src/authentication/authentication.ts | 50 +++++ .../client/src/authentication/cookies.test.ts | 107 ++++++++++ .../client/src/authentication/cookies.ts | 82 +++++++ .../ray/dashboard/client/src/service/event.ts | 8 +- .../client/src/service/requestHandlers.ts | 48 ++++- .../ray/dashboard/client/src/service/util.ts | 10 +- python/ray/dashboard/http_server_head.py | 53 +++++ .../dashboard/tests/test_dashboard_auth.py | 39 ++++ 11 files changed, 837 insertions(+), 10 deletions(-) create mode 100644 python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx create mode 100644 python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx create mode 100644 python/ray/dashboard/client/src/authentication/authentication.ts create mode 100644 python/ray/dashboard/client/src/authentication/cookies.test.ts create mode 100644 python/ray/dashboard/client/src/authentication/cookies.ts diff --git a/python/ray/dashboard/client/src/App.tsx b/python/ray/dashboard/client/src/App.tsx index 7e37d6819ca6..4df0f4ba40ce 100644 --- a/python/ray/dashboard/client/src/App.tsx +++ b/python/ray/dashboard/client/src/App.tsx @@ -4,6 +4,15 @@ import dayjs from "dayjs"; import duration from "dayjs/plugin/duration"; import React, { Suspense, useEffect, useState } from "react"; import { HashRouter, Navigate, Route, Routes } from "react-router-dom"; +import TokenAuthenticationDialog from "./authentication/TokenAuthenticationDialog"; +import { + getAuthenticationMode, + testTokenValidity, +} from "./authentication/authentication"; +import { + getAuthenticationToken, + setAuthenticationToken, +} from "./authentication/cookies"; import ActorDetailPage, { ActorDetailLayout } from "./pages/actor/ActorDetail"; import { ActorLayout } from "./pages/actor/ActorLayout"; import Loading from "./pages/exception/Loading"; @@ -147,6 +156,15 @@ const App = () => { dashboardDatasource: undefined, serverTimeZone: undefined, }); + + // Authentication state + const [authenticationDialogOpen, setAuthenticationDialogOpen] = + useState(false); + const [hasAttemptedAuthentication, setHasAttemptedAuthentication] = + useState(false); + const [authenticationError, setAuthenticationError] = useState< + string | undefined + >(); useEffect(() => { getNodeList().then((res) => { if (res?.data?.data?.summary) { @@ -218,12 +236,96 @@ const App = () => { updateTimezone(); }, []); + // Check authentication mode on mount + useEffect(() => { + const checkAuthentication = async () => { + try { + const { authentication_mode } = await getAuthenticationMode(); + + if (authentication_mode === "token") { + // Token authentication is enabled + const existingToken = getAuthenticationToken(); + + if (!existingToken) { + // No token found - show dialog immediately + setAuthenticationDialogOpen(true); + } + // If token exists, let it be used by interceptor + // If invalid, interceptor will trigger dialog via 401/403 + } + } catch (error) { + console.error("Failed to check authentication mode:", error); + } + }; + + checkAuthentication(); + }, []); + + // Listen for authentication errors from axios interceptor + useEffect(() => { + const handleAuthenticationError = (event: Event) => { + const customEvent = event as CustomEvent<{ hadToken: boolean }>; + const hadToken = customEvent.detail?.hadToken ?? false; + + setHasAttemptedAuthentication(hadToken); + setAuthenticationDialogOpen(true); + }; + + window.addEventListener( + "ray-authentication-error", + handleAuthenticationError, + ); + + return () => { + window.removeEventListener( + "ray-authentication-error", + handleAuthenticationError, + ); + }; + }, []); + + // Handle token submission from dialog + const handleTokenSubmit = async (token: string) => { + try { + // Test if token is valid + const isValid = await testTokenValidity(token); + + if (isValid) { + // Save token to cookie + setAuthenticationToken(token); + setHasAttemptedAuthentication(true); + setAuthenticationDialogOpen(false); + setAuthenticationError(undefined); + + // Reload the page to refetch all data with the new token + window.location.reload(); + } else { + // Token is invalid + setHasAttemptedAuthentication(true); + setAuthenticationError( + "Invalid authentication token. Please check and try again.", + ); + } + } catch (error) { + console.error("Failed to validate token:", error); + setAuthenticationError( + "Failed to validate token. Please check your connection and try again.", + ); + } + }; + return ( + {/* Redirect people hitting the /new path to root. TODO(aguo): Delete this redirect in ray 2.5 */} diff --git a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx new file mode 100644 index 000000000000..5eea6cdc1d25 --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx @@ -0,0 +1,202 @@ +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import React from "react"; +import "@testing-library/jest-dom"; +import TokenAuthenticationDialog from "./TokenAuthenticationDialog"; + +describe("TokenAuthenticationDialog", () => { + const mockOnSubmit = jest.fn(); + + beforeEach(() => { + mockOnSubmit.mockClear(); + }); + + it("renders with initial message when no existing token", () => { + render( + , + ); + + expect(screen.getByText("Token Authentication Required")).toBeInTheDocument(); + expect( + screen.getByText(/token authentication is enabled for this cluster/i), + ).toBeInTheDocument(); + }); + + it("renders with re-authentication message when has existing token", () => { + render( + , + ); + + expect(screen.getByText("Token Authentication Required")).toBeInTheDocument(); + expect( + screen.getByText(/authentication token is invalid or has expired/i), + ).toBeInTheDocument(); + }); + + it("displays error message when provided", () => { + const errorMessage = "Invalid token provided"; + render( + , + ); + + expect(screen.getByText(errorMessage)).toBeInTheDocument(); + }); + + it("calls onSubmit with entered token when submit is clicked", async () => { + const user = userEvent.setup(); + mockOnSubmit.mockResolvedValue(undefined); + + render( + , + ); + + const input = screen.getByLabelText(/authentication token/i); + await user.type(input, "test-token-123"); + + const submitButton = screen.getByRole("button", { name: /submit/i }); + await user.click(submitButton); + + await waitFor(() => { + expect(mockOnSubmit).toHaveBeenCalledWith("test-token-123"); + }); + }); + + it("calls onSubmit when Enter key is pressed", async () => { + const user = userEvent.setup(); + mockOnSubmit.mockResolvedValue(undefined); + + render( + , + ); + + const input = screen.getByLabelText(/authentication token/i); + await user.type(input, "test-token-123{Enter}"); + + await waitFor(() => { + expect(mockOnSubmit).toHaveBeenCalledWith("test-token-123"); + }); + }); + + it("disables submit button when token is empty", () => { + render( + , + ); + + const submitButton = screen.getByRole("button", { name: /submit/i }); + expect(submitButton).toBeDisabled(); + }); + + it("enables submit button when token is entered", async () => { + const user = userEvent.setup(); + render( + , + ); + + const submitButton = screen.getByRole("button", { name: /submit/i }); + expect(submitButton).toBeDisabled(); + + const input = screen.getByLabelText(/authentication token/i); + await user.type(input, "test-token"); + + expect(submitButton).not.toBeDisabled(); + }); + + it("toggles token visibility when visibility icon is clicked", async () => { + const user = userEvent.setup(); + render( + , + ); + + const input = screen.getByLabelText(/authentication token/i); + await user.type(input, "secret-token"); + + // Initially should be password type (hidden) + expect(input).toHaveAttribute("type", "password"); + + // Click visibility toggle + const toggleButton = screen.getByLabelText(/toggle token visibility/i); + await user.click(toggleButton); + + // Should now be text type (visible) + expect(input).toHaveAttribute("type", "text"); + + // Click again to hide + await user.click(toggleButton); + expect(input).toHaveAttribute("type", "password"); + }); + + it("shows loading state during submission", async () => { + const user = userEvent.setup(); + // Mock a slow submission + mockOnSubmit.mockImplementation( + () => new Promise((resolve) => setTimeout(resolve, 100)), + ); + + render( + , + ); + + const input = screen.getByLabelText(/authentication token/i); + await user.type(input, "test-token"); + + const submitButton = screen.getByRole("button", { name: /submit/i }); + await user.click(submitButton); + + // Should show validating state + await waitFor(() => { + expect(screen.getByText(/validating.../i)).toBeInTheDocument(); + }); + }); + + it("does not render when open is false", () => { + const { container } = render( + , + ); + + // Dialog should not be visible + expect( + screen.queryByText("Token Authentication Required"), + ).not.toBeInTheDocument(); + }); +}); diff --git a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx new file mode 100644 index 000000000000..cb28d27d7ad4 --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx @@ -0,0 +1,146 @@ +/** + * Dialog component for Ray dashboard token authentication. + * Prompts users to enter their authentication token when token auth is enabled. + */ + +import React, { useState } from "react"; +import { + Alert, + Button, + CircularProgress, + Dialog, + DialogActions, + DialogContent, + DialogContentText, + DialogTitle, + IconButton, + InputAdornment, + TextField, +} from "@mui/material"; +import { Visibility, VisibilityOff } from "@mui/icons-material"; + +export type TokenAuthenticationDialogProps = { + /** Whether the dialog is open */ + open: boolean; + /** Whether the user has previously entered a token (affects messaging) */ + hasExistingToken: boolean; + /** Callback when user submits a token */ + onSubmit: (token: string) => Promise; + /** Optional error message to display */ + error?: string; +}; + +/** + * Token Authentication Dialog Component. + * + * Shows different messages based on whether this is the first time + * (hasExistingToken=false) or if a previously stored token was rejected + * (hasExistingToken=true). + */ +export const TokenAuthenticationDialog: React.FC< + TokenAuthenticationDialogProps +> = ({ open, hasExistingToken, onSubmit, error }) => { + const [token, setToken] = useState(""); + const [showToken, setShowToken] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); + + const handleSubmit = async () => { + if (!token.trim()) { + return; + } + + setIsSubmitting(true); + try { + await onSubmit(token.trim()); + // If successful, the parent component will close the dialog + // and likely reload the page + } catch (err) { + // Error is handled by parent component via the error prop + console.error("Failed to submit token:", err); + } finally { + setIsSubmitting(false); + } + }; + + const handleKeyPress = (event: React.KeyboardEvent) => { + if (event.key === "Enter" && !isSubmitting) { + handleSubmit(); + } + }; + + const toggleShowToken = () => { + setShowToken(!showToken); + }; + + // Different messages based on whether this is initial auth or re-auth + const title = "Token Authentication Required"; + const message = hasExistingToken + ? "The authentication token is invalid or has expired. Please provide a valid authentication token." + : "Token authentication is enabled for this cluster. Please provide a valid authentication token."; + + return ( + + {title} + + + {message} + + + {error && ( + + {error} + + )} + + setToken(e.target.value)} + onKeyPress={handleKeyPress} + disabled={isSubmitting} + placeholder="Enter your authentication token" + InputProps={{ + endAdornment: ( + + + {showToken ? : } + + + ), + }} + /> + + + + + + ); +}; + +export default TokenAuthenticationDialog; diff --git a/python/ray/dashboard/client/src/authentication/authentication.ts b/python/ray/dashboard/client/src/authentication/authentication.ts new file mode 100644 index 000000000000..fd47d4d28f56 --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/authentication.ts @@ -0,0 +1,50 @@ +/** + * Authentication service for Ray dashboard. + * Provides functions to check authentication mode and validate tokens when token auth is enabled. + */ + +import axios from "axios"; + +/** + * Response type for authentication mode endpoint. + */ +export type AuthenticationModeResponse = { + authentication_mode: "disabled" | "token"; +}; + +/** + * Get the current authentication mode from the server. + * This endpoint is public and does not require authentication. + * + * @returns Promise resolving to the authentication mode + */ +export const getAuthenticationMode = + async (): Promise => { + const response = await axios.get( + "/api/authentication_mode", + ); + return response.data; + }; + +/** + * Test if a token is valid by making a request to the /api/version endpoint + * which is fast and reliable. + * + * @param token - The authentication token to test + * @returns Promise resolving to true if token is valid, false otherwise + */ +export const testTokenValidity = async (token: string): Promise => { + try { + await axios.get("/api/version", { + headers: { Authorization: `Bearer ${token}` }, + }); + return true; + } catch (error: any) { + // 401 (Unauthorized) or 403 (Forbidden) means invalid token + if (error.response?.status === 401 || error.response?.status === 403) { + return false; + } + // For other errors (network, server errors, etc.), re-throw + throw error; + } +}; diff --git a/python/ray/dashboard/client/src/authentication/cookies.test.ts b/python/ray/dashboard/client/src/authentication/cookies.test.ts new file mode 100644 index 000000000000..d1cb22d67a24 --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/cookies.test.ts @@ -0,0 +1,107 @@ +import "@testing-library/jest-dom"; +import { + clearAuthenticationToken, + deleteCookie, + getAuthenticationToken, + getCookie, + setAuthenticationToken, + setCookie, +} from "./cookies"; + +describe("Cookie utilities", () => { + beforeEach(() => { + // Clear all cookies before each test + document.cookie.split(";").forEach((cookie) => { + const name = cookie.split("=")[0].trim(); + document.cookie = `${name}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;`; + }); + }); + + describe("setCookie and getCookie", () => { + it("sets and retrieves a cookie", () => { + setCookie("test-cookie", "test-value"); + const value = getCookie("test-cookie"); + expect(value).toBe("test-value"); + }); + + it("returns null for non-existent cookie", () => { + const value = getCookie("non-existent"); + expect(value).toBeNull(); + }); + + it("overwrites existing cookie with same name", () => { + setCookie("test-cookie", "value1"); + setCookie("test-cookie", "value2"); + const value = getCookie("test-cookie"); + expect(value).toBe("value2"); + }); + }); + + describe("deleteCookie", () => { + it("deletes an existing cookie", () => { + setCookie("test-cookie", "test-value"); + expect(getCookie("test-cookie")).toBe("test-value"); + + deleteCookie("test-cookie"); + expect(getCookie("test-cookie")).toBeNull(); + }); + + it("handles deletion of non-existent cookie", () => { + // Should not throw error + expect(() => deleteCookie("non-existent")).not.toThrow(); + }); + }); + + describe("Authentication token functions", () => { + it("sets and retrieves authentication token", () => { + const testToken = "test-auth-token-123"; + setAuthenticationToken(testToken); + + const retrievedToken = getAuthenticationToken(); + expect(retrievedToken).toBe(testToken); + }); + + it("returns null when no authentication token is set", () => { + const token = getAuthenticationToken(); + expect(token).toBeNull(); + }); + + it("clears authentication token", () => { + setAuthenticationToken("test-token"); + expect(getAuthenticationToken()).toBe("test-token"); + + clearAuthenticationToken(); + expect(getAuthenticationToken()).toBeNull(); + }); + + it("overwrites existing authentication token", () => { + setAuthenticationToken("token1"); + expect(getAuthenticationToken()).toBe("token1"); + + setAuthenticationToken("token2"); + expect(getAuthenticationToken()).toBe("token2"); + }); + }); + + describe("Multiple cookies", () => { + it("handles multiple cookies independently", () => { + setCookie("cookie1", "value1"); + setCookie("cookie2", "value2"); + setCookie("cookie3", "value3"); + + expect(getCookie("cookie1")).toBe("value1"); + expect(getCookie("cookie2")).toBe("value2"); + expect(getCookie("cookie3")).toBe("value3"); + }); + + it("deletes only specified cookie", () => { + setCookie("cookie1", "value1"); + setCookie("cookie2", "value2"); + + deleteCookie("cookie1"); + + expect(getCookie("cookie1")).toBeNull(); + expect(getCookie("cookie2")).toBe("value2"); + }); + }); +}); diff --git a/python/ray/dashboard/client/src/authentication/cookies.ts b/python/ray/dashboard/client/src/authentication/cookies.ts new file mode 100644 index 000000000000..929a30030738 --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/cookies.ts @@ -0,0 +1,82 @@ +/** + * Cookie utility functions for Ray dashboard authentication. + */ + +const AUTHENTICATION_TOKEN_COOKIE_NAME = "ray-authentication-token"; + +/** + * Get a cookie value by name. + * + * @param name - The name of the cookie to retrieve + * @returns The cookie value if found, null otherwise + */ +export const getCookie = (name: string): string | null => { + const nameEQ = name + "="; + const cookies = document.cookie.split(";"); + + for (let i = 0; i < cookies.length; i++) { + let cookie = cookies[i]; + while (cookie.charAt(0) === " ") { + cookie = cookie.substring(1, cookie.length); + } + if (cookie.indexOf(nameEQ) === 0) { + return cookie.substring(nameEQ.length, cookie.length); + } + } + return null; +}; + +/** + * Set a cookie with the given name, value, and expiration. + * + * @param name - The name of the cookie + * @param value - The value to store in the cookie + * @param days - Number of days until the cookie expires (default: 30) + */ +export const setCookie = ( + name: string, + value: string, + days = 30, +): void => { + let expires = ""; + if (days) { + const date = new Date(); + date.setTime(date.getTime() + days * 24 * 60 * 60 * 1000); + expires = "; expires=" + date.toUTCString(); + } + document.cookie = name + "=" + (value || "") + expires + "; path=/"; +}; + +/** + * Delete a cookie by name. + * + * @param name - The name of the cookie to delete + */ +export const deleteCookie = (name: string): void => { + document.cookie = name + "=; Max-Age=-99999999; path=/"; +}; + +/** + * Get the authentication token from cookies. + * + * @returns The authentication token if found, null otherwise + */ +export const getAuthenticationToken = (): string | null => { + return getCookie(AUTHENTICATION_TOKEN_COOKIE_NAME); +}; + +/** + * Set the authentication token in cookies. + * + * @param token - The authentication token to store + */ +export const setAuthenticationToken = (token: string): void => { + setCookie(AUTHENTICATION_TOKEN_COOKIE_NAME, token); +}; + +/** + * Clear the authentication token from cookies. + */ +export const clearAuthenticationToken = (): void => { + deleteCookie(AUTHENTICATION_TOKEN_COOKIE_NAME); +}; diff --git a/python/ray/dashboard/client/src/service/event.ts b/python/ray/dashboard/client/src/service/event.ts index dcd153ed4542..a73dc926379f 100644 --- a/python/ray/dashboard/client/src/service/event.ts +++ b/python/ray/dashboard/client/src/service/event.ts @@ -1,18 +1,18 @@ -import axios from "axios"; +import { axiosInstance } from "./requestHandlers"; import { EventGlobalRsp, EventRsp } from "../type/event"; export const getEvents = (jobId: string) => { if (jobId) { - return axios.get(`events?job_id=${jobId}`); + return axiosInstance.get(`events?job_id=${jobId}`); } }; export const getPipelineEvents = (jobId: string) => { if (jobId) { - return axios.get(`events?job_id=${jobId}&view=pipeline`); + return axiosInstance.get(`events?job_id=${jobId}&view=pipeline`); } }; export const getGlobalEvents = () => { - return axios.get("events"); + return axiosInstance.get("events"); }; diff --git a/python/ray/dashboard/client/src/service/requestHandlers.ts b/python/ray/dashboard/client/src/service/requestHandlers.ts index 9da2ff6fc8aa..f601a3f522e4 100644 --- a/python/ray/dashboard/client/src/service/requestHandlers.ts +++ b/python/ray/dashboard/client/src/service/requestHandlers.ts @@ -9,6 +9,7 @@ */ import axios, { AxiosRequestConfig, AxiosResponse } from "axios"; +import { getAuthenticationToken } from "../authentication/cookies"; /** * This function formats URLs such that the user's browser @@ -26,9 +27,54 @@ export const formatUrl = (url: string): string => { return url; }; +// Create axios instance with interceptors for authentication +const axiosInstance = axios.create(); + +// Export the configured axios instance for direct use when needed +export { axiosInstance }; + +// Request interceptor: Add authentication token if available +axiosInstance.interceptors.request.use( + (config) => { + const token = getAuthenticationToken(); + if (token) { + config.headers.Authorization = `Bearer ${token}`; + } + return config; + }, + (error) => { + return Promise.reject(error); + }, +); + +// Response interceptor: Handle 401/403 errors +axiosInstance.interceptors.response.use( + (response) => { + return response; + }, + (error) => { + // If we get 401 (Unauthorized) or 403 (Forbidden), dispatch an event + // so the App component can show the authentication dialog + if (error.response?.status === 401 || error.response?.status === 403) { + // Check if there was a token in the request + const hadToken = !!getAuthenticationToken(); + + // Dispatch custom event for authentication error + window.dispatchEvent( + new CustomEvent("ray-authentication-error", { + detail: { hadToken }, + }), + ); + } + + // Re-throw the error so the caller can handle it if needed + return Promise.reject(error); + }, +); + export const get = >( url: string, config?: AxiosRequestConfig, ): Promise => { - return axios.get(formatUrl(url), config); + return axiosInstance.get(formatUrl(url), config); }; diff --git a/python/ray/dashboard/client/src/service/util.ts b/python/ray/dashboard/client/src/service/util.ts index 966c82db2919..e666c6fbc8d2 100644 --- a/python/ray/dashboard/client/src/service/util.ts +++ b/python/ray/dashboard/client/src/service/util.ts @@ -1,4 +1,4 @@ -import axios from "axios"; +import { axiosInstance } from "./requestHandlers"; type CMDRsp = { result: boolean; @@ -9,7 +9,7 @@ type CMDRsp = { }; export const getJstack = (ip: string, pid: string) => { - return axios.get("utils/jstack", { + return axiosInstance.get("utils/jstack", { params: { ip, pid, @@ -18,7 +18,7 @@ export const getJstack = (ip: string, pid: string) => { }; export const getJmap = (ip: string, pid: string) => { - return axios.get("utils/jmap", { + return axiosInstance.get("utils/jmap", { params: { ip, pid, @@ -27,7 +27,7 @@ export const getJmap = (ip: string, pid: string) => { }; export const getJstat = (ip: string, pid: string, options: string) => { - return axios.get("utils/jstat", { + return axiosInstance.get("utils/jstat", { params: { ip, pid, @@ -48,5 +48,5 @@ type NamespacesRsp = { }; export const getNamespaces = () => { - return axios.get("namespaces"); + return axiosInstance.get("namespaces"); }; diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 5f7054900c18..502a5247828a 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -20,6 +20,7 @@ from ray._common.network_utils import build_address, parse_address from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag from ray._common.utils import get_or_create_event_loop +from ray._raylet import AuthenticationMode, get_authentication_mode from ray.dashboard import authentication_utils as auth_utils from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics from ray.dashboard.head import DashboardHeadModule @@ -160,6 +161,28 @@ async def get_timezone(self, req) -> aiohttp.web.Response: status=500, text="Internal Server Error:" + str(e) ) + @routes.get("/api/authentication_mode") + async def get_authentication_mode(self, req) -> aiohttp.web.Response: + """Get the current authentication mode. + + Returns: + JSON response with authentication_mode field. + Possible values: "disabled", "token" + """ + try: + mode = get_authentication_mode() + if mode == AuthenticationMode.TOKEN: + mode_str = "token" + else: + mode_str = "disabled" + + return aiohttp.web.json_response({"authentication_mode": mode_str}) + except Exception as e: + logger.error(f"Error getting authentication mode: {e}") + return aiohttp.web.Response( + status=500, text="Internal Server Error: " + str(e) + ) + def get_address(self): assert self.http_host and self.http_port return self.http_host, self.http_port @@ -172,6 +195,36 @@ async def auth_middleware(self, request, handler): if not auth_utils.is_token_auth_enabled(): return await handler(request) + # Public endpoints that don't require authentication + # These endpoints are needed to check authentication status or serve static content + public_endpoints = [ + "/api/authentication_mode", + ] + + # Public paths (using startswith for path prefixes) + public_path_prefixes = [ + "/static/", # Static assets (JS, CSS, images) + ] + + # Public exact paths + public_exact_paths = [ + "/", # Root index.html + "/favicon.ico", # Favicon + ] + + # Skip authentication for public endpoints + if request.path in public_endpoints: + return await handler(request) + + # Skip authentication for public path prefixes + for prefix in public_path_prefixes: + if request.path.startswith(prefix): + return await handler(request) + + # Skip authentication for public exact paths + if request.path in public_exact_paths: + return await handler(request) + # Extract and validate token auth_header = request.headers.get("Authorization", "") diff --git a/python/ray/dashboard/tests/test_dashboard_auth.py b/python/ray/dashboard/tests/test_dashboard_auth.py index 7407fc199a1d..5f4f9b8ffc11 100644 --- a/python/ray/dashboard/tests/test_dashboard_auth.py +++ b/python/ray/dashboard/tests/test_dashboard_auth.py @@ -63,6 +63,45 @@ def test_dashboard_auth_disabled(setup_cluster_without_token_auth): assert response.status_code == 200 +def test_authentication_mode_endpoint_with_token_auth(setup_cluster_with_token_auth): + """Test authentication_mode endpoint returns 'token' when auth is enabled.""" + + cluster_info = setup_cluster_with_token_auth + + # This endpoint should be accessible WITHOUT authentication + response = requests.get(f"{cluster_info['dashboard_url']}/api/authentication_mode") + + assert response.status_code == 200 + assert response.json() == {"authentication_mode": "token"} + + +def test_authentication_mode_endpoint_without_auth(setup_cluster_without_token_auth): + """Test authentication_mode endpoint returns 'disabled' when auth is off.""" + + cluster_info = setup_cluster_without_token_auth + + response = requests.get(f"{cluster_info['dashboard_url']}/api/authentication_mode") + + assert response.status_code == 200 + assert response.json() == {"authentication_mode": "disabled"} + + +def test_authentication_mode_endpoint_is_public(setup_cluster_with_token_auth): + """Test authentication_mode endpoint works without Authorization header.""" + + cluster_info = setup_cluster_with_token_auth + + # Call WITHOUT any authorization header - should still succeed + response = requests.get( + f"{cluster_info['dashboard_url']}/api/authentication_mode", + headers={}, # Explicitly no auth + ) + + # Should succeed even with token auth enabled + assert response.status_code == 200 + assert response.json() == {"authentication_mode": "token"} + + if __name__ == "__main__": sys.exit(pytest.main(["-vv", __file__])) From 0d3316c49d529b8bec8da7a9d0a5b9a901e2d5d1 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 08:08:07 +0000 Subject: [PATCH 02/26] fix lint issues Signed-off-by: sampan --- python/ray/dashboard/client/src/App.tsx | 2 +- .../src/authentication/TokenAuthenticationDialog.test.tsx | 2 +- .../client/src/authentication/TokenAuthenticationDialog.tsx | 4 ++-- python/ray/dashboard/client/src/service/event.ts | 2 +- python/ray/dashboard/http_server_head.py | 6 ------ 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/ray/dashboard/client/src/App.tsx b/python/ray/dashboard/client/src/App.tsx index 4df0f4ba40ce..db8b7c67c3d6 100644 --- a/python/ray/dashboard/client/src/App.tsx +++ b/python/ray/dashboard/client/src/App.tsx @@ -4,7 +4,6 @@ import dayjs from "dayjs"; import duration from "dayjs/plugin/duration"; import React, { Suspense, useEffect, useState } from "react"; import { HashRouter, Navigate, Route, Routes } from "react-router-dom"; -import TokenAuthenticationDialog from "./authentication/TokenAuthenticationDialog"; import { getAuthenticationMode, testTokenValidity, @@ -13,6 +12,7 @@ import { getAuthenticationToken, setAuthenticationToken, } from "./authentication/cookies"; +import TokenAuthenticationDialog from "./authentication/TokenAuthenticationDialog"; import ActorDetailPage, { ActorDetailLayout } from "./pages/actor/ActorDetail"; import { ActorLayout } from "./pages/actor/ActorLayout"; import Loading from "./pages/exception/Loading"; diff --git a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx index 5eea6cdc1d25..ae67f36cbf7f 100644 --- a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx +++ b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx @@ -186,7 +186,7 @@ describe("TokenAuthenticationDialog", () => { }); it("does not render when open is false", () => { - const { container } = render( + render( { if (jobId) { diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 502a5247828a..46d427905aa0 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -163,12 +163,6 @@ async def get_timezone(self, req) -> aiohttp.web.Response: @routes.get("/api/authentication_mode") async def get_authentication_mode(self, req) -> aiohttp.web.Response: - """Get the current authentication mode. - - Returns: - JSON response with authentication_mode field. - Possible values: "disabled", "token" - """ try: mode = get_authentication_mode() if mode == AuthenticationMode.TOKEN: From 7660ade4c42e09041c1510e066071486d4d0478c Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 08:35:29 +0000 Subject: [PATCH 03/26] address comments Signed-off-by: sampan --- python/ray/dashboard/client/src/App.tsx | 10 +- .../TokenAuthenticationDialog.test.tsx | 8 +- .../TokenAuthenticationDialog.tsx | 190 +++++++++--------- .../src/authentication/authentication.ts | 10 +- .../client/src/authentication/constants.ts | 9 + .../client/src/authentication/cookies.ts | 6 +- .../client/src/service/requestHandlers.ts | 3 +- python/ray/dashboard/http_server_head.py | 37 +--- 8 files changed, 135 insertions(+), 138 deletions(-) create mode 100644 python/ray/dashboard/client/src/authentication/constants.ts diff --git a/python/ray/dashboard/client/src/App.tsx b/python/ray/dashboard/client/src/App.tsx index db8b7c67c3d6..ddb8164d3c9e 100644 --- a/python/ray/dashboard/client/src/App.tsx +++ b/python/ray/dashboard/client/src/App.tsx @@ -8,6 +8,7 @@ import { getAuthenticationMode, testTokenValidity, } from "./authentication/authentication"; +import { AUTHENTICATION_ERROR_EVENT } from "./authentication/constants"; import { getAuthenticationToken, setAuthenticationToken, @@ -162,9 +163,8 @@ const App = () => { useState(false); const [hasAttemptedAuthentication, setHasAttemptedAuthentication] = useState(false); - const [authenticationError, setAuthenticationError] = useState< - string | undefined - >(); + const [authenticationError, setAuthenticationError] = + useState(); useEffect(() => { getNodeList().then((res) => { if (res?.data?.data?.summary) { @@ -272,13 +272,13 @@ const App = () => { }; window.addEventListener( - "ray-authentication-error", + AUTHENTICATION_ERROR_EVENT, handleAuthenticationError, ); return () => { window.removeEventListener( - "ray-authentication-error", + AUTHENTICATION_ERROR_EVENT, handleAuthenticationError, ); }; diff --git a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx index ae67f36cbf7f..bf7a0c0419b3 100644 --- a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx +++ b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.test.tsx @@ -20,7 +20,9 @@ describe("TokenAuthenticationDialog", () => { />, ); - expect(screen.getByText("Token Authentication Required")).toBeInTheDocument(); + expect( + screen.getByText("Token Authentication Required"), + ).toBeInTheDocument(); expect( screen.getByText(/token authentication is enabled for this cluster/i), ).toBeInTheDocument(); @@ -35,7 +37,9 @@ describe("TokenAuthenticationDialog", () => { />, ); - expect(screen.getByText("Token Authentication Required")).toBeInTheDocument(); + expect( + screen.getByText("Token Authentication Required"), + ).toBeInTheDocument(); expect( screen.getByText(/authentication token is invalid or has expired/i), ).toBeInTheDocument(); diff --git a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx index 940b1b8b7479..e260b1d49bf3 100644 --- a/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx +++ b/python/ray/dashboard/client/src/authentication/TokenAuthenticationDialog.tsx @@ -37,110 +37,106 @@ export type TokenAuthenticationDialogProps = { * (hasExistingToken=false) or if a previously stored token was rejected * (hasExistingToken=true). */ -export const TokenAuthenticationDialog: React.FC< - TokenAuthenticationDialogProps -> = ({ open, hasExistingToken, onSubmit, error }) => { - const [token, setToken] = useState(""); - const [showToken, setShowToken] = useState(false); - const [isSubmitting, setIsSubmitting] = useState(false); +export const TokenAuthenticationDialog: React.FC = + ({ open, hasExistingToken, onSubmit, error }) => { + const [token, setToken] = useState(""); + const [showToken, setShowToken] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); - const handleSubmit = async () => { - if (!token.trim()) { - return; - } + const handleSubmit = async () => { + if (!token.trim()) { + return; + } - setIsSubmitting(true); - try { - await onSubmit(token.trim()); - // If successful, the parent component will close the dialog - // and likely reload the page - } catch (err) { - // Error is handled by parent component via the error prop - console.error("Failed to submit token:", err); - } finally { - setIsSubmitting(false); - } - }; + setIsSubmitting(true); + try { + await onSubmit(token.trim()); + // If successful, the parent component will close the dialog + // and likely reload the page + } finally { + setIsSubmitting(false); + } + }; - const handleKeyPress = (event: React.KeyboardEvent) => { - if (event.key === "Enter" && !isSubmitting) { - handleSubmit(); - } - }; + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === "Enter" && !isSubmitting) { + handleSubmit(); + } + }; - const toggleShowToken = () => { - setShowToken(!showToken); - }; + const toggleShowToken = () => { + setShowToken(!showToken); + }; - // Different messages based on whether this is initial auth or re-auth - const title = "Token Authentication Required"; - const message = hasExistingToken - ? "The authentication token is invalid or has expired. Please provide a valid authentication token." - : "Token authentication is enabled for this cluster. Please provide a valid authentication token."; + // Different messages based on whether this is initial auth or re-auth + const title = "Token Authentication Required"; + const message = hasExistingToken + ? "The authentication token is invalid or has expired. Please provide a valid authentication token." + : "Token authentication is enabled for this cluster. Please provide a valid authentication token."; - return ( - - {title} - - - {message} - + return ( + + {title} + + + {message} + - {error && ( - - {error} - - )} + {error && ( + + {error} + + )} - setToken(e.target.value)} - onKeyPress={handleKeyPress} - disabled={isSubmitting} - placeholder="Enter your authentication token" - InputProps={{ - endAdornment: ( - - - {showToken ? : } - - - ), - }} - /> - - - - - - ); -}; + setToken(e.target.value)} + onKeyDown={handleKeyDown} + disabled={isSubmitting} + placeholder="Enter your authentication token" + InputProps={{ + endAdornment: ( + + + {showToken ? : } + + + ), + }} + /> + + + + + + ); + }; export default TokenAuthenticationDialog; diff --git a/python/ray/dashboard/client/src/authentication/authentication.ts b/python/ray/dashboard/client/src/authentication/authentication.ts index fd47d4d28f56..c7579a1995fa 100644 --- a/python/ray/dashboard/client/src/authentication/authentication.ts +++ b/python/ray/dashboard/client/src/authentication/authentication.ts @@ -4,6 +4,7 @@ */ import axios from "axios"; +import { formatUrl, get } from "../service/requestHandlers"; /** * Response type for authentication mode endpoint. @@ -20,7 +21,7 @@ export type AuthenticationModeResponse = { */ export const getAuthenticationMode = async (): Promise => { - const response = await axios.get( + const response = await get( "/api/authentication_mode", ); return response.data; @@ -30,12 +31,17 @@ export const getAuthenticationMode = * Test if a token is valid by making a request to the /api/version endpoint * which is fast and reliable. * + * Note: This uses plain axios (not axiosInstance) to avoid the request interceptor + * that would add the token from cookies, since we want to test the specific token + * passed as a parameter. It also avoids the response interceptor that would dispatch + * global authentication error events, since we handle 401/403 errors locally. + * * @param token - The authentication token to test * @returns Promise resolving to true if token is valid, false otherwise */ export const testTokenValidity = async (token: string): Promise => { try { - await axios.get("/api/version", { + await axios.get(formatUrl("/api/version"), { headers: { Authorization: `Bearer ${token}` }, }); return true; diff --git a/python/ray/dashboard/client/src/authentication/constants.ts b/python/ray/dashboard/client/src/authentication/constants.ts new file mode 100644 index 000000000000..fce013e5f30d --- /dev/null +++ b/python/ray/dashboard/client/src/authentication/constants.ts @@ -0,0 +1,9 @@ +/** + * Authentication-related constants for the Ray dashboard. + */ + +/** + * Event name dispatched when an authentication error occurs (401 or 403). + * Listened to by App.tsx to show the authentication dialog. + */ +export const AUTHENTICATION_ERROR_EVENT = "ray-authentication-error"; diff --git a/python/ray/dashboard/client/src/authentication/cookies.ts b/python/ray/dashboard/client/src/authentication/cookies.ts index 929a30030738..12180de6b973 100644 --- a/python/ray/dashboard/client/src/authentication/cookies.ts +++ b/python/ray/dashboard/client/src/authentication/cookies.ts @@ -33,11 +33,7 @@ export const getCookie = (name: string): string | null => { * @param value - The value to store in the cookie * @param days - Number of days until the cookie expires (default: 30) */ -export const setCookie = ( - name: string, - value: string, - days = 30, -): void => { +export const setCookie = (name: string, value: string, days = 30): void => { let expires = ""; if (days) { const date = new Date(); diff --git a/python/ray/dashboard/client/src/service/requestHandlers.ts b/python/ray/dashboard/client/src/service/requestHandlers.ts index f601a3f522e4..5addbaf518a8 100644 --- a/python/ray/dashboard/client/src/service/requestHandlers.ts +++ b/python/ray/dashboard/client/src/service/requestHandlers.ts @@ -9,6 +9,7 @@ */ import axios, { AxiosRequestConfig, AxiosResponse } from "axios"; +import { AUTHENTICATION_ERROR_EVENT } from "../authentication/constants"; import { getAuthenticationToken } from "../authentication/cookies"; /** @@ -61,7 +62,7 @@ axiosInstance.interceptors.response.use( // Dispatch custom event for authentication error window.dispatchEvent( - new CustomEvent("ray-authentication-error", { + new CustomEvent(AUTHENTICATION_ERROR_EVENT, { detail: { hadToken }, }), ); diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 46d427905aa0..d1a8e9b9e5b3 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -189,34 +189,19 @@ async def auth_middleware(self, request, handler): if not auth_utils.is_token_auth_enabled(): return await handler(request) - # Public endpoints that don't require authentication - # These endpoints are needed to check authentication status or serve static content - public_endpoints = [ + # Public endpoints that don't require authentication. + # These are needed for the dashboard to load and request an auth token. + public_exact_paths = { + "/", # Root index.html + "/favicon.ico", "/api/authentication_mode", - ] - - # Public paths (using startswith for path prefixes) - public_path_prefixes = [ + } + public_path_prefixes = ( "/static/", # Static assets (JS, CSS, images) - ] - - # Public exact paths - public_exact_paths = [ - "/", # Root index.html - "/favicon.ico", # Favicon - ] - - # Skip authentication for public endpoints - if request.path in public_endpoints: - return await handler(request) - - # Skip authentication for public path prefixes - for prefix in public_path_prefixes: - if request.path.startswith(prefix): - return await handler(request) - - # Skip authentication for public exact paths - if request.path in public_exact_paths: + ) + if request.path in public_exact_paths or request.path.startswith( + public_path_prefixes + ): return await handler(request) # Extract and validate token From 5234bfe99f4f387dfc6d70183b6145a84522f208 Mon Sep 17 00:00:00 2001 From: sampan Date: Mon, 3 Nov 2025 08:36:10 +0000 Subject: [PATCH 04/26] fix lint Signed-off-by: sampan --- python/ray/dashboard/http_server_head.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index d1a8e9b9e5b3..f3c04de54d8e 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -196,9 +196,7 @@ async def auth_middleware(self, request, handler): "/favicon.ico", "/api/authentication_mode", } - public_path_prefixes = ( - "/static/", # Static assets (JS, CSS, images) - ) + public_path_prefixes = ("/static/",) # Static assets (JS, CSS, images) if request.path in public_exact_paths or request.path.startswith( public_path_prefixes ): From f61c27f53a1d7d16c164dd58347faf400d34e49f Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 4 Nov 2025 04:35:58 +0000 Subject: [PATCH 05/26] fix typo Signed-off-by: sampan --- python/ray/_private/authentication/http_token_authentication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 277e0bf0ba41..d6feef61bc33 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -21,6 +21,7 @@ def get_token_auth_middleware( whitelisted_path_prefixes: List of path prefixes that don't require authentication Returns: An aiohttp middleware function + """ @aiohttp_module.web.middleware async def token_auth_middleware(request, handler): From 2da0b9b10b211183862ac9e5689b817d7cce9470 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 4 Nov 2025 06:31:22 +0000 Subject: [PATCH 06/26] [core] Configure an interceptor to pass auth token in python direct grpc calls Signed-off-by: sampan --- .../authentication}/authentication_utils.py | 0 .../grpc_authentication_client_interceptor.py | 170 ++++++++++++++++++ .../http_token_authentication.py | 2 +- python/ray/_private/utils.py | 33 +++- .../ray/tests/test_token_auth_integration.py | 54 ++++++ 5 files changed, 256 insertions(+), 3 deletions(-) rename python/ray/{dashboard => _private/authentication}/authentication_utils.py (100%) create mode 100644 python/ray/_private/authentication/grpc_authentication_client_interceptor.py diff --git a/python/ray/dashboard/authentication_utils.py b/python/ray/_private/authentication/authentication_utils.py similarity index 100% rename from python/ray/dashboard/authentication_utils.py rename to python/ray/_private/authentication/authentication_utils.py diff --git a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py new file mode 100644 index 000000000000..188a125af469 --- /dev/null +++ b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py @@ -0,0 +1,170 @@ +"""gRPC client interceptor for token-based authentication.""" + +import logging +from collections import namedtuple +from typing import Tuple + +import grpc +from grpc import aio as aiogrpc + +from ray._private.authentication import authentication_utils +from ray._private.authentication.authentication_constants import ( + AUTHORIZATION_HEADER_NAME, +) +from ray._raylet import AuthenticationTokenLoader + +logger = logging.getLogger(__name__) + + +# Named tuple to hold client call details +_ClientCallDetails = namedtuple( + "_ClientCallDetails", + ("method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"), +) + + +def _get_authentication_metadata_tuple() -> Tuple[Tuple[str, str], ...]: + """Get gRPC metadata tuple for authentication. Currently only supported for token authentication. + + Returns: + tuple: Empty tuple or ((AUTHORIZATION_HEADER_NAME, "Bearer "),) + """ + token_loader = AuthenticationTokenLoader.instance() + if not token_loader.has_token(): + return () + + headers = token_loader.get_token_for_http_header() + + # Convert HTTP header dict to gRPC metadata tuple + # gRPC expects: (("key", "value"), ...) + return tuple((k, v) for k, v in headers.items()) + + +class AuthenticationMetadataClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + """Synchronous gRPC client interceptor that adds authentication metadata.""" + + def intercept_unary_unary(self, continuation, client_call_details, request): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return continuation(new_details, request) + + def intercept_unary_stream(self, continuation, client_call_details, request): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return continuation(new_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return continuation(new_details, request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return continuation(new_details, request_iterator) + + +class AsyncAuthenticationMetadataClientInterceptor( + aiogrpc.UnaryUnaryClientInterceptor, + aiogrpc.UnaryStreamClientInterceptor, + aiogrpc.StreamUnaryClientInterceptor, + aiogrpc.StreamStreamClientInterceptor, +): + """Async gRPC client interceptor that adds authentication metadata.""" + + async def intercept_unary_unary(self, continuation, client_call_details, request): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return await continuation(new_details, request) + + async def intercept_unary_stream(self, continuation, client_call_details, request): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return await continuation(new_details, request) + + async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return await continuation(new_details, request_iterator) + + async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + metadata = list(client_call_details.metadata or []) + metadata.extend(_get_authentication_metadata_tuple()) + + new_details = _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=getattr(client_call_details, "wait_for_ready", None), + compression=getattr(client_call_details, "compression", None), + ) + return await continuation(new_details, request_iterator) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index d6feef61bc33..1be2125e0d26 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from ray._private.authentication import authentication_constants -from ray.dashboard import authentication_utils as auth_utils +from ray._private.authentication import authentication_utils as auth_utils logger = logging.getLogger(__name__) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 6e5c6dfb0544..ce80946b7796 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1026,6 +1026,7 @@ def init_grpc_channel( import grpc from grpc import aio as aiogrpc + from ray._private.authentication import authentication_utils from ray._private.tls_utils import load_certs_from_env grpc_module = aiogrpc if asynchronous else grpc @@ -1040,6 +1041,20 @@ def init_grpc_channel( ) options = options_dict.items() + # Build interceptors list + interceptors = [] + if authentication_utils.is_token_auth_enabled(): + from ray._private.authentication.grpc_authentication_client_interceptor import ( + AsyncAuthenticationMetadataClientInterceptor, + AuthenticationMetadataClientInterceptor, + ) + + if asynchronous: + interceptors.append(AsyncAuthenticationMetadataClientInterceptor()) + else: + interceptors.append(AuthenticationMetadataClientInterceptor()) + + # Create channel with TLS if enabled if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( @@ -1047,9 +1062,23 @@ def init_grpc_channel( private_key=private_key, root_certificates=ca_cert, ) - channel = grpc_module.secure_channel(address, credentials, options=options) + if asynchronous: + channel = grpc_module.secure_channel( + address, credentials, options=options, interceptors=interceptors + ) + else: + channel = grpc_module.secure_channel(address, credentials, options=options) else: - channel = grpc_module.insecure_channel(address, options=options) + if asynchronous: + channel = grpc_module.insecure_channel( + address, options=options, interceptors=interceptors + ) + else: + channel = grpc_module.insecure_channel(address, options=options) + + # Apply interceptors for sync channels (async channels get them in constructor) + if not asynchronous and interceptors: + channel = grpc.intercept_channel(channel, *interceptors) return channel diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 9765d7fd1705..076d96592db6 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -342,5 +342,59 @@ def worker_joined(): _cleanup_ray_start(env) +def test_dashboard_with_token_auth_integration(setup_cluster_with_token_auth): + """Test that dashboard components work with token authentication. + + This verifies that with token auth enabled: + 1. Job submission works + 2. Tasks execute successfully + 3. Actors can be created and called + """ + cluster_info = setup_cluster_with_token_auth + + # Test 1: Submit a simple task + @ray.remote + def simple_task(x): + return x + 1 + + result = ray.get(simple_task.remote(41)) + assert result == 42, f"Task should return 42, got {result}" + + # Test 2: Create and use an actor + @ray.remote + class SimpleActor: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + actor = SimpleActor.remote() + result = ray.get(actor.increment.remote()) + assert result == 1, f"Actor method should return 1, got {result}" + + # Test 3: Submit a job and wait for completion + from ray.job_submission import JobSubmissionClient + + # Create job submission client (uses HTTP with auth headers) + client = JobSubmissionClient(address=cluster_info["dashboard_url"]) + + # Submit a simple job + job_id = client.submit_job( + entrypoint="echo 'Hello from job'", + ) + + # Wait for job to complete + def job_finished(): + status = client.get_job_status(job_id) + return status in ["SUCCEEDED", "FAILED", "STOPPED"] + + wait_for_condition(job_finished, timeout=30) + + final_status = client.get_job_status(job_id) + assert final_status == "SUCCEEDED", f"Job should succeed, got status: {final_status}" + + if __name__ == "__main__": sys.exit(pytest.main(["-vv", __file__])) From 1f1897d9c03c2e4a01a8456d9e70dc664e1cd507 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 4 Nov 2025 06:32:34 +0000 Subject: [PATCH 07/26] fix lint issues Signed-off-by: sampan --- .../grpc_authentication_client_interceptor.py | 20 +++++++++++-------- .../http_token_authentication.py | 6 ++++-- .../ray/tests/test_token_auth_integration.py | 4 +++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py index 188a125af469..987f51a51ea6 100644 --- a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py @@ -7,10 +7,6 @@ import grpc from grpc import aio as aiogrpc -from ray._private.authentication import authentication_utils -from ray._private.authentication.authentication_constants import ( - AUTHORIZATION_HEADER_NAME, -) from ray._raylet import AuthenticationTokenLoader logger = logging.getLogger(__name__) @@ -76,7 +72,9 @@ def intercept_unary_stream(self, continuation, client_call_details, request): ) return continuation(new_details, request) - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) @@ -90,7 +88,9 @@ def intercept_stream_unary(self, continuation, client_call_details, request_iter ) return continuation(new_details, request_iterator) - def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) @@ -141,7 +141,9 @@ async def intercept_unary_stream(self, continuation, client_call_details, reques ) return await continuation(new_details, request) - async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) @@ -155,7 +157,9 @@ async def intercept_stream_unary(self, continuation, client_call_details, reques ) return await continuation(new_details, request_iterator) - async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + async def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) diff --git a/python/ray/_private/authentication/http_token_authentication.py b/python/ray/_private/authentication/http_token_authentication.py index 1be2125e0d26..70e234a6d2cc 100644 --- a/python/ray/_private/authentication/http_token_authentication.py +++ b/python/ray/_private/authentication/http_token_authentication.py @@ -2,8 +2,10 @@ from types import ModuleType from typing import Dict, List, Optional -from ray._private.authentication import authentication_constants -from ray._private.authentication import authentication_utils as auth_utils +from ray._private.authentication import ( + authentication_constants, + authentication_utils as auth_utils, +) logger = logging.getLogger(__name__) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 076d96592db6..9aa82438098e 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -393,7 +393,9 @@ def job_finished(): wait_for_condition(job_finished, timeout=30) final_status = client.get_job_status(job_id) - assert final_status == "SUCCEEDED", f"Job should succeed, got status: {final_status}" + assert ( + final_status == "SUCCEEDED" + ), f"Job should succeed, got status: {final_status}" if __name__ == "__main__": From 6aa4ce044b0f0e2bd2cf295c245ed751a0daeb52 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 4 Nov 2025 07:00:48 +0000 Subject: [PATCH 08/26] reduce code duplication Signed-off-by: sampan --- .../grpc_authentication_client_interceptor.py | 88 ++++--------------- python/ray/_private/utils.py | 31 ++++--- 2 files changed, 33 insertions(+), 86 deletions(-) diff --git a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py index 987f51a51ea6..9a0800a90a56 100644 --- a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py @@ -44,11 +44,12 @@ class AuthenticationMetadataClientInterceptor( ): """Synchronous gRPC client interceptor that adds authentication metadata.""" - def intercept_unary_unary(self, continuation, client_call_details, request): + def _intercept_call_details(self, client_call_details): + """Helper method to add authentication metadata to client call details.""" metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) - new_details = _ClientCallDetails( + return _ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, metadata=metadata, @@ -56,52 +57,25 @@ def intercept_unary_unary(self, continuation, client_call_details, request): wait_for_ready=getattr(client_call_details, "wait_for_ready", None), compression=getattr(client_call_details, "compression", None), ) + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details = self._intercept_call_details(client_call_details) return continuation(new_details, request) def intercept_unary_stream(self, continuation, client_call_details, request): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return continuation(new_details, request) def intercept_stream_unary( self, continuation, client_call_details, request_iterator ): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return continuation(new_details, request_iterator) def intercept_stream_stream( self, continuation, client_call_details, request_iterator ): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return continuation(new_details, request_iterator) @@ -113,11 +87,12 @@ class AsyncAuthenticationMetadataClientInterceptor( ): """Async gRPC client interceptor that adds authentication metadata.""" - async def intercept_unary_unary(self, continuation, client_call_details, request): + def _intercept_call_details(self, client_call_details): + """Helper method to add authentication metadata to client call details.""" metadata = list(client_call_details.metadata or []) metadata.extend(_get_authentication_metadata_tuple()) - new_details = _ClientCallDetails( + return _ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, metadata=metadata, @@ -125,50 +100,23 @@ async def intercept_unary_unary(self, continuation, client_call_details, request wait_for_ready=getattr(client_call_details, "wait_for_ready", None), compression=getattr(client_call_details, "compression", None), ) + + async def intercept_unary_unary(self, continuation, client_call_details, request): + new_details = self._intercept_call_details(client_call_details) return await continuation(new_details, request) async def intercept_unary_stream(self, continuation, client_call_details, request): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return await continuation(new_details, request) async def intercept_stream_unary( self, continuation, client_call_details, request_iterator ): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return await continuation(new_details, request_iterator) async def intercept_stream_stream( self, continuation, client_call_details, request_iterator ): - metadata = list(client_call_details.metadata or []) - metadata.extend(_get_authentication_metadata_tuple()) - - new_details = _ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=getattr(client_call_details, "wait_for_ready", None), - compression=getattr(client_call_details, "compression", None), - ) + new_details = self._intercept_call_details(client_call_details) return await continuation(new_details, request_iterator) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index ce80946b7796..5bda84aad204 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1055,30 +1055,29 @@ def init_grpc_channel( interceptors.append(AuthenticationMetadataClientInterceptor()) # Create channel with TLS if enabled - if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + use_tls = os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true") + if use_tls: server_cert_chain, private_key, ca_cert = load_certs_from_env() credentials = grpc.ssl_channel_credentials( certificate_chain=server_cert_chain, private_key=private_key, root_certificates=ca_cert, ) - if asynchronous: - channel = grpc_module.secure_channel( - address, credentials, options=options, interceptors=interceptors - ) - else: - channel = grpc_module.secure_channel(address, credentials, options=options) + channel_creator = grpc_module.secure_channel + base_args = (address, credentials) else: - if asynchronous: - channel = grpc_module.insecure_channel( - address, options=options, interceptors=interceptors - ) - else: - channel = grpc_module.insecure_channel(address, options=options) + channel_creator = grpc_module.insecure_channel + base_args = (address,) - # Apply interceptors for sync channels (async channels get them in constructor) - if not asynchronous and interceptors: - channel = grpc.intercept_channel(channel, *interceptors) + # Create channel (async channels get interceptors in constructor, sync via intercept_channel) + if asynchronous: + channel = channel_creator( + *base_args, options=options, interceptors=interceptors + ) + else: + channel = channel_creator(*base_args, options=options) + if interceptors: + channel = grpc.intercept_channel(channel, *interceptors) return channel From 8186c0902bafbeb165268ae81822b7b2906926fb Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 4 Nov 2025 08:56:54 +0000 Subject: [PATCH 09/26] empty commit Signed-off-by: sampan From d90ffd3c2f498417238691c068bc5b1126d92bcd Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 5 Nov 2025 04:14:41 +0000 Subject: [PATCH 10/26] [Core] Add Service Interceptor to support token authentication in dashboard agent Signed-off-by: sampan --- .../grpc_authentication_server_interceptor.py | 132 ++++++++++++ python/ray/dashboard/agent.py | 14 +- ..._grpc_authentication_server_interceptor.py | 195 ++++++++++++++++++ 3 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 python/ray/_private/authentication/grpc_authentication_server_interceptor.py create mode 100644 python/ray/tests/test_grpc_authentication_server_interceptor.py diff --git a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py new file mode 100644 index 000000000000..afed3f660854 --- /dev/null +++ b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py @@ -0,0 +1,132 @@ +"""gRPC server interceptor for token-based authentication.""" + +import logging +from typing import Awaitable, Callable + +import grpc +from grpc import aio as aiogrpc + +from ray._private.authentication.authentication_constants import ( + AUTHORIZATION_HEADER_NAME, +) +from ray._private.authentication.authentication_utils import ( + is_token_auth_enabled, + validate_request_token, +) + +logger = logging.getLogger(__name__) + + +class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor): + """Async gRPC server interceptor that validates authentication tokens. + + This interceptor checks the "authorization" metadata header for a valid + Bearer token when token authentication is enabled via RAY_auth_mode=token. + If the token is missing or invalid, the request is rejected with UNAUTHENTICATED status. + """ + + def _validate_authentication(self, metadata: tuple) -> bool: + """Validate authentication token from gRPC metadata. + + Args: + metadata: gRPC metadata tuple of (key, value) pairs + + Returns: + True if authentication succeeds or is not required, False otherwise + """ + # If token auth is not enabled, allow all requests + if not is_token_auth_enabled(): + return True + + # Extract authorization header from metadata + auth_header = None + for key, value in metadata: + if key.lower() == AUTHORIZATION_HEADER_NAME: + auth_header = value + break + + if not auth_header: + logger.warning( + "Authentication required but no authorization header provided" + ) + return False + + # Validate the token format and value + # validate_request_token returns bool (True if valid, False otherwise) + return validate_request_token(auth_header) + + async def intercept_service( + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler: + """Intercept service calls to validate authentication. + + This method is called once per RPC to get the handler. We wrap the handler + to validate authentication before executing the actual RPC method. + """ + # Get the actual handler + handler = await continuation(handler_call_details) + + if handler is None: + return None + + # Wrap the RPC behavior with authentication check + def wrap_rpc_behavior(behavior): + """Wrap an RPC method to validate authentication first.""" + if behavior is None: + return None + + async def wrapped(request_or_iterator, context): + if not self._validate_authentication(context.invocation_metadata()): + await context.abort( + grpc.StatusCode.UNAUTHENTICATED, + "Invalid or missing authentication token", + ) + return await behavior(request_or_iterator, context) + + return wrapped + + # Create a wrapper class that implements RpcMethodHandler interface + class AuthenticatedHandler: + """Wrapper handler that validates authentication.""" + + def __init__(self, original_handler, wrapper_func): + self._original = original_handler + self._wrap = wrapper_func + + @property + def request_streaming(self): + return self._original.request_streaming + + @property + def response_streaming(self): + return self._original.response_streaming + + @property + def request_deserializer(self): + return self._original.request_deserializer + + @property + def response_serializer(self): + return self._original.response_serializer + + @property + def unary_unary(self): + return self._wrap(self._original.unary_unary) + + @property + def unary_stream(self): + return self._wrap(self._original.unary_stream) + + @property + def stream_unary(self): + return self._wrap(self._original.stream_unary) + + @property + def stream_stream(self): + return self._wrap(self._original.stream_stream) + + return AuthenticatedHandler(handler, wrap_rpc_behavior) diff --git a/python/ray/dashboard/agent.py b/python/ray/dashboard/agent.py index 6b95ad4d1444..355ac03d9aa5 100644 --- a/python/ray/dashboard/agent.py +++ b/python/ray/dashboard/agent.py @@ -82,6 +82,12 @@ def __init__( def _init_non_minimal(self): from grpc import aio as aiogrpc + from ray._private.authentication.authentication_utils import ( + is_token_auth_enabled, + ) + from ray._private.authentication.grpc_authentication_server_interceptor import ( + AsyncAuthenticationServerInterceptor, + ) from ray._private.tls_utils import add_port_to_grpc_server from ray.dashboard.http_server_agent import HttpServerAgent @@ -98,7 +104,13 @@ def _init_non_minimal(self): else: aiogrpc.init_grpc_aio() + # Add authentication interceptor if token auth is enabled + interceptors = [] + if is_token_auth_enabled(): + interceptors.append(AsyncAuthenticationServerInterceptor()) + self.server = aiogrpc.server( + interceptors=interceptors, options=( ("grpc.so_reuseport", 0), ( @@ -109,7 +121,7 @@ def _init_non_minimal(self): "grpc.max_receive_message_length", AGENT_GRPC_MAX_MESSAGE_LENGTH, ), - ) # noqa + ), # noqa ) try: add_port_to_grpc_server(self.server, build_address(self.ip, self.grpc_port)) diff --git a/python/ray/tests/test_grpc_authentication_server_interceptor.py b/python/ray/tests/test_grpc_authentication_server_interceptor.py new file mode 100644 index 000000000000..cfea7cfd48a2 --- /dev/null +++ b/python/ray/tests/test_grpc_authentication_server_interceptor.py @@ -0,0 +1,195 @@ +"""Unit tests for gRPC server authentication interceptor.""" + +import uuid + +import grpc +import pytest +from grpc import aio as aiogrpc + +from ray._private.authentication.grpc_authentication_server_interceptor import ( + AsyncAuthenticationServerInterceptor, +) + +# Create a simple test service for testing +from ray.core.generated import reporter_pb2, reporter_pb2_grpc +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + reset_auth_token_state, + set_auth_mode, + set_env_auth_token, +) + + +class TestReporterServicer(reporter_pb2_grpc.ReporterServiceServicer): + """Simple test servicer for testing authentication.""" + + async def HealthCheck(self, request, context): + """Return a health check response.""" + return reporter_pb2.HealthCheckReply() + + +@pytest.fixture +async def auth_server_and_port(): + """Create a gRPC server with authentication interceptor.""" + interceptor = AsyncAuthenticationServerInterceptor() + server = aiogrpc.server(interceptors=[interceptor]) + + servicer = TestReporterServicer() + reporter_pb2_grpc.add_ReporterServiceServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + await server.start() + + yield server, port + + await server.stop(grace=1) + + +@pytest.fixture +async def no_auth_server_and_port(): + """Create a gRPC server without authentication interceptor.""" + server = aiogrpc.server() + + servicer = TestReporterServicer() + reporter_pb2_grpc.add_ReporterServiceServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + await server.start() + + yield server, port + + await server.stop(grace=1) + + +@pytest.mark.asyncio +async def test_server_interceptor_allows_valid_token(auth_server_and_port): + """Test that server interceptor allows requests with valid tokens.""" + with authentication_env_guard(): + # Set up token authentication + token = uuid.uuid4().hex + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + # Get server from fixture + _, port = auth_server_and_port + + # Create client with valid token in metadata + async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + # Add valid token to metadata + metadata = (("authorization", f"Bearer {token}"),) + + request = reporter_pb2.HealthCheckRequest() + response = await stub.HealthCheck(request, metadata=metadata) + + # Should succeed (response exists and is not None) + assert response is not None + + +@pytest.mark.asyncio +async def test_server_interceptor_rejects_invalid_token(auth_server_and_port): + """Test that server interceptor rejects requests with invalid tokens.""" + with authentication_env_guard(): + # Set up token authentication + correct_token = uuid.uuid4().hex + set_auth_mode("token") + set_env_auth_token(correct_token) + reset_auth_token_state() + + # Get server from fixture + _, port = auth_server_and_port + + # Create client with invalid token in metadata + async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + # Add invalid token to metadata + wrong_token = uuid.uuid4().hex + metadata = (("authorization", f"Bearer {wrong_token}"),) + + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED status + with pytest.raises(grpc.RpcError) as exc_info: + await stub.HealthCheck(request, metadata=metadata) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + + +@pytest.mark.asyncio +async def test_server_interceptor_rejects_missing_token(auth_server_and_port): + """Test that server interceptor rejects requests without tokens.""" + with authentication_env_guard(): + # Set up token authentication + token = uuid.uuid4().hex + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + # Get server from fixture + _, port = auth_server_and_port + + # Create client without token in metadata + async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED status + with pytest.raises(grpc.RpcError) as exc_info: + await stub.HealthCheck(request) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + + +@pytest.mark.asyncio +async def test_server_interceptor_disabled_auth_allows_all(auth_server_and_port): + """Test that when auth is disabled, all requests are allowed.""" + with authentication_env_guard(): + # Set auth mode to disabled (or don't set it at all) + set_auth_mode("disabled") + reset_auth_token_state() + + # Get server from fixture + _, port = auth_server_and_port + + # Create client without any token + async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + request = reporter_pb2.HealthCheckRequest() + response = await stub.HealthCheck(request) + + # Should succeed even without token + assert response is not None + + +@pytest.mark.asyncio +async def test_no_interceptor_allows_all_requests(no_auth_server_and_port): + """Test that server without interceptor allows all requests.""" + with authentication_env_guard(): + # Even with token auth enabled, server without interceptor allows all + token = uuid.uuid4().hex + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + _, port = no_auth_server_and_port + + # Create client without token + async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + request = reporter_pb2.HealthCheckRequest() + response = await stub.HealthCheck(request) + + # Should succeed (no interceptor means no auth check) + assert response is not None + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"])) From b464cd5b51670249195dec641ccc501eb0959f9c Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 5 Nov 2025 05:42:43 +0000 Subject: [PATCH 11/26] [core] Token auth usability improvements Signed-off-by: sampan --- .../authentication/authentication_token_setup.py | 8 ++++---- .../grpc_authentication_server_interceptor.py | 2 +- python/ray/dashboard/client/src/App.tsx | 7 +++++++ python/ray/dashboard/http_server_head.py | 4 ++++ python/ray/tests/authentication_test_utils.py | 4 ++-- python/ray/tests/test_token_auth_integration.py | 6 +++--- src/ray/common/ray_config_def.h | 2 +- .../pubsub/tests/python_gcs_subscriber_auth_test.cc | 8 ++++---- src/ray/ray_syncer/tests/ray_syncer_test.cc | 10 +++++----- src/ray/raylet/tests/runtime_env_agent_client_test.cc | 6 +++--- src/ray/rpc/authentication/authentication_mode.cc | 2 +- src/ray/rpc/tests/authentication_token_loader_test.cc | 6 +++--- src/ray/rpc/tests/grpc_auth_token_tests.cc | 4 ++-- 13 files changed, 40 insertions(+), 29 deletions(-) diff --git a/python/ray/_private/authentication/authentication_token_setup.py b/python/ray/_private/authentication/authentication_token_setup.py index 8ad292430406..f87491a02c65 100644 --- a/python/ray/_private/authentication/authentication_token_setup.py +++ b/python/ray/_private/authentication/authentication_token_setup.py @@ -67,7 +67,7 @@ def ensure_token_if_auth_enabled( 3. Generate and save a default token for new local clusters if one doesn't already exist. Args: - system_config: Ray raises an error if you set auth_mode in system_config instead of the environment. + system_config: Ray raises an error if you set AUTH_MODE in system_config instead of the environment. create_token_if_missing: Generate a new token if one doesn't already exist. Raises: @@ -79,11 +79,11 @@ def ensure_token_if_auth_enabled( if get_authentication_mode() != AuthenticationMode.TOKEN: if ( system_config - and "auth_mode" in system_config - and system_config["auth_mode"] != "disabled" + and "AUTH_MODE" in system_config + and system_config["AUTH_MODE"] != "disabled" ): raise RuntimeError( - "Set authentication mode can only be set with the `RAY_auth_mode` environment variable, not using the system_config." + "Set authentication mode can only be set with the `RAY_AUTH_MODE` environment variable, not using the system_config." ) return diff --git a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py index afed3f660854..18bc4103d303 100644 --- a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py @@ -21,7 +21,7 @@ class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor): """Async gRPC server interceptor that validates authentication tokens. This interceptor checks the "authorization" metadata header for a valid - Bearer token when token authentication is enabled via RAY_auth_mode=token. + Bearer token when token authentication is enabled via RAY_AUTH_MODE=token. If the token is missing or invalid, the request is rejected with UNAUTHENTICATED status. """ diff --git a/python/ray/dashboard/client/src/App.tsx b/python/ray/dashboard/client/src/App.tsx index ddb8164d3c9e..0c900c0dc4d0 100644 --- a/python/ray/dashboard/client/src/App.tsx +++ b/python/ray/dashboard/client/src/App.tsx @@ -10,6 +10,7 @@ import { } from "./authentication/authentication"; import { AUTHENTICATION_ERROR_EVENT } from "./authentication/constants"; import { + clearAuthenticationToken, getAuthenticationToken, setAuthenticationToken, } from "./authentication/cookies"; @@ -252,6 +253,12 @@ const App = () => { } // If token exists, let it be used by interceptor // If invalid, interceptor will trigger dialog via 401/403 + } else { + // Auth mode is disabled - clear any existing token from cookie + const existingToken = getAuthenticationToken(); + if (existingToken) { + clearAuthenticationToken(); + } } } catch (error) { console.error("Failed to check authentication mode:", error); diff --git a/python/ray/dashboard/http_server_head.py b/python/ray/dashboard/http_server_head.py index 593cfbbbb2b4..055c337d03f8 100644 --- a/python/ray/dashboard/http_server_head.py +++ b/python/ray/dashboard/http_server_head.py @@ -272,6 +272,10 @@ async def run( "/", # Root index.html "/favicon.ico", "/api/authentication_mode", + "/api/healthz", # General healthcheck + "/api/gcs_healthz", # GCS health check + "/api/local_raylet_healthz", # Raylet health check + "/-/healthz", # Serve health check } public_path_prefixes = ("/static/",) # Static assets (JS, CSS, images) diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py index e98711b51e88..d69aaa15c163 100644 --- a/python/ray/tests/authentication_test_utils.py +++ b/python/ray/tests/authentication_test_utils.py @@ -8,7 +8,7 @@ from ray._raylet import AuthenticationTokenLoader, Config -_AUTH_ENV_VARS = ("RAY_auth_mode", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") +_AUTH_ENV_VARS = ("RAY_AUTH_MODE", "RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH") _DEFAULT_AUTH_TOKEN_RELATIVE_PATH = Path(".ray") / "auth_token" @@ -22,7 +22,7 @@ def reset_auth_token_state() -> None: def set_auth_mode(mode: str) -> None: """Set the authentication mode environment variable.""" - os.environ["RAY_auth_mode"] = mode + os.environ["RAY_AUTH_MODE"] = mode def set_env_auth_token(token: str) -> None: diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 9aa82438098e..65ea0c95bf07 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -222,7 +222,7 @@ def test_ray_start_without_token_raises_error(is_head, request): """Test that ray start fails when auth_mode=token but no token exists.""" # Set up environment with token auth enabled but no token env = os.environ.copy() - env["RAY_auth_mode"] = "token" + env["RAY_AUTH_MODE"] = "token" env.pop("RAY_AUTH_TOKEN", None) env.pop("RAY_AUTH_TOKEN_PATH", None) @@ -253,7 +253,7 @@ def test_ray_start_head_with_token_succeeds(): test_token = "a" * 32 env = os.environ.copy() env["RAY_AUTH_TOKEN"] = test_token - env["RAY_auth_mode"] = "token" + env["RAY_AUTH_MODE"] = "token" try: # Start head node - should succeed @@ -303,7 +303,7 @@ def test_ray_start_address_with_token(token_match, setup_cluster_with_token_auth # Set up environment for worker env = os.environ.copy() - env["RAY_auth_mode"] = "token" + env["RAY_AUTH_MODE"] = "token" if token_match == "correct": env["RAY_AUTH_TOKEN"] = cluster_token diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index e4e8fc1d48ef..05b25b7d22fb 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -39,7 +39,7 @@ RAY_CONFIG(bool, enable_cluster_auth, true) /// will be converted to AuthenticationMode enum defined in /// rpc/authentication/authentication_mode.h /// use GetAuthenticationMode() to get the authentication mode enum value. -RAY_CONFIG(std::string, auth_mode, "disabled") +RAY_CONFIG(std::string, AUTH_MODE, "disabled") /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index c2518d30a07c..f701fb0f2119 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -86,7 +86,7 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { protected: void SetUp() override { // Enable token authentication by default - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); rpc::AuthenticationTokenLoader::instance().ResetCache(); } @@ -97,7 +97,7 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { } unsetenv("RAY_AUTH_TOKEN"); // Reset to default auth mode - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); rpc::AuthenticationTokenLoader::instance().ResetCache(); } @@ -136,10 +136,10 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { void SetClientToken(const std::string &client_token) { if (!client_token.empty()) { setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); } else { unsetenv("RAY_AUTH_TOKEN"); - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); } rpc::AuthenticationTokenLoader::instance().ResetCache(); } diff --git a/src/ray/ray_syncer/tests/ray_syncer_test.cc b/src/ray/ray_syncer/tests/ray_syncer_test.cc index be2e14ae53f7..a7aa857cc336 100644 --- a/src/ray/ray_syncer/tests/ray_syncer_test.cc +++ b/src/ray/ray_syncer/tests/ray_syncer_test.cc @@ -995,13 +995,13 @@ class SyncerAuthenticationTest : public ::testing::Test { // Clear any existing environment variables and reset state unsetenv("RAY_AUTH_TOKEN"); ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); - RayConfig::instance().auth_mode() = "disabled"; + RayConfig::instance().AUTH_MODE() = "disabled"; } void TearDown() override { unsetenv("RAY_AUTH_TOKEN"); ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); - RayConfig::instance().auth_mode() = "disabled"; + RayConfig::instance().AUTH_MODE() = "disabled"; } struct AuthenticatedSyncerServerTest { @@ -1085,7 +1085,7 @@ TEST_F(SyncerAuthenticationTest, MatchingTokens) { // Set client token via environment variable setenv("RAY_AUTH_TOKEN", test_token.c_str(), 1); // Enable token authentication - RayConfig::instance().auth_mode() = "token"; + RayConfig::instance().AUTH_MODE() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create authenticated server @@ -1112,7 +1112,7 @@ TEST_F(SyncerAuthenticationTest, MismatchedTokens) { // Set client token via environment variable setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); // Enable token authentication - RayConfig::instance().auth_mode() = "token"; + RayConfig::instance().AUTH_MODE() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create authenticated server with different token @@ -1163,7 +1163,7 @@ TEST_F(SyncerAuthenticationTest, ClientHasTokenServerDoesNotRequire) { // Set client token setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); // Enable token authentication - RayConfig::instance().auth_mode() = "token"; + RayConfig::instance().AUTH_MODE() = "token"; ray::rpc::AuthenticationTokenLoader::instance().ResetCache(); // Create server without authentication (empty token) diff --git a/src/ray/raylet/tests/runtime_env_agent_client_test.cc b/src/ray/raylet/tests/runtime_env_agent_client_test.cc index ea5bb4974097..f44422af2e2f 100644 --- a/src/ray/raylet/tests/runtime_env_agent_client_test.cc +++ b/src/ray/raylet/tests/runtime_env_agent_client_test.cc @@ -193,7 +193,7 @@ delay_after(instrumented_io_context &ioc) { auto dummy_shutdown_raylet_gracefully = [](const rpc::NodeDeathInfo &) {}; TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvOK) { - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); unsetenv("RAY_AUTH_TOKEN"); rpc::AuthenticationTokenLoader::instance().ResetCache(); @@ -365,7 +365,7 @@ TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvRetriesOnServerNotStarted) } TEST(RuntimeEnvAgentClientTest, AttachesAuthHeaderWhenEnabled) { - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); setenv("RAY_AUTH_TOKEN", "header_token", 1); rpc::AuthenticationTokenLoader::instance().ResetCache(); @@ -427,7 +427,7 @@ TEST(RuntimeEnvAgentClientTest, AttachesAuthHeaderWhenEnabled) { ASSERT_EQ(called_times, 1); ASSERT_EQ(observed_auth_header, "Bearer header_token"); - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); unsetenv("RAY_AUTH_TOKEN"); rpc::AuthenticationTokenLoader::instance().ResetCache(); } diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc index 1bbe209733ce..7a02865f0b43 100644 --- a/src/ray/rpc/authentication/authentication_mode.cc +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -24,7 +24,7 @@ namespace ray { namespace rpc { AuthenticationMode GetAuthenticationMode() { - std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); + std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().AUTH_MODE()); if (auth_mode_lower == "token") { return AuthenticationMode::TOKEN; diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index b758acd64c05..06f8968718c2 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -45,7 +45,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { protected: void SetUp() override { // Enable token authentication for tests - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory // This ensures tests work in environments where HOME isn't provided @@ -90,7 +90,7 @@ class AuthenticationTokenLoaderTest : public ::testing::Test { // Reset the singleton's cached state for test isolation AuthenticationTokenLoader::instance().ResetCache(); // Disable token auth after tests - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); } void cleanup_env() { @@ -277,7 +277,7 @@ TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { // Disable auth for this specific test - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); AuthenticationTokenLoader::instance().ResetCache(); // No token set anywhere, but auth is disabled diff --git a/src/ray/rpc/tests/grpc_auth_token_tests.cc b/src/ray/rpc/tests/grpc_auth_token_tests.cc index 4499b4c43129..fe9e9cd2571d 100644 --- a/src/ray/rpc/tests/grpc_auth_token_tests.cc +++ b/src/ray/rpc/tests/grpc_auth_token_tests.cc @@ -32,7 +32,7 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { public: void SetUp() override { // Configure token auth via RayConfig - std::string config_json = R"({"auth_mode": "token"})"; + std::string config_json = R"({"AUTH_MODE": "token"})"; RayConfig::instance().initialize(config_json); AuthenticationTokenLoader::instance().ResetCache(); } @@ -44,7 +44,7 @@ class TestGrpcServerClientTokenAuthFixture : public ::testing::Test { if (!client_token.empty()) { setenv("RAY_AUTH_TOKEN", client_token.c_str(), 1); } else { - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); AuthenticationTokenLoader::instance().ResetCache(); unsetenv("RAY_AUTH_TOKEN"); } From 29f511a31da5e19ac236e4032e2f60113142bd92 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 5 Nov 2025 06:27:54 +0000 Subject: [PATCH 12/26] address comment Signed-off-by: sampan --- src/ray/rpc/tests/authentication_token_loader_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc index 06f8968718c2..288b5c72725c 100644 --- a/src/ray/rpc/tests/authentication_token_loader_test.cc +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -288,7 +288,7 @@ TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { EXPECT_FALSE(loader.GetToken().has_value()); // Re-enable for other tests - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); } TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { From 26de367d0eece6884e1e71052934346dce0b4aa6 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 5 Nov 2025 06:30:59 +0000 Subject: [PATCH 13/26] add test_grpc_authentication_server_interceptor to BUILD.bazel Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 337b745184bc..dcda47a261cb 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -56,6 +56,7 @@ py_test_module_list( "test_gcs_utils.py", "test_get_locations.py", "test_global_state.py", + "test_grpc_authentication_server_interceptor", "test_healthcheck.py", "test_metric_cardinality.py", "test_metrics_agent.py", From 2b5e01bc3a2debbc7ab2db280efb13930a8512f1 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 5 Nov 2025 07:16:10 +0000 Subject: [PATCH 14/26] fix typo Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index dcda47a261cb..e28983d2a181 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -56,7 +56,7 @@ py_test_module_list( "test_gcs_utils.py", "test_get_locations.py", "test_global_state.py", - "test_grpc_authentication_server_interceptor", + "test_grpc_authentication_server_interceptor.py", "test_healthcheck.py", "test_metric_cardinality.py", "test_metrics_agent.py", From c0182036d08a52bc17a0305f53825c5a1667b656 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 6 Nov 2025 06:30:56 +0000 Subject: [PATCH 15/26] [core] use client interceptor for adding auth token in c++ client calls Signed-off-by: sampan --- python/ray/autoscaler/v2/tests/test_sdk.py | 73 ++++++++++++- src/ray/pubsub/python_gcs_subscriber.cc | 10 -- src/ray/pubsub/python_gcs_subscriber.h | 4 - src/ray/rpc/BUILD.bazel | 5 +- src/ray/rpc/client_call.h | 8 -- src/ray/rpc/grpc_client.cc | 117 +++++++++++++++++++++ src/ray/rpc/grpc_client.h | 47 +++------ 7 files changed, 206 insertions(+), 58 deletions(-) create mode 100644 src/ray/rpc/grpc_client.cc diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index dc8f96130972..c6d260500267 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -31,6 +31,7 @@ from ray.core.generated.autoscaler_pb2 import ClusterResourceState, NodeStatus from ray.core.generated.common_pb2 import LabelSelectorOperator from ray.util.state.api import list_nodes +from ray.tests import authentication_test_utils def _autoscaler_state_service_stub(): @@ -403,7 +404,6 @@ def verify(): def test_pg_usage_labels(shutdown_only): - ray.init(num_cpus=1) # Create a pg @@ -924,6 +924,77 @@ def verify(): wait_for_condition(verify) +@pytest.mark.parametrize( + "token_state,setup_token,should_fail", + [ + ("valid", lambda: None, False), + ("invalid", lambda: _setup_invalid_token(), True), + ], +) +def test_autoscaler_api_with_token_auth( + setup_cluster_with_token_auth, + cleanup_auth_token_env, + token_state, + setup_token, + should_fail, +): + """Parametrized test for autoscaler API with different token states. + + Tests request_cluster_resources with valid, invalid, and missing tokens. + """ + cluster_info = setup_cluster_with_token_auth + cluster = cluster_info["cluster"] + + # Setup token state (this changes the client-side token) + setup_token() + + # Ray is already initialized by the fixture, so just use it + # For invalid token test, this creates a mismatch between client and server tokens + if should_fail: + # API call should fail with invalid token + with pytest.raises(Exception) as exc_info: + request_cluster_resources( + ray.get_runtime_context().gcs_address, + [{"resources": {"CPU": 1}, "label_selector": {}}], + ) + + # Verify it's an authentication error + error_str = str(exc_info.value).lower() + assert ( + "unauthenticated" in error_str or "invalidauthtoken" in error_str + ), ( + f"request_cluster_resources with {token_state} token should return auth error, got: {exc_info.value}" + ) + else: + # API call should succeed with valid token + request_cluster_resources( + ray.get_runtime_context().gcs_address, + [{"resources": {"CPU": 1}, "label_selector": {}}], + ) + + # Verify the request was successful using the autoscaler state service stub + stub = _autoscaler_state_service_stub() + state = get_cluster_resource_state(stub) + assert len(state.cluster_resource_constraints) > 0, ( + f"request_cluster_resources with {token_state} token should succeed" + ) + + +def _setup_invalid_token(): + """Helper to set up an invalid authentication token.""" + + invalid_token = "invalid_token_value" + authentication_test_utils.set_env_auth_token(invalid_token) + authentication_test_utils.reset_auth_token_state() + + +def _clear_token(): + """Helper to clear authentication token sources.""" + + authentication_test_utils.clear_auth_token_sources() + authentication_test_utils.reset_auth_token_state() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/src/ray/pubsub/python_gcs_subscriber.cc b/src/ray/pubsub/python_gcs_subscriber.cc index c4b5ae762e9b..ce2b025e5383 100644 --- a/src/ray/pubsub/python_gcs_subscriber.cc +++ b/src/ray/pubsub/python_gcs_subscriber.cc @@ -52,7 +52,6 @@ Status PythonGcsSubscriber::Subscribe() { } grpc::ClientContext context; - SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -80,7 +79,6 @@ Status PythonGcsSubscriber::DoPoll(int64_t timeout_ms, rpc::PubMessage *message) return Status::OK(); } current_polling_context_ = std::make_shared(); - SetAuthenticationToken(*current_polling_context_); if (timeout_ms != -1) { current_polling_context_->set_deadline(std::chrono::system_clock::now() + std::chrono::milliseconds(timeout_ms)); @@ -176,7 +174,6 @@ Status PythonGcsSubscriber::Close() { } grpc::ClientContext context; - SetAuthenticationToken(context); rpc::GcsSubscriberCommandBatchRequest request; request.set_subscriber_id(subscriber_id_); @@ -199,12 +196,5 @@ int64_t PythonGcsSubscriber::last_batch_size() { return last_batch_size_; } -void PythonGcsSubscriber::SetAuthenticationToken(grpc::ClientContext &context) { - auto auth_token = ray::rpc::AuthenticationTokenLoader::instance().GetToken(); - if (auth_token.has_value() && !auth_token->empty()) { - auth_token->SetMetadata(context); - } -} - } // namespace pubsub } // namespace ray diff --git a/src/ray/pubsub/python_gcs_subscriber.h b/src/ray/pubsub/python_gcs_subscriber.h index e8aeaa116566..5fe4eda29812 100644 --- a/src/ray/pubsub/python_gcs_subscriber.h +++ b/src/ray/pubsub/python_gcs_subscriber.h @@ -80,10 +80,6 @@ class RAY_EXPORT PythonGcsSubscriber { std::deque queue_ ABSL_GUARDED_BY(mu_); bool closed_ ABSL_GUARDED_BY(mu_) = false; std::shared_ptr current_polling_context_ ABSL_GUARDED_BY(mu_); - - // Set authentication token on a gRPC client context if token-based authentication is - // enabled - void SetAuthenticationToken(grpc::ClientContext &context); }; /// Get the .lines() attribute of a LogBatch as a std::vector diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 637655b02e29..746a6c77222a 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -21,22 +21,25 @@ ray_cc_library( "//src/ray/common:grpc_util", "//src/ray/common:id", "//src/ray/common:status", - "//src/ray/rpc/authentication:authentication_token_loader", "@com_google_absl//absl/synchronization", ], ) ray_cc_library( name = "grpc_client", + srcs = ["grpc_client.cc"], hdrs = ["grpc_client.h"], visibility = ["//visibility:public"], deps = [ ":client_call", ":common", ":rpc_chaos", + "//src/ray/common:constants", "//src/ray/common:grpc_util", "//src/ray/common:ray_config", "//src/ray/common:status", + "//src/ray/rpc/authentication:authentication_mode", + "//src/ray/rpc/authentication:authentication_token_loader", "//src/ray/util:network_util", ], ) diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 319915f3e17a..7955727d4f8b 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -31,8 +31,6 @@ #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/rpc/authentication/authentication_token.h" -#include "ray/rpc/authentication/authentication_token_loader.h" #include "ray/rpc/metrics.h" #include "ray/rpc/rpc_callback_types.h" #include "ray/util/thread_utils.h" @@ -74,7 +72,6 @@ class ClientCallImpl : public ClientCall { /// \param[in] callback The callback function to handle the reply. explicit ClientCallImpl(const ClientCallback &callback, const ClusterID &cluster_id, - const std::optional &auth_token, std::shared_ptr stats_handle, bool record_stats, int64_t timeout_ms = -1) @@ -89,10 +86,6 @@ class ClientCallImpl : public ClientCall { if (!cluster_id.IsNil()) { context_.AddMetadata(kClusterIdKey, cluster_id.Hex()); } - // Add authentication token if provided - if (auth_token.has_value()) { - auth_token->SetMetadata(context_); - } } Status GetStatus() override { @@ -286,7 +279,6 @@ class ClientCallManager { auto call = std::make_shared>( callback, cluster_id_, - AuthenticationTokenLoader::instance().GetToken(), std::move(stats_handle), record_stats_, method_timeout_ms); diff --git a/src/ray/rpc/grpc_client.cc b/src/ray/rpc/grpc_client.cc new file mode 100644 index 000000000000..f4a37645dbcc --- /dev/null +++ b/src/ray/rpc/grpc_client.cc @@ -0,0 +1,117 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/grpc_client.h" + +#include + +#include "ray/common/constants.h" +#include "ray/rpc/authentication/authentication_mode.h" +#include "ray/rpc/authentication/authentication_token_loader.h" + +namespace ray { +namespace rpc { + +namespace { + +/// Client interceptor that automatically adds Ray authentication tokens to outgoing RPCs. +/// The token is loaded from AuthenticationTokenLoader and added as a Bearer token +/// in the "authorization" metadata key. +class RayTokenAuthClientInterceptor : public grpc::experimental::Interceptor { + public: + void Intercept(grpc::experimental::InterceptorBatchMethods *methods) override { + if (methods->QueryInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto token = AuthenticationTokenLoader::instance().GetToken(); + + // If token is present and non-empty, add it to the metadata + if (token.has_value() && !token->empty()) { + // Get the metadata map and add the authorization header + auto *metadata = methods->GetSendInitialMetadata(); + metadata->insert(std::make_pair(kAuthTokenKey, + token->ToAuthorizationHeaderValue())); + } + } + methods->Proceed(); + } +}; + +/// Factory for creating RayAuthClientInterceptor instances +class RayTokenAuthClientInterceptorFactory + : public grpc::experimental::ClientInterceptorFactoryInterface { + public: + grpc::experimental::Interceptor *CreateClientInterceptor( + grpc::experimental::ClientRpcInfo *info) override { + return new RayTokenAuthClientInterceptor(); + } +}; + +} // namespace + +std::shared_ptr BuildChannel( + const std::string &address, + int port, + std::optional arguments) { + // Set up channel arguments with default values if not provided + if (!arguments.has_value()) { + arguments = grpc::ChannelArguments(); + } + + arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, + ::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0); + arguments->SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); + arguments->SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); + arguments->SetInt(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE, + ::RayConfig::instance().grpc_stream_buffer_size()); + + // Step 1: Get transport credentials (TLS or insecure) + std::shared_ptr channel_creds; + + if (::RayConfig::instance().USE_TLS()) { + std::string server_cert_file = std::string(::RayConfig::instance().TLS_SERVER_CERT()); + std::string server_key_file = std::string(::RayConfig::instance().TLS_SERVER_KEY()); + std::string root_cert_file = std::string(::RayConfig::instance().TLS_CA_CERT()); + std::string server_cert_chain = ReadCert(server_cert_file); + std::string private_key = ReadCert(server_key_file); + std::string cacert = ReadCert(root_cert_file); + + grpc::SslCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = cacert; + ssl_opts.pem_private_key = private_key; + ssl_opts.pem_cert_chain = server_cert_chain; + channel_creds = grpc::SslCredentials(ssl_opts); + } else { + channel_creds = grpc::InsecureChannelCredentials(); + } + + // Step 2: Create channel with interceptors if token auth is enabled + std::string target_address = BuildAddress(address, port); + + if (GetAuthenticationMode() == AuthenticationMode::TOKEN) { + // Create channel with auth interceptor + std::vector> + interceptor_factories; + interceptor_factories.push_back( + std::make_unique()); + + return grpc::experimental::CreateCustomChannelWithInterceptors( + target_address, channel_creds, *arguments, std::move(interceptor_factories)); + } else { + // Create channel without interceptors + return grpc::CreateCustomChannel(target_address, channel_creds, *arguments); + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index bb6ad2a949e0..9657e84e33e3 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -56,42 +56,21 @@ namespace rpc { INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client, method_timeout_ms); \ } -inline std::shared_ptr BuildChannel( +/// Build a gRPC channel to the specified address. +/// +/// This is the ONLY recommended way to create gRPC channels in Ray. +/// Use of raw grpc::CreateCustomChannel() should be avoided. +/// +/// Authentication tokens are automatically added in metadata when RAY_AUTH_MODE=token. +/// +/// \param address The server address +/// \param port The server port +/// \param arguments Optional channel arguments for customization +/// \return A shared pointer to the gRPC channel +std::shared_ptr BuildChannel( const std::string &address, int port, - std::optional arguments = std::nullopt) { - if (!arguments.has_value()) { - arguments = grpc::ChannelArguments(); - } - - arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, - ::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0); - arguments->SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); - arguments->SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - arguments->SetInt(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE, - ::RayConfig::instance().grpc_stream_buffer_size()); - std::shared_ptr channel; - if (::RayConfig::instance().USE_TLS()) { - std::string server_cert_file = std::string(::RayConfig::instance().TLS_SERVER_CERT()); - std::string server_key_file = std::string(::RayConfig::instance().TLS_SERVER_KEY()); - std::string root_cert_file = std::string(::RayConfig::instance().TLS_CA_CERT()); - std::string server_cert_chain = ReadCert(server_cert_file); - std::string private_key = ReadCert(server_key_file); - std::string cacert = ReadCert(root_cert_file); - - grpc::SslCredentialsOptions ssl_opts; - ssl_opts.pem_root_certs = cacert; - ssl_opts.pem_private_key = private_key; - ssl_opts.pem_cert_chain = server_cert_chain; - auto ssl_creds = grpc::SslCredentials(ssl_opts); - channel = - grpc::CreateCustomChannel(BuildAddress(address, port), ssl_creds, *arguments); - } else { - channel = grpc::CreateCustomChannel( - BuildAddress(address, port), grpc::InsecureChannelCredentials(), *arguments); - } - return channel; -} + std::optional arguments = std::nullopt); template class GrpcClient { From d66af1697d8d17565b44fe07be08d45141daba97 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 6 Nov 2025 06:34:47 +0000 Subject: [PATCH 16/26] fix lint issues Signed-off-by: sampan --- python/ray/autoscaler/v2/tests/test_sdk.py | 17 +++++------------ src/ray/rpc/client_call.h | 6 +----- src/ray/rpc/grpc_client.cc | 9 +++++++-- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index c6d260500267..17c390e625d0 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -30,8 +30,8 @@ from ray.core.generated import autoscaler_pb2, autoscaler_pb2_grpc from ray.core.generated.autoscaler_pb2 import ClusterResourceState, NodeStatus from ray.core.generated.common_pb2 import LabelSelectorOperator -from ray.util.state.api import list_nodes from ray.tests import authentication_test_utils +from ray.util.state.api import list_nodes def _autoscaler_state_service_stub(): @@ -942,14 +942,9 @@ def test_autoscaler_api_with_token_auth( Tests request_cluster_resources with valid, invalid, and missing tokens. """ - cluster_info = setup_cluster_with_token_auth - cluster = cluster_info["cluster"] - # Setup token state (this changes the client-side token) setup_token() - # Ray is already initialized by the fixture, so just use it - # For invalid token test, this creates a mismatch between client and server tokens if should_fail: # API call should fail with invalid token with pytest.raises(Exception) as exc_info: @@ -962,9 +957,7 @@ def test_autoscaler_api_with_token_auth( error_str = str(exc_info.value).lower() assert ( "unauthenticated" in error_str or "invalidauthtoken" in error_str - ), ( - f"request_cluster_resources with {token_state} token should return auth error, got: {exc_info.value}" - ) + ), f"request_cluster_resources with {token_state} token should return auth error, got: {exc_info.value}" else: # API call should succeed with valid token request_cluster_resources( @@ -975,9 +968,9 @@ def test_autoscaler_api_with_token_auth( # Verify the request was successful using the autoscaler state service stub stub = _autoscaler_state_service_stub() state = get_cluster_resource_state(stub) - assert len(state.cluster_resource_constraints) > 0, ( - f"request_cluster_resources with {token_state} token should succeed" - ) + assert ( + len(state.cluster_resource_constraints) > 0 + ), f"request_cluster_resources with {token_state} token should succeed" def _setup_invalid_token(): diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 7955727d4f8b..0c890ee080cb 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -277,11 +277,7 @@ class ClientCallManager { } auto call = std::make_shared>( - callback, - cluster_id_, - std::move(stats_handle), - record_stats_, - method_timeout_ms); + callback, cluster_id_, std::move(stats_handle), record_stats_, method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( diff --git a/src/ray/rpc/grpc_client.cc b/src/ray/rpc/grpc_client.cc index f4a37645dbcc..aa6f4e0ba6ca 100644 --- a/src/ray/rpc/grpc_client.cc +++ b/src/ray/rpc/grpc_client.cc @@ -16,6 +16,11 @@ #include +#include +#include +#include +#include + #include "ray/common/constants.h" #include "ray/rpc/authentication/authentication_mode.h" #include "ray/rpc/authentication/authentication_token_loader.h" @@ -39,8 +44,8 @@ class RayTokenAuthClientInterceptor : public grpc::experimental::Interceptor { if (token.has_value() && !token->empty()) { // Get the metadata map and add the authorization header auto *metadata = methods->GetSendInitialMetadata(); - metadata->insert(std::make_pair(kAuthTokenKey, - token->ToAuthorizationHeaderValue())); + metadata->insert( + std::make_pair(kAuthTokenKey, token->ToAuthorizationHeaderValue())); } } methods->Proceed(); From fbcfea4a244d271c10d45c372d5ddc0426823637 Mon Sep 17 00:00:00 2001 From: sampan Date: Thu, 6 Nov 2025 07:30:44 +0000 Subject: [PATCH 17/26] separate out intererceptor code Signed-off-by: sampan --- src/ray/rpc/BUILD.bazel | 2 +- src/ray/rpc/authentication/BUILD.bazel | 13 ++++ .../token_auth_client_interceptor.cc | 63 +++++++++++++++++++ .../token_auth_client_interceptor.h | 47 ++++++++++++++ src/ray/rpc/grpc_client.cc | 50 +-------------- 5 files changed, 126 insertions(+), 49 deletions(-) create mode 100644 src/ray/rpc/authentication/token_auth_client_interceptor.cc create mode 100644 src/ray/rpc/authentication/token_auth_client_interceptor.h diff --git a/src/ray/rpc/BUILD.bazel b/src/ray/rpc/BUILD.bazel index 746a6c77222a..0cf589f3fe60 100644 --- a/src/ray/rpc/BUILD.bazel +++ b/src/ray/rpc/BUILD.bazel @@ -39,7 +39,7 @@ ray_cc_library( "//src/ray/common:ray_config", "//src/ray/common:status", "//src/ray/rpc/authentication:authentication_mode", - "//src/ray/rpc/authentication:authentication_token_loader", + "//src/ray/rpc/authentication:token_auth_client_interceptor", "//src/ray/util:network_util", ], ) diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel index 8da78e5d728b..712cf9ae6ef7 100644 --- a/src/ray/rpc/authentication/BUILD.bazel +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -32,3 +32,16 @@ ray_cc_library( "//src/ray/util:logging", ], ) + +ray_cc_library( + name = "token_auth_client_interceptor", + srcs = ["token_auth_client_interceptor.cc"], + hdrs = ["token_auth_client_interceptor.h"], + visibility = ["//visibility:public"], + deps = [ + ":authentication_token", + ":authentication_token_loader", + "//src/ray/common:constants", + "@com_github_grpc_grpc//:grpc++", + ], +) diff --git a/src/ray/rpc/authentication/token_auth_client_interceptor.cc b/src/ray/rpc/authentication/token_auth_client_interceptor.cc new file mode 100644 index 000000000000..bece4f3bc3b9 --- /dev/null +++ b/src/ray/rpc/authentication/token_auth_client_interceptor.cc @@ -0,0 +1,63 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/token_auth_client_interceptor.h" + +#include + +#include +#include +#include +#include + +#include "ray/common/constants.h" +#include "ray/rpc/authentication/authentication_token_loader.h" + +namespace ray { +namespace rpc { + +void RayTokenAuthClientInterceptor::Intercept( + grpc::experimental::InterceptorBatchMethods *methods) { + if (methods->QueryInterceptionHookPoint( + grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto token = AuthenticationTokenLoader::instance().GetToken(); + + // If token is present and non-empty, add it to the metadata + if (token.has_value() && !token->empty()) { + // Get the metadata map and add the authorization header + auto *metadata = methods->GetSendInitialMetadata(); + metadata->insert( + std::make_pair(kAuthTokenKey, token->ToAuthorizationHeaderValue())); + } + } + methods->Proceed(); +} + +grpc::experimental::Interceptor * +RayTokenAuthClientInterceptorFactory::CreateClientInterceptor( + grpc::experimental::ClientRpcInfo *info) { + return new RayTokenAuthClientInterceptor(); +} + +std::vector> +CreateTokenAuthInterceptorFactories() { + std::vector> + interceptor_factories; + interceptor_factories.push_back( + std::make_unique()); + return interceptor_factories; +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/token_auth_client_interceptor.h b/src/ray/rpc/authentication/token_auth_client_interceptor.h new file mode 100644 index 000000000000..8dff955c0516 --- /dev/null +++ b/src/ray/rpc/authentication/token_auth_client_interceptor.h @@ -0,0 +1,47 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +namespace ray { +namespace rpc { + +/// Client interceptor that automatically adds Ray authentication tokens to outgoing RPCs. +/// The token is loaded from AuthenticationTokenLoader and added as a Bearer token +/// in the "authorization" metadata key. +class RayTokenAuthClientInterceptor : public grpc::experimental::Interceptor { + public: + void Intercept(grpc::experimental::InterceptorBatchMethods *methods) override; +}; + +/// Factory for creating RayTokenAuthClientInterceptor instances +class RayTokenAuthClientInterceptorFactory + : public grpc::experimental::ClientInterceptorFactoryInterface { + public: + grpc::experimental::Interceptor *CreateClientInterceptor( + grpc::experimental::ClientRpcInfo *info) override; +}; + +/// Creates a vector of interceptor factories for token authentication. +/// This should be used when creating gRPC channels with token auth enabled. +std::vector> +CreateTokenAuthInterceptorFactories(); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_client.cc b/src/ray/rpc/grpc_client.cc index aa6f4e0ba6ca..83f4bce557e7 100644 --- a/src/ray/rpc/grpc_client.cc +++ b/src/ray/rpc/grpc_client.cc @@ -14,56 +14,15 @@ #include "ray/rpc/grpc_client.h" -#include - #include #include -#include -#include -#include "ray/common/constants.h" #include "ray/rpc/authentication/authentication_mode.h" -#include "ray/rpc/authentication/authentication_token_loader.h" +#include "ray/rpc/authentication/token_auth_client_interceptor.h" namespace ray { namespace rpc { -namespace { - -/// Client interceptor that automatically adds Ray authentication tokens to outgoing RPCs. -/// The token is loaded from AuthenticationTokenLoader and added as a Bearer token -/// in the "authorization" metadata key. -class RayTokenAuthClientInterceptor : public grpc::experimental::Interceptor { - public: - void Intercept(grpc::experimental::InterceptorBatchMethods *methods) override { - if (methods->QueryInterceptionHookPoint( - grpc::experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { - auto token = AuthenticationTokenLoader::instance().GetToken(); - - // If token is present and non-empty, add it to the metadata - if (token.has_value() && !token->empty()) { - // Get the metadata map and add the authorization header - auto *metadata = methods->GetSendInitialMetadata(); - metadata->insert( - std::make_pair(kAuthTokenKey, token->ToAuthorizationHeaderValue())); - } - } - methods->Proceed(); - } -}; - -/// Factory for creating RayAuthClientInterceptor instances -class RayTokenAuthClientInterceptorFactory - : public grpc::experimental::ClientInterceptorFactoryInterface { - public: - grpc::experimental::Interceptor *CreateClientInterceptor( - grpc::experimental::ClientRpcInfo *info) override { - return new RayTokenAuthClientInterceptor(); - } -}; - -} // namespace - std::shared_ptr BuildChannel( const std::string &address, int port, @@ -105,13 +64,8 @@ std::shared_ptr BuildChannel( if (GetAuthenticationMode() == AuthenticationMode::TOKEN) { // Create channel with auth interceptor - std::vector> - interceptor_factories; - interceptor_factories.push_back( - std::make_unique()); - return grpc::experimental::CreateCustomChannelWithInterceptors( - target_address, channel_creds, *arguments, std::move(interceptor_factories)); + target_address, channel_creds, *arguments, CreateTokenAuthInterceptorFactories()); } else { // Create channel without interceptors return grpc::CreateCustomChannel(target_address, channel_creds, *arguments); From 1240a8c6f0e43f8f8252cc72783e596f1cd1bd4e Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 11 Nov 2025 05:05:15 +0000 Subject: [PATCH 18/26] empty commit Signed-off-by: sampan From 3e473ba1341f8ca43dc6bc8d1b06bd7ee9109ca3 Mon Sep 17 00:00:00 2001 From: sampan Date: Tue, 11 Nov 2025 05:09:20 +0000 Subject: [PATCH 19/26] address comment Signed-off-by: sampan --- python/ray/_private/authentication/authentication_utils.py | 2 +- python/ray/dashboard/tests/test_dashboard_auth.py | 2 +- python/ray/tests/authentication_test_utils.py | 2 +- src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc | 4 ++-- src/ray/raylet/tests/runtime_env_agent_client_test.cc | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/_private/authentication/authentication_utils.py b/python/ray/_private/authentication/authentication_utils.py index a87821e64509..ce1351d3c2fe 100644 --- a/python/ray/_private/authentication/authentication_utils.py +++ b/python/ray/_private/authentication/authentication_utils.py @@ -15,7 +15,7 @@ def is_token_auth_enabled() -> bool: """Check if token authentication is enabled. Returns: - bool: True if auth_mode is set to "token", False otherwise + bool: True if AUTH_MODE is set to "token", False otherwise """ if not _RAYLET_AVAILABLE: return False diff --git a/python/ray/dashboard/tests/test_dashboard_auth.py b/python/ray/dashboard/tests/test_dashboard_auth.py index 5f4f9b8ffc11..1df2a65d96f0 100644 --- a/python/ray/dashboard/tests/test_dashboard_auth.py +++ b/python/ray/dashboard/tests/test_dashboard_auth.py @@ -51,7 +51,7 @@ def test_dashboard_request_requires_auth_invalid_token(setup_cluster_with_token_ def test_dashboard_auth_disabled(setup_cluster_without_token_auth): - """Test that auth is not enforced when auth_mode is disabled.""" + """Test that auth is not enforced when AUTH_MODE is disabled.""" cluster_info = setup_cluster_without_token_auth diff --git a/python/ray/tests/authentication_test_utils.py b/python/ray/tests/authentication_test_utils.py index d69aaa15c163..ab0119e52426 100644 --- a/python/ray/tests/authentication_test_utils.py +++ b/python/ray/tests/authentication_test_utils.py @@ -13,7 +13,7 @@ def reset_auth_token_state() -> None: - """Reset authentication token and auth_mode ray config.""" + """Reset authentication token and AUTH_MODE ray config.""" AuthenticationTokenLoader.instance().reset_cache() Config.initialize("") diff --git a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc index 5d3fa6ead8fe..9bc3d76ca916 100644 --- a/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc +++ b/src/ray/pubsub/tests/python_gcs_subscriber_auth_test.cc @@ -137,10 +137,10 @@ class PythonGcsSubscriberAuthTest : public ::testing::Test { void SetClientToken(const std::string &client_token) { if (!client_token.empty()) { ray::SetEnv("RAY_AUTH_TOKEN", client_token); - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); } else { ray::UnsetEnv("RAY_AUTH_TOKEN"); - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); } rpc::AuthenticationTokenLoader::instance().ResetCache(); } diff --git a/src/ray/raylet/tests/runtime_env_agent_client_test.cc b/src/ray/raylet/tests/runtime_env_agent_client_test.cc index 003908d6bf26..92b186b11736 100644 --- a/src/ray/raylet/tests/runtime_env_agent_client_test.cc +++ b/src/ray/raylet/tests/runtime_env_agent_client_test.cc @@ -194,7 +194,7 @@ delay_after(instrumented_io_context &ioc) { auto dummy_shutdown_raylet_gracefully = [](const rpc::NodeDeathInfo &) {}; TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvOK) { - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); ray::UnsetEnv("RAY_AUTH_TOKEN"); rpc::AuthenticationTokenLoader::instance().ResetCache(); @@ -366,7 +366,7 @@ TEST(RuntimeEnvAgentClientTest, GetOrCreateRuntimeEnvRetriesOnServerNotStarted) } TEST(RuntimeEnvAgentClientTest, AttachesAuthHeaderWhenEnabled) { - RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})"); ray::SetEnv("RAY_AUTH_TOKEN", "header_token"); rpc::AuthenticationTokenLoader::instance().ResetCache(); @@ -428,7 +428,7 @@ TEST(RuntimeEnvAgentClientTest, AttachesAuthHeaderWhenEnabled) { ASSERT_EQ(called_times, 1); ASSERT_EQ(observed_auth_header, "Bearer header_token"); - RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + RayConfig::instance().initialize(R"({"AUTH_MODE": "disabled"})"); ray::UnsetEnv("RAY_AUTH_TOKEN"); rpc::AuthenticationTokenLoader::instance().ResetCache(); } From 215b625b1353d76c478434ed3df14b3d18349a85 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 04:37:51 +0000 Subject: [PATCH 20/26] [Core] support token auth in ray client server Signed-off-by: sampan --- .../grpc_authentication_client_interceptor.py | 2 +- .../grpc_authentication_server_interceptor.py | 142 ++++++++++++---- python/ray/_private/gcs_utils.py | 2 +- python/ray/_private/grpc_utils.py | 157 +++++++++++++++++ python/ray/_private/internal_api.py | 6 +- python/ray/_private/test_utils.py | 4 +- python/ray/_private/utils.py | 64 ------- python/ray/autoscaler/v2/tests/test_sdk.py | 2 +- .../aggregator/tests/test_aggregator_agent.py | 2 +- .../ray/dashboard/modules/node/node_head.py | 2 +- .../modules/reporter/reporter_head.py | 2 +- python/ray/tests/BUILD.bazel | 1 + python/ray/tests/authentication/__init__.py | 1 + python/ray/tests/authentication/conftest.py | 126 ++++++++++++++ .../test_async_grpc_interceptors.py | 159 ++++++++++++++++++ .../test_sync_grpc_interceptors.py | 153 +++++++++++++++++ python/ray/tests/test_memory_pressure.py | 2 +- python/ray/tests/test_state_api.py | 2 +- .../ray/tests/test_token_auth_integration.py | 19 ++- python/ray/util/client/server/proxier.py | 10 +- python/ray/util/client/server/server.py | 24 ++- python/ray/util/client/worker.py | 53 +++--- python/ray/util/state/state_manager.py | 4 +- 23 files changed, 787 insertions(+), 152 deletions(-) create mode 100644 python/ray/_private/grpc_utils.py create mode 100644 python/ray/tests/authentication/__init__.py create mode 100644 python/ray/tests/authentication/conftest.py create mode 100644 python/ray/tests/authentication/test_async_grpc_interceptors.py create mode 100644 python/ray/tests/authentication/test_sync_grpc_interceptors.py diff --git a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py index 9a0800a90a56..01fdd48d7732 100644 --- a/python/ray/_private/authentication/grpc_authentication_client_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_client_interceptor.py @@ -36,7 +36,7 @@ def _get_authentication_metadata_tuple() -> Tuple[Tuple[str, str], ...]: return tuple((k, v) for k, v in headers.items()) -class AuthenticationMetadataClientInterceptor( +class SyncAuthenticationMetadataClientInterceptor( grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, diff --git a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py index 18bc4103d303..e6da22bb7019 100644 --- a/python/ray/_private/authentication/grpc_authentication_server_interceptor.py +++ b/python/ray/_private/authentication/grpc_authentication_server_interceptor.py @@ -17,6 +17,33 @@ logger = logging.getLogger(__name__) +def _authenticate_request(metadata: tuple) -> bool: + """Authenticate incoming request. currently only supports token authentication. + + Args: + metadata: gRPC metadata tuple of (key, value) pairs + Returns: + True if authentication succeeds or is not required, False otherwise + """ + if not is_token_auth_enabled(): + return True + + # Extract authorization header from metadata + auth_header = None + for key, value in metadata: + if key.lower() == AUTHORIZATION_HEADER_NAME: + auth_header = value + break + + if not auth_header: + logger.warning("Authentication required but no authorization header provided") + return False + + # Validate the token format and value + # validate_request_token returns bool (True if valid, False otherwise) + return validate_request_token(auth_header) + + class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor): """Async gRPC server interceptor that validates authentication tokens. @@ -25,36 +52,6 @@ class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor): If the token is missing or invalid, the request is rejected with UNAUTHENTICATED status. """ - def _validate_authentication(self, metadata: tuple) -> bool: - """Validate authentication token from gRPC metadata. - - Args: - metadata: gRPC metadata tuple of (key, value) pairs - - Returns: - True if authentication succeeds or is not required, False otherwise - """ - # If token auth is not enabled, allow all requests - if not is_token_auth_enabled(): - return True - - # Extract authorization header from metadata - auth_header = None - for key, value in metadata: - if key.lower() == AUTHORIZATION_HEADER_NAME: - auth_header = value - break - - if not auth_header: - logger.warning( - "Authentication required but no authorization header provided" - ) - return False - - # Validate the token format and value - # validate_request_token returns bool (True if valid, False otherwise) - return validate_request_token(auth_header) - async def intercept_service( self, continuation: Callable[ @@ -80,7 +77,7 @@ def wrap_rpc_behavior(behavior): return None async def wrapped(request_or_iterator, context): - if not self._validate_authentication(context.invocation_metadata()): + if not _authenticate_request(context.invocation_metadata()): await context.abort( grpc.StatusCode.UNAUTHENTICATED, "Invalid or missing authentication token", @@ -130,3 +127,86 @@ def stream_stream(self): return self._wrap(self._original.stream_stream) return AuthenticatedHandler(handler, wrap_rpc_behavior) + + +class SyncAuthenticationServerInterceptor(grpc.ServerInterceptor): + """Synchronous gRPC server interceptor that validates authentication tokens. + + This interceptor checks the "authorization" metadata header for a valid + Bearer token when token authentication is enabled via RAY_AUTH_MODE=token. + If the token is missing or invalid, the request is rejected with UNAUTHENTICATED status. + """ + + def intercept_service( + self, + continuation: Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler: + """Intercept service calls to validate authentication. + + This method is called once per RPC to get the handler. We wrap the handler + to validate authentication before executing the actual RPC method. + """ + # Get the actual handler + handler = continuation(handler_call_details) + + if handler is None: + return None + + # Wrap the RPC behavior with authentication check + def wrap_rpc_behavior(behavior): + """Wrap an RPC method to validate authentication first.""" + if behavior is None: + return None + + def wrapped(request_or_iterator, context): + if not _authenticate_request(context.invocation_metadata()): + context.abort( + grpc.StatusCode.UNAUTHENTICATED, + "Invalid or missing authentication token", + ) + return behavior(request_or_iterator, context) + + return wrapped + + # Create a wrapper class that implements RpcMethodHandler interface + class AuthenticatedHandler: + """Wrapper handler that validates authentication.""" + + def __init__(self, original_handler, wrapper_func): + self._original = original_handler + self._wrap = wrapper_func + + @property + def request_streaming(self): + return self._original.request_streaming + + @property + def response_streaming(self): + return self._original.response_streaming + + @property + def request_deserializer(self): + return self._original.request_deserializer + + @property + def response_serializer(self): + return self._original.response_serializer + + @property + def unary_unary(self): + return self._wrap(self._original.unary_unary) + + @property + def unary_stream(self): + return self._wrap(self._original.unary_stream) + + @property + def stream_unary(self): + return self._wrap(self._original.stream_unary) + + @property + def stream_stream(self): + return self._wrap(self._original.stream_stream) + + return AuthenticatedHandler(handler, wrap_rpc_behavior) diff --git a/python/ray/_private/gcs_utils.py b/python/ray/_private/gcs_utils.py index 5678d681794d..4590e7370a68 100644 --- a/python/ray/_private/gcs_utils.py +++ b/python/ray/_private/gcs_utils.py @@ -79,7 +79,7 @@ def create_gcs_channel(address: str, aio=False): Returns: grpc.Channel or grpc.aio.Channel to GCS """ - from ray._private.utils import init_grpc_channel + from ray._private.grpc_utils import init_grpc_channel return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio) diff --git a/python/ray/_private/grpc_utils.py b/python/ray/_private/grpc_utils.py new file mode 100644 index 000000000000..9d22b60a7a37 --- /dev/null +++ b/python/ray/_private/grpc_utils.py @@ -0,0 +1,157 @@ +"""Utilities for creating gRPC channels and servers with authentication support.""" + +import os +from concurrent import futures +from typing import Any, Optional, Sequence, Tuple + +import grpc +from grpc import aio as aiogrpc +from ray._private.authentication import authentication_utils +from ray._private.tls_utils import load_certs_from_env + +import ray + + +def init_grpc_channel( + address: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + asynchronous: bool = False, + credentials: Optional[grpc.ChannelCredentials] = None, +): + """Create a gRPC channel with authentication interceptors if token auth is enabled. + + This function handles: + - TLS configuration via RAY_USE_TLS environment variable or custom credentials + - Authentication interceptors when token auth is enabled + - Keepalive settings from Ray config + - Both synchronous and asynchronous channels + + Args: + address: The gRPC server address (host:port) + options: Optional gRPC channel options as sequence of (key, value) tuples + asynchronous: If True, create async channel; otherwise sync + credentials: Optional custom gRPC credentials for TLS. If provided, takes + precedence over RAY_USE_TLS environment variable. + + Returns: + grpc.Channel or grpc.aio.Channel: Configured gRPC channel with interceptors + """ + grpc_module = aiogrpc if asynchronous else grpc + + options = options or [] + options_dict = dict(options) + options_dict["grpc.keepalive_time_ms"] = options_dict.get( + "grpc.keepalive_time_ms", ray._config.grpc_client_keepalive_time_ms() + ) + options_dict["grpc.keepalive_timeout_ms"] = options_dict.get( + "grpc.keepalive_timeout_ms", ray._config.grpc_client_keepalive_timeout_ms() + ) + options = options_dict.items() + + # Build interceptors list + interceptors = [] + if authentication_utils.is_token_auth_enabled(): + from ray._private.authentication.grpc_authentication_client_interceptor import ( + AsyncAuthenticationMetadataClientInterceptor, + SyncAuthenticationMetadataClientInterceptor, + ) + + if asynchronous: + interceptors.append(AsyncAuthenticationMetadataClientInterceptor()) + else: + interceptors.append(SyncAuthenticationMetadataClientInterceptor()) + + # Determine channel type and credentials + if credentials is not None: + # Use provided custom credentials (takes precedence) + channel_creator = grpc_module.secure_channel + base_args = (address, credentials) + elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): + # Use TLS from environment variables + server_cert_chain, private_key, ca_cert = load_certs_from_env() + tls_credentials = grpc.ssl_channel_credentials( + certificate_chain=server_cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + channel_creator = grpc_module.secure_channel + base_args = (address, tls_credentials) + else: + # Insecure channel + channel_creator = grpc_module.insecure_channel + base_args = (address,) + + # Create channel (async channels get interceptors in constructor, sync via intercept_channel) + if asynchronous: + channel = channel_creator( + *base_args, options=options, interceptors=interceptors + ) + else: + channel = channel_creator(*base_args, options=options) + if interceptors: + channel = grpc.intercept_channel(channel, *interceptors) + + return channel + + +def create_grpc_server_with_interceptors( + max_workers: Optional[int] = None, + thread_name_prefix: str = "grpc_server", + options: Optional[Sequence[Tuple[str, Any]]] = None, + asynchronous: bool = False, +): + """Create a gRPC server with authentication interceptors if token auth is enabled. + + This function handles: + - Authentication interceptors when token auth is enabled + - Both synchronous and asynchronous servers + - Thread pool configuration for sync servers + + Args: + max_workers: Max thread pool workers (required for sync, ignored for async) + thread_name_prefix: Thread name prefix for sync thread pool + options: Optional gRPC server options as sequence of (key, value) tuples + asynchronous: If True, create async server; otherwise sync + + Returns: + grpc.Server or grpc.aio.Server: Configured gRPC server with interceptors + """ + grpc_module = aiogrpc if asynchronous else grpc + + # Build interceptors list + interceptors = [] + if authentication_utils.is_token_auth_enabled(): + if asynchronous: + from ray._private.authentication.grpc_authentication_server_interceptor import ( + AsyncAuthenticationServerInterceptor, + ) + + interceptors.append(AsyncAuthenticationServerInterceptor()) + else: + from ray._private.authentication.grpc_authentication_server_interceptor import ( + SyncAuthenticationServerInterceptor, + ) + + interceptors.append(SyncAuthenticationServerInterceptor()) + + # Create server + if asynchronous: + server = grpc_module.server( + interceptors=interceptors if interceptors else None, + options=options, + ) + else: + if max_workers is None: + raise ValueError("max_workers is required for synchronous gRPC servers") + + executor = futures.ThreadPoolExecutor( + max_workers=max_workers, + thread_name_prefix=thread_name_prefix, + ) + server = grpc_module.server( + executor, + interceptors=interceptors if interceptors else None, + options=options, + ) + + return server diff --git a/python/ray/_private/internal_api.py b/python/ray/_private/internal_api.py index a461a09360bb..704a59cfa80f 100644 --- a/python/ray/_private/internal_api.py +++ b/python/ray/_private/internal_api.py @@ -58,6 +58,7 @@ def get_memory_info_reply(state, node_manager_address=None, node_manager_port=No """Returns global memory info.""" from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc + from ray._private.grpc_utils import init_grpc_channel # We can ask any Raylet for the global memory info, that Raylet internally # asks all nodes in the cluster for memory stats. @@ -75,7 +76,7 @@ def get_memory_info_reply(state, node_manager_address=None, node_manager_port=No else: raylet_address = build_address(node_manager_address, node_manager_port) - channel = utils.init_grpc_channel( + channel = init_grpc_channel( raylet_address, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), @@ -97,11 +98,12 @@ def node_stats( """Returns NodeStats object describing memory usage in the cluster.""" from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc + from ray._private.grpc_utils import init_grpc_channel # We can ask any Raylet for the global memory info. assert node_manager_address is not None and node_manager_port is not None raylet_address = build_address(node_manager_address, node_manager_port) - channel = utils.init_grpc_channel( + channel = init_grpc_channel( raylet_address, options=[ ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 51508be58357..cd23c8d3e333 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -1867,7 +1867,7 @@ def get_node_stats(raylet, num_retry=5, timeout=2): raylet_address = build_address( raylet["NodeManagerAddress"], raylet["NodeManagerPort"] ) - channel = ray._private.utils.init_grpc_channel(raylet_address) + channel = ray._private.grpc_utils.init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) for _ in range(num_retry): try: @@ -1888,7 +1888,7 @@ def get_resource_usage(gcs_address, timeout=10): if not gcs_address: gcs_address = ray.worker._global_node.gcs_address - gcs_channel = ray._private.utils.init_grpc_channel( + gcs_channel = ray._private.grpc_utils.init_grpc_channel( gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False ) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 5bda84aad204..4d1bfe7c1f52 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1018,70 +1018,6 @@ def validate_namespace(namespace: str): ) -def init_grpc_channel( - address: str, - options: Optional[Sequence[Tuple[str, Any]]] = None, - asynchronous: bool = False, -): - import grpc - from grpc import aio as aiogrpc - - from ray._private.authentication import authentication_utils - from ray._private.tls_utils import load_certs_from_env - - grpc_module = aiogrpc if asynchronous else grpc - - options = options or [] - options_dict = dict(options) - options_dict["grpc.keepalive_time_ms"] = options_dict.get( - "grpc.keepalive_time_ms", ray._config.grpc_client_keepalive_time_ms() - ) - options_dict["grpc.keepalive_timeout_ms"] = options_dict.get( - "grpc.keepalive_timeout_ms", ray._config.grpc_client_keepalive_timeout_ms() - ) - options = options_dict.items() - - # Build interceptors list - interceptors = [] - if authentication_utils.is_token_auth_enabled(): - from ray._private.authentication.grpc_authentication_client_interceptor import ( - AsyncAuthenticationMetadataClientInterceptor, - AuthenticationMetadataClientInterceptor, - ) - - if asynchronous: - interceptors.append(AsyncAuthenticationMetadataClientInterceptor()) - else: - interceptors.append(AuthenticationMetadataClientInterceptor()) - - # Create channel with TLS if enabled - use_tls = os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true") - if use_tls: - server_cert_chain, private_key, ca_cert = load_certs_from_env() - credentials = grpc.ssl_channel_credentials( - certificate_chain=server_cert_chain, - private_key=private_key, - root_certificates=ca_cert, - ) - channel_creator = grpc_module.secure_channel - base_args = (address, credentials) - else: - channel_creator = grpc_module.insecure_channel - base_args = (address,) - - # Create channel (async channels get interceptors in constructor, sync via intercept_channel) - if asynchronous: - channel = channel_creator( - *base_args, options=options, interceptors=interceptors - ) - else: - channel = channel_creator(*base_args, options=options) - if interceptors: - channel = grpc.intercept_channel(channel, *interceptors) - - return channel - - def get_dashboard_dependency_error() -> Optional[ImportError]: """Returns the exception error if Ray Dashboard dependencies are not installed. None if they are installed. diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index 17c390e625d0..648e007706c8 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -37,7 +37,7 @@ def _autoscaler_state_service_stub(): """Get the grpc stub for the autoscaler state service""" gcs_address = ray.get_runtime_context().gcs_address - gcs_channel = ray._private.utils.init_grpc_channel( + gcs_channel = ray._private.grpc_utils.init_grpc_channel( gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS ) return autoscaler_pb2_grpc.AutoscalerStateServiceStub(gcs_channel) diff --git a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py index 499fc4e3b2e4..f28b59be584c 100644 --- a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py +++ b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py @@ -11,7 +11,7 @@ from ray._common.network_utils import find_free_port from ray._private import ray_constants from ray._private.test_utils import wait_for_condition -from ray._private.utils import init_grpc_channel +from ray._private.grpc_utils import init_grpc_channel from ray._raylet import GcsClient from ray.core.generated.common_pb2 import ( ErrorType, diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index d1d8ae775ab3..94a1ba87a0a4 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -290,7 +290,7 @@ async def _update_node(self, node: dict): node["nodeManagerAddress"], int(node["nodeManagerPort"]) ) options = ray_constants.GLOBAL_GRPC_OPTIONS - channel = ray._private.utils.init_grpc_channel( + channel = ray._private.grpc_utils.init_grpc_channel( address, options, asynchronous=True ) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 8971991c4dea..2f213b04b071 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -24,7 +24,7 @@ KV_NAMESPACE_DASHBOARD, env_integer, ) -from ray._private.utils import init_grpc_channel +from ray._private.grpc_utils import init_grpc_channel from ray.autoscaler._private.commands import debug_status from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 9275fa8dd520..7f41eeafb6c9 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -227,6 +227,7 @@ py_test_module_list( "test_placement_group_5.py", "test_scheduling.py", "test_scheduling_2.py", + "test_token_auth_integration.py", "test_wait.py", ], name_suffix = "_client_mode", diff --git a/python/ray/tests/authentication/__init__.py b/python/ray/tests/authentication/__init__.py new file mode 100644 index 000000000000..ee3ad837bc34 --- /dev/null +++ b/python/ray/tests/authentication/__init__.py @@ -0,0 +1 @@ +"""Tests for gRPC authentication interceptors.""" diff --git a/python/ray/tests/authentication/conftest.py b/python/ray/tests/authentication/conftest.py new file mode 100644 index 000000000000..0d97bb941965 --- /dev/null +++ b/python/ray/tests/authentication/conftest.py @@ -0,0 +1,126 @@ +"""Shared fixtures for gRPC authentication interceptor integration tests.""" + +import uuid + +import grpc +import pytest +from grpc import aio as aiogrpc + +from ray._private.grpc_utils import create_grpc_server_with_interceptors +from ray.core.generated import reporter_pb2, reporter_pb2_grpc +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + set_auth_mode, + set_env_auth_token, + reset_auth_token_state, +) + + +class SyncReporterService(reporter_pb2_grpc.ReporterServiceServicer): + """Simple synchronous test service for testing auth interceptors.""" + + def HealthCheck(self, request, context): + """Simple health check endpoint.""" + return reporter_pb2.HealthCheckReply() + + +class AsyncReporterService(reporter_pb2_grpc.ReporterServiceServicer): + """Simple asynchronous test service for testing auth interceptors.""" + + async def HealthCheck(self, request, context): + """Simple health check endpoint (async version).""" + return reporter_pb2.HealthCheckReply() + + +def _create_test_server_base( + *, + asynchronous: bool, + with_auth: bool, + servicer_cls, +): + """Internal helper to create sync or async test server with optional auth.""" + + if with_auth: + # Auth is enabled - server will use interceptor + server = create_grpc_server_with_interceptors( + max_workers=None if asynchronous else 10, + thread_name_prefix="test_server", + options=None, + asynchronous=asynchronous, + ) + else: + # Auth is disabled - create server without helper (no interceptor) + if asynchronous: + server = aiogrpc.server(options=None) + else: + from concurrent import futures + + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + options=None, + ) + + # Add test service + servicer = servicer_cls() + reporter_pb2_grpc.add_ReporterServiceServicer_to_server(servicer, server) + + # Bind to ephemeral port + port = server.add_insecure_port("[::]:0") + + return server, port + +@pytest.fixture +def create_sync_test_server(): + """Factory to create synchronous gRPC test server. + + Returns a function that creates a test server and returns (server, port). + The server must be stopped by the caller. + """ + + def _create(with_auth=True): + server, port = _create_test_server_base( + asynchronous=False, + with_auth=with_auth, + servicer_cls=SyncReporterService, + ) + server.start() + return server, port + + return _create + + +@pytest.fixture +def create_async_test_server(): + """Factory to create asynchronous gRPC test server. + + Returns an async function that creates a test server and returns (server, port). + The server must be stopped by the caller. + """ + + async def _create(with_auth=True): + server, port = _create_test_server_base( + asynchronous=True, + with_auth=with_auth, + servicer_cls=AsyncReporterService, + ) + await server.start() + return server, port + + return _create + + + +@pytest.fixture +def test_token(): + """Generate a test authentication token.""" + return uuid.uuid4().hex + + +@pytest.fixture +def setup_auth_environment(test_token): + """Set up authentication environment with test token.""" + with authentication_env_guard(): + set_auth_mode("token") + set_env_auth_token(test_token) + reset_auth_token_state() + yield test_token diff --git a/python/ray/tests/authentication/test_async_grpc_interceptors.py b/python/ray/tests/authentication/test_async_grpc_interceptors.py new file mode 100644 index 000000000000..95d222a6dcd3 --- /dev/null +++ b/python/ray/tests/authentication/test_async_grpc_interceptors.py @@ -0,0 +1,159 @@ +"""Integration tests for asynchronous gRPC authentication interceptors.""" + +import uuid + +import grpc +import pytest +from grpc import aio as aiogrpc + +from ray._private.grpc_utils import init_grpc_channel +from ray.core.generated import reporter_pb2, reporter_pb2_grpc +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + set_auth_mode, + set_env_auth_token, + reset_auth_token_state, +) + + +@pytest.mark.asyncio +async def test_async_server_and_client_with_valid_token(create_async_test_server): + """Test async server + client with matching token succeeds.""" + token = uuid.uuid4().hex + + with authentication_env_guard(): + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + # Create server with auth enabled + server, port = await create_async_test_server(with_auth=True) + + try: + # Client with auth interceptor via init_grpc_channel + channel = init_grpc_channel( + f"localhost:{port}", + options=None, + asynchronous=True, + ) + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + response = await stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + await server.stop(grace=1) + + +@pytest.mark.asyncio +async def test_async_server_and_client_with_invalid_token(create_async_test_server): + """Test async server + client with mismatched token fails.""" + server_token = uuid.uuid4().hex + wrong_token = uuid.uuid4().hex + + with authentication_env_guard(): + # Set up server with server_token + set_auth_mode("token") + set_env_auth_token(server_token) + reset_auth_token_state() + + server, port = await create_async_test_server(with_auth=True) + + try: + # Create client channel and manually add wrong token to metadata + channel = aiogrpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + # Add invalid token to metadata (not using client interceptor) + metadata = (("authorization", f"Bearer {wrong_token}"),) + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED + with pytest.raises(grpc.RpcError) as exc_info: + await stub.HealthCheck(request, metadata=metadata, timeout=5) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + finally: + await server.stop(grace=1) + + +@pytest.mark.asyncio +async def test_async_server_with_auth_client_without_token(create_async_test_server): + """Test async server with auth, client without token fails.""" + token = uuid.uuid4().hex + + with authentication_env_guard(): + # Set up server with auth enabled + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + server, port = await create_async_test_server(with_auth=True) + + try: + # Create channel without auth metadata + channel = aiogrpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED (no metadata provided) + with pytest.raises(grpc.RpcError) as exc_info: + await stub.HealthCheck(request, timeout=5) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + finally: + await server.stop(grace=1) + + +@pytest.mark.asyncio +async def test_async_server_without_auth(create_async_test_server): + """Test async server without auth allows unauthenticated requests.""" + with authentication_env_guard(): + # Disable auth mode + set_auth_mode("disabled") + reset_auth_token_state() + + # Create server without auth + server, port = await create_async_test_server(with_auth=False) + + try: + # Client without auth + channel = aiogrpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should succeed without auth + response = await stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + await server.stop(grace=1) + + +@pytest.mark.asyncio +async def test_async_server_with_auth_disabled_allows_all(create_async_test_server): + """Test async server allows requests when auth mode is disabled.""" + with authentication_env_guard(): + # Disable auth mode globally + set_auth_mode("disabled") + reset_auth_token_state() + + # Even though we call create_async_test_server with with_auth=True, + # the server won't enforce auth because auth mode is disabled + server, port = await create_async_test_server(with_auth=True) + + try: + # Client without token + channel = aiogrpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should succeed because auth is disabled + response = await stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + await server.stop(grace=1) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/authentication/test_sync_grpc_interceptors.py b/python/ray/tests/authentication/test_sync_grpc_interceptors.py new file mode 100644 index 000000000000..c3ddff2e706b --- /dev/null +++ b/python/ray/tests/authentication/test_sync_grpc_interceptors.py @@ -0,0 +1,153 @@ +"""Integration tests for synchronous gRPC authentication interceptors.""" + +import uuid + +import grpc +import pytest + +from ray._private.grpc_utils import init_grpc_channel +from ray.core.generated import reporter_pb2, reporter_pb2_grpc +from ray.tests.authentication_test_utils import ( + authentication_env_guard, + set_auth_mode, + set_env_auth_token, + reset_auth_token_state, +) + + +def test_sync_server_and_client_with_valid_token(create_sync_test_server): + """Test sync server + client with matching token succeeds.""" + token = uuid.uuid4().hex + + with authentication_env_guard(): + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + # Create server with auth enabled + server, port = create_sync_test_server(with_auth=True) + + try: + # Client with auth interceptor via init_grpc_channel + channel = init_grpc_channel( + f"localhost:{port}", + options=None, + asynchronous=False, + ) + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + response = stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + server.stop(grace=1) + + +def test_sync_server_and_client_with_invalid_token(create_sync_test_server): + """Test sync server + client with mismatched token fails.""" + server_token = uuid.uuid4().hex + wrong_token = uuid.uuid4().hex + + with authentication_env_guard(): + # Set up server with server_token + set_auth_mode("token") + set_env_auth_token(server_token) + reset_auth_token_state() + + server, port = create_sync_test_server(with_auth=True) + + try: + # Create client channel and manually add wrong token to metadata + channel = grpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + + # Add invalid token to metadata (not using client interceptor) + metadata = (("authorization", f"Bearer {wrong_token}"),) + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED + with pytest.raises(grpc.RpcError) as exc_info: + stub.HealthCheck(request, metadata=metadata, timeout=5) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + finally: + server.stop(grace=1) + + +def test_sync_server_with_auth_client_without_token(create_sync_test_server): + """Test server with auth, client without token fails.""" + token = uuid.uuid4().hex + + with authentication_env_guard(): + # Set up server with auth enabled + set_auth_mode("token") + set_env_auth_token(token) + reset_auth_token_state() + + server, port = create_sync_test_server(with_auth=True) + + try: + # Create channel without auth metadata + channel = grpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should fail with UNAUTHENTICATED (no metadata provided) + with pytest.raises(grpc.RpcError) as exc_info: + stub.HealthCheck(request, timeout=5) + + assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + finally: + server.stop(grace=1) + + +def test_sync_server_without_auth(create_sync_test_server): + """Test server without auth allows unauthenticated requests.""" + with authentication_env_guard(): + # Disable auth mode + set_auth_mode("disabled") + reset_auth_token_state() + + # Create server without auth + server, port = create_sync_test_server(with_auth=False) + + try: + # Client without auth + channel = grpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should succeed without auth + response = stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + server.stop(grace=1) + + +def test_sync_server_with_auth_disabled_allows_all(create_sync_test_server): + """Test server allows requests when auth mode is disabled.""" + with authentication_env_guard(): + # Disable auth mode globally + set_auth_mode("disabled") + reset_auth_token_state() + + # Even though we call create_sync_test_server with with_auth=True, + # the server won't enforce auth because auth mode is disabled + server, port = create_sync_test_server(with_auth=True) + + try: + # Client without token + channel = grpc.insecure_channel(f"localhost:{port}") + stub = reporter_pb2_grpc.ReporterServiceStub(channel) + request = reporter_pb2.HealthCheckRequest() + + # Should succeed because auth is disabled + response = stub.HealthCheck(request, timeout=5) + assert response is not None + finally: + server.stop(grace=1) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-vv", __file__])) diff --git a/python/ray/tests/test_memory_pressure.py b/python/ray/tests/test_memory_pressure.py index a3bd0213482d..6833522cd2d4 100644 --- a/python/ray/tests/test_memory_pressure.py +++ b/python/ray/tests/test_memory_pressure.py @@ -25,7 +25,7 @@ def get_local_state_client(): - gcs_channel = ray._private.utils.init_grpc_channel( + gcs_channel = ray._private.grpc_utils.init_grpc_channel( ray.worker._global_node.gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=True, diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index b65b187b5e9d..ded8d63cec6e 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -144,7 +144,7 @@ def state_source_client(gcs_address): ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ) - gcs_channel = ray._private.utils.init_grpc_channel( + gcs_channel = ray._private.grpc_utils.init_grpc_channel( gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True ) gcs_client = GcsClient(address=gcs_address) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 5fa5f99ccc6d..2ca56d97d50d 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -9,7 +9,7 @@ import pytest import ray -from ray._private.test_utils import wait_for_condition +from ray._private.test_utils import wait_for_condition, client_test_enabled try: from ray._raylet import AuthenticationTokenLoader @@ -117,7 +117,10 @@ def clean_token_sources(cleanup_auth_token_env): reset_auth_token_state() - +@pytest.mark.skipif( + client_test_enabled(), + reason="This test is for starting a new local cluster, not compatible with client mode" +) def test_local_cluster_generates_token(): """Test ray.init() generates token for local cluster when auth_mode=token is set.""" # Ensure no token exists @@ -217,6 +220,10 @@ def test_func(): ray.shutdown() +@pytest.mark.skipif( + client_test_enabled(), + reason="Uses subprocess ray start, not compatible with client mode" +) @pytest.mark.parametrize("is_head", [True, False]) def test_ray_start_without_token_raises_error(is_head, request): """Test that ray start fails when auth_mode=token but no token exists.""" @@ -247,6 +254,10 @@ def test_ray_start_without_token_raises_error(is_head, request): _run_ray_start_and_verify_status(args, env, expect_success=False) +@pytest.mark.skipif( + client_test_enabled(), + reason="Uses subprocess ray start, not compatible with client mode" +) def test_ray_start_head_with_token_succeeds(): """Test that ray start --head succeeds when token auth is enabled with a valid token.""" # Set up environment with token auth and a valid token @@ -290,6 +301,10 @@ def test_func(): _cleanup_ray_start(env) +@pytest.mark.skipif( + client_test_enabled(), + reason="Uses subprocess ray start, not compatible with client mode" +) @pytest.mark.parametrize("token_match", ["correct", "incorrect"]) def test_ray_start_address_with_token(token_match, setup_cluster_with_token_auth): """Test ray start --address=... with correct or incorrect token.""" diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 4abab76a825d..80070a5b235e 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -210,7 +210,7 @@ def create_specific_server(self, client_id: str) -> SpecificServer: server = SpecificServer( port=port, process_handle_future=futures.Future(), - channel=ray._private.utils.init_grpc_channel( + channel=ray._private.grpc_utils.init_grpc_channel( build_address(host, port), options=GRPC_OPTIONS ), ) @@ -874,9 +874,13 @@ def serve_proxier( gcs_cli = GcsClient(address=gcs_address) ray.experimental.internal_kv._initialize_internal_kv(gcs_cli) - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS), + from ray._private.grpc_utils import create_grpc_server_with_interceptors + + server = create_grpc_server_with_interceptors( + max_workers=CLIENT_SERVER_MAX_THREADS, + thread_name_prefix="ray_client_proxier", options=GRPC_OPTIONS, + asynchronous=False, ) proxy_manager = ProxyManager( gcs_address, diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index c4e5d897e09d..c59ca1720792 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -10,7 +10,6 @@ import threading import time from collections import defaultdict -from concurrent import futures from typing import Any, Callable, Dict, List, Optional, Set, Union import grpc @@ -135,7 +134,7 @@ def Init( logger.exception("Running Ray Init failed:") return ray_client_pb2.InitResponse( ok=False, - msg="Call to `ray.init()` on the server " f"failed with: {e}", + msg=f"Call to `ray.init()` on the server failed with: {e}", ) if job_config is None: return ray_client_pb2.InitResponse(ok=True) @@ -378,8 +377,7 @@ def Terminate(self, req, context=None): return_exception_in_context(e, context) else: raise RuntimeError( - "Client requested termination without providing a valid " - "terminate_type" + "Client requested termination without providing a valid terminate_type" ) return ray_client_pb2.TerminateResponse(ok=True) @@ -397,7 +395,7 @@ def _async_get_object( """ if len(request.ids) != 1: raise ValueError( - "Async get() must have exactly 1 Object ID. " f"Actual: {request}" + f"Async get() must have exactly 1 Object ID. Actual: {request}" ) rid = request.ids[0] ref = self.object_refs[client_id].get(rid, None) @@ -479,8 +477,7 @@ def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str): valid=False, error=cloudpickle.dumps( ValueError( - f"ClientObjectRef {rid} is not found for client " - f"{client_id}" + f"ClientObjectRef {rid} is not found for client {client_id}" ) ), ) @@ -773,13 +770,14 @@ def default_connect_handler( if not ray.is_initialized(): return ray.init(job_config=job_config, **ray_init_kwargs) + from ray._private.grpc_utils import create_grpc_server_with_interceptors + ray_connect_handler = ray_connect_handler or default_connect_handler - server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=CLIENT_SERVER_MAX_THREADS, - thread_name_prefix="ray_client_server", - ), + server = create_grpc_server_with_interceptors( + max_workers=CLIENT_SERVER_MAX_THREADS, + thread_name_prefix="ray_client_server", options=GRPC_OPTIONS, + asynchronous=False, ) task_servicer = RayletServicer(ray_connect_handler) data_servicer = DataServicer(task_servicer) @@ -929,7 +927,7 @@ def main(): ) except Exception as e: logger.error( - f"[{args.mode}] Failed to put health check " f"on {args.address}" + f"[{args.mode}] Failed to put health check on {args.address}" ) logger.exception(e) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index a13b5bca8535..c9487e79308f 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -2,6 +2,7 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ + import base64 import json import logging @@ -17,7 +18,6 @@ import grpc -import ray._private.tls_utils import ray.cloudpickle as cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc @@ -71,9 +71,13 @@ # Links to the Ray Design Pattern doc to use in the task overhead warning # message -DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501 +DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = ( + "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501 +) -DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501 +DESIGN_PATTERN_LARGE_OBJECTS_LINK = ( + "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501 +) def backoff(timeout: int) -> int: @@ -173,27 +177,28 @@ def _connect_channel(self, reconnecting=False) -> None: self.channel.unsubscribe(self._on_channel_state_change) self.channel.close() + from ray._private.grpc_utils import init_grpc_channel + + # Prepare credentials if secure connection is requested + credentials = None if self._secure: if self._credentials is not None: credentials = self._credentials elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"): - ( - server_cert_chain, - private_key, - ca_cert, - ) = ray._private.tls_utils.load_certs_from_env() - credentials = grpc.ssl_channel_credentials( - certificate_chain=server_cert_chain, - private_key=private_key, - root_certificates=ca_cert, - ) + # init_grpc_channel will handle this via load_certs_from_env() + credentials = None else: + # Default SSL credentials (no specific certs) credentials = grpc.ssl_channel_credentials() - self.channel = grpc.secure_channel( - self._conn_str, credentials, options=GRPC_OPTIONS - ) - else: - self.channel = grpc.insecure_channel(self._conn_str, options=GRPC_OPTIONS) + + # Create channel with auth interceptors via helper + # This automatically adds auth interceptors when token auth is enabled + self.channel = init_grpc_channel( + self._conn_str, + options=GRPC_OPTIONS, + asynchronous=False, + credentials=credentials, + ) self.channel.subscribe(self._on_channel_state_change) @@ -233,15 +238,14 @@ def _connect_channel(self, reconnecting=False) -> None: # which is why we do not sleep here. except grpc.RpcError as e: logger.debug( - "Ray client server unavailable, " f"retrying in {timeout}s..." + f"Ray client server unavailable, retrying in {timeout}s..." ) logger.debug(f"Received when checking init: {e.details()}") # Ray is not ready yet, wait a timeout. time.sleep(timeout) # Fallthrough, backoff, and retry at the top of the loop logger.debug( - "Waiting for Ray to become ready on the server, " - f"retry in {timeout}s..." + f"Waiting for Ray to become ready on the server, retry in {timeout}s..." ) if not reconnecting: # Don't increase backoff when trying to reconnect -- @@ -523,7 +527,7 @@ def wait( ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: if not isinstance(object_refs, list): raise TypeError( - "wait() expected a list of ClientObjectRef, " f"got {type(object_refs)}" + f"wait() expected a list of ClientObjectRef, got {type(object_refs)}" ) for ref in object_refs: if not isinstance(ref, ClientObjectRef): @@ -612,9 +616,8 @@ def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None: self.data_client.Schedule(task, populate_ids) self.total_outbound_message_size_bytes += task.ByteSize() - if ( - self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD - and log_once("client_communication_overhead_warning") + if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD and log_once( + "client_communication_overhead_warning" ): warnings.warn( "More than 10MB of messages have been created to schedule " diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index b22ba784e8c2..f6a3c3b2118f 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -146,7 +146,7 @@ def register_gcs_client(self, gcs_channel: grpc.aio.Channel): def get_raylet_stub(self, ip: str, port: int): options = _STATE_MANAGER_GRPC_OPTIONS - channel = ray._private.utils.init_grpc_channel( + channel = ray._private.grpc_utils.init_grpc_channel( build_address(ip, port), options, asynchronous=True ) return NodeManagerServiceStub(channel) @@ -162,7 +162,7 @@ async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub: return None ip, http_port, grpc_port = json.loads(agent_addr) options = ray_constants.GLOBAL_GRPC_OPTIONS - channel = ray._private.utils.init_grpc_channel( + channel = ray._private.grpc_utils.init_grpc_channel( build_address(ip, grpc_port), options=options, asynchronous=True ) return LogServiceStub(channel) From 4a0566b184341324d68922faac05a15e258b3e79 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 04:38:47 +0000 Subject: [PATCH 21/26] fix lint Signed-off-by: sampan --- python/ray/_private/grpc_utils.py | 4 ++-- python/ray/_private/internal_api.py | 5 ++--- python/ray/_private/utils.py | 2 -- .../aggregator/tests/test_aggregator_agent.py | 2 +- .../ray/dashboard/modules/reporter/reporter_head.py | 2 +- python/ray/tests/authentication/conftest.py | 4 ++-- .../authentication/test_async_grpc_interceptors.py | 2 +- .../authentication/test_sync_grpc_interceptors.py | 2 +- python/ray/tests/test_token_auth_integration.py | 11 ++++++----- python/ray/util/client/worker.py | 13 +++++-------- 10 files changed, 21 insertions(+), 26 deletions(-) diff --git a/python/ray/_private/grpc_utils.py b/python/ray/_private/grpc_utils.py index 9d22b60a7a37..b795d2c58c01 100644 --- a/python/ray/_private/grpc_utils.py +++ b/python/ray/_private/grpc_utils.py @@ -6,10 +6,10 @@ import grpc from grpc import aio as aiogrpc -from ray._private.authentication import authentication_utils -from ray._private.tls_utils import load_certs_from_env import ray +from ray._private.authentication import authentication_utils +from ray._private.tls_utils import load_certs_from_env def init_grpc_channel( diff --git a/python/ray/_private/internal_api.py b/python/ray/_private/internal_api.py index 704a59cfa80f..7a5573216988 100644 --- a/python/ray/_private/internal_api.py +++ b/python/ray/_private/internal_api.py @@ -4,7 +4,6 @@ import ray import ray._private.profiling as profiling import ray._private.services as services -import ray._private.utils as utils import ray._private.worker from ray._common.network_utils import build_address from ray._private.state import GlobalState @@ -57,8 +56,8 @@ def memory_summary( def get_memory_info_reply(state, node_manager_address=None, node_manager_port=None): """Returns global memory info.""" - from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc from ray._private.grpc_utils import init_grpc_channel + from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc # We can ask any Raylet for the global memory info, that Raylet internally # asks all nodes in the cluster for memory stats. @@ -97,8 +96,8 @@ def node_stats( ): """Returns NodeStats object describing memory usage in the cluster.""" - from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc from ray._private.grpc_utils import init_grpc_channel + from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc # We can ask any Raylet for the global memory info. assert node_manager_address is not None and node_manager_port is not None diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 4d1bfe7c1f52..f7ba9af1465d 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -16,12 +16,10 @@ from subprocess import list2cmdline from typing import ( TYPE_CHECKING, - Any, Dict, List, Mapping, Optional, - Sequence, Tuple, Union, ) diff --git a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py index f28b59be584c..1e9eb0330938 100644 --- a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py +++ b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py @@ -10,8 +10,8 @@ import ray.dashboard.consts as dashboard_consts from ray._common.network_utils import find_free_port from ray._private import ray_constants -from ray._private.test_utils import wait_for_condition from ray._private.grpc_utils import init_grpc_channel +from ray._private.test_utils import wait_for_condition from ray._raylet import GcsClient from ray.core.generated.common_pb2 import ( ErrorType, diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 2f213b04b071..74fe9c447c03 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -14,6 +14,7 @@ from ray import ActorID, NodeID from ray._common.network_utils import build_address from ray._common.usage.usage_constants import CLUSTER_METADATA_KEY +from ray._private.grpc_utils import init_grpc_channel from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter from ray._private.ray_constants import ( DEBUG_AUTOSCALING_ERROR, @@ -24,7 +25,6 @@ KV_NAMESPACE_DASHBOARD, env_integer, ) -from ray._private.grpc_utils import init_grpc_channel from ray.autoscaler._private.commands import debug_status from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS diff --git a/python/ray/tests/authentication/conftest.py b/python/ray/tests/authentication/conftest.py index 0d97bb941965..5452c030ba9b 100644 --- a/python/ray/tests/authentication/conftest.py +++ b/python/ray/tests/authentication/conftest.py @@ -10,9 +10,9 @@ from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.tests.authentication_test_utils import ( authentication_env_guard, + reset_auth_token_state, set_auth_mode, set_env_auth_token, - reset_auth_token_state, ) @@ -69,6 +69,7 @@ def _create_test_server_base( return server, port + @pytest.fixture def create_sync_test_server(): """Factory to create synchronous gRPC test server. @@ -109,7 +110,6 @@ async def _create(with_auth=True): return _create - @pytest.fixture def test_token(): """Generate a test authentication token.""" diff --git a/python/ray/tests/authentication/test_async_grpc_interceptors.py b/python/ray/tests/authentication/test_async_grpc_interceptors.py index 95d222a6dcd3..410543caa2fd 100644 --- a/python/ray/tests/authentication/test_async_grpc_interceptors.py +++ b/python/ray/tests/authentication/test_async_grpc_interceptors.py @@ -10,9 +10,9 @@ from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.tests.authentication_test_utils import ( authentication_env_guard, + reset_auth_token_state, set_auth_mode, set_env_auth_token, - reset_auth_token_state, ) diff --git a/python/ray/tests/authentication/test_sync_grpc_interceptors.py b/python/ray/tests/authentication/test_sync_grpc_interceptors.py index c3ddff2e706b..07429964b1d1 100644 --- a/python/ray/tests/authentication/test_sync_grpc_interceptors.py +++ b/python/ray/tests/authentication/test_sync_grpc_interceptors.py @@ -9,9 +9,9 @@ from ray.core.generated import reporter_pb2, reporter_pb2_grpc from ray.tests.authentication_test_utils import ( authentication_env_guard, + reset_auth_token_state, set_auth_mode, set_env_auth_token, - reset_auth_token_state, ) diff --git a/python/ray/tests/test_token_auth_integration.py b/python/ray/tests/test_token_auth_integration.py index 2ca56d97d50d..51aab367b69a 100644 --- a/python/ray/tests/test_token_auth_integration.py +++ b/python/ray/tests/test_token_auth_integration.py @@ -9,7 +9,7 @@ import pytest import ray -from ray._private.test_utils import wait_for_condition, client_test_enabled +from ray._private.test_utils import client_test_enabled, wait_for_condition try: from ray._raylet import AuthenticationTokenLoader @@ -117,9 +117,10 @@ def clean_token_sources(cleanup_auth_token_env): reset_auth_token_state() + @pytest.mark.skipif( client_test_enabled(), - reason="This test is for starting a new local cluster, not compatible with client mode" + reason="This test is for starting a new local cluster, not compatible with client mode", ) def test_local_cluster_generates_token(): """Test ray.init() generates token for local cluster when auth_mode=token is set.""" @@ -222,7 +223,7 @@ def test_func(): @pytest.mark.skipif( client_test_enabled(), - reason="Uses subprocess ray start, not compatible with client mode" + reason="Uses subprocess ray start, not compatible with client mode", ) @pytest.mark.parametrize("is_head", [True, False]) def test_ray_start_without_token_raises_error(is_head, request): @@ -256,7 +257,7 @@ def test_ray_start_without_token_raises_error(is_head, request): @pytest.mark.skipif( client_test_enabled(), - reason="Uses subprocess ray start, not compatible with client mode" + reason="Uses subprocess ray start, not compatible with client mode", ) def test_ray_start_head_with_token_succeeds(): """Test that ray start --head succeeds when token auth is enabled with a valid token.""" @@ -303,7 +304,7 @@ def test_func(): @pytest.mark.skipif( client_test_enabled(), - reason="Uses subprocess ray start, not compatible with client mode" + reason="Uses subprocess ray start, not compatible with client mode", ) @pytest.mark.parametrize("token_match", ["correct", "incorrect"]) def test_ray_start_address_with_token(token_match, setup_cluster_with_token_auth): diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index c9487e79308f..166ec6cb675b 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -71,13 +71,9 @@ # Links to the Ray Design Pattern doc to use in the task overhead warning # message -DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = ( - "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501 -) +DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501 -DESIGN_PATTERN_LARGE_OBJECTS_LINK = ( - "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501 -) +DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501 def backoff(timeout: int) -> int: @@ -616,8 +612,9 @@ def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None: self.data_client.Schedule(task, populate_ids) self.total_outbound_message_size_bytes += task.ByteSize() - if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD and log_once( - "client_communication_overhead_warning" + if ( + self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD + and log_once("client_communication_overhead_warning") ): warnings.warn( "More than 10MB of messages have been created to schedule " From 99513f70b3d3edd2871cef8fb5dba8b4e106212a Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 09:12:08 +0000 Subject: [PATCH 22/26] fix import after merge Signed-off-by: sampan --- python/ray/autoscaler/v2/tests/test_sdk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index 17c390e625d0..a61d4cb5ab6d 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -30,7 +30,7 @@ from ray.core.generated import autoscaler_pb2, autoscaler_pb2_grpc from ray.core.generated.autoscaler_pb2 import ClusterResourceState, NodeStatus from ray.core.generated.common_pb2 import LabelSelectorOperator -from ray.tests import authentication_test_utils +from ray._private import authentication_test_utils from ray.util.state.api import list_nodes From 0317fdaf3bd24462e6793d66a47d3f57a242e001 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 09:12:49 +0000 Subject: [PATCH 23/26] fix lint issues Signed-off-by: sampan --- python/ray/autoscaler/v2/tests/test_sdk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index a61d4cb5ab6d..278ff4886d89 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -11,6 +11,7 @@ import ray import ray._private.ray_constants as ray_constants from ray._common.test_utils import wait_for_condition +from ray._private import authentication_test_utils from ray.autoscaler.v2.schema import ( ClusterStatus, LaunchRequest, @@ -30,7 +31,6 @@ from ray.core.generated import autoscaler_pb2, autoscaler_pb2_grpc from ray.core.generated.autoscaler_pb2 import ClusterResourceState, NodeStatus from ray.core.generated.common_pb2 import LabelSelectorOperator -from ray._private import authentication_test_utils from ray.util.state.api import list_nodes From fb3d6cfb1710afb213fb19e24e48ef149229b228 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 11:26:30 +0000 Subject: [PATCH 24/26] fix imports Signed-off-by: sampan --- python/ray/_private/test_utils.py | 5 +++-- python/ray/autoscaler/v2/tests/test_sdk.py | 5 ++--- python/ray/dashboard/modules/node/node_head.py | 5 ++--- python/ray/tests/test_memory_pressure.py | 3 ++- python/ray/tests/test_state_api.py | 3 ++- python/ray/util/client/server/proxier.py | 3 ++- python/ray/util/state/state_manager.py | 8 +++----- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index cd23c8d3e333..40ca6ae643dc 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -37,6 +37,7 @@ from ray._private import ( ray_constants, ) +from ray._private.grpc_utils import init_grpc_channel from ray._private.internal_api import memory_summary from ray._private.tls_utils import generate_self_signed_tls_certs from ray._private.worker import RayContext @@ -1867,7 +1868,7 @@ def get_node_stats(raylet, num_retry=5, timeout=2): raylet_address = build_address( raylet["NodeManagerAddress"], raylet["NodeManagerPort"] ) - channel = ray._private.grpc_utils.init_grpc_channel(raylet_address) + channel = init_grpc_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) for _ in range(num_retry): try: @@ -1888,7 +1889,7 @@ def get_resource_usage(gcs_address, timeout=10): if not gcs_address: gcs_address = ray.worker._global_node.gcs_address - gcs_channel = ray._private.grpc_utils.init_grpc_channel( + gcs_channel = init_grpc_channel( gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False ) diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index 648e007706c8..b7108766a253 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -11,6 +11,7 @@ import ray import ray._private.ray_constants as ray_constants from ray._common.test_utils import wait_for_condition +from ray._private.grpc_utils import init_grpc_channel from ray.autoscaler.v2.schema import ( ClusterStatus, LaunchRequest, @@ -37,9 +38,7 @@ def _autoscaler_state_service_stub(): """Get the grpc stub for the autoscaler state service""" gcs_address = ray.get_runtime_context().gcs_address - gcs_channel = ray._private.grpc_utils.init_grpc_channel( - gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS - ) + gcs_channel = init_grpc_channel(gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS) return autoscaler_pb2_grpc.AutoscalerStateServiceStub(gcs_channel) diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 94a1ba87a0a4..4836ec31b09f 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -21,6 +21,7 @@ GcsAioNodeInfoSubscriber, GcsAioResourceUsageSubscriber, ) +from ray._private.grpc_utils import init_grpc_channel from ray._private.ray_constants import ( DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS, @@ -290,9 +291,7 @@ async def _update_node(self, node: dict): node["nodeManagerAddress"], int(node["nodeManagerPort"]) ) options = ray_constants.GLOBAL_GRPC_OPTIONS - channel = ray._private.grpc_utils.init_grpc_channel( - address, options, asynchronous=True - ) + channel = init_grpc_channel(address, options, asynchronous=True) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub diff --git a/python/ray/tests/test_memory_pressure.py b/python/ray/tests/test_memory_pressure.py index 6833522cd2d4..5bfcd1394db2 100644 --- a/python/ray/tests/test_memory_pressure.py +++ b/python/ray/tests/test_memory_pressure.py @@ -11,6 +11,7 @@ from ray._private import ( ray_constants, ) +from ray._private.grpc_utils import init_grpc_channel from ray._private.state_api_test_utils import verify_failed_task from ray._private.test_utils import raw_metrics from ray._private.utils import get_used_memory @@ -25,7 +26,7 @@ def get_local_state_client(): - gcs_channel = ray._private.grpc_utils.init_grpc_channel( + gcs_channel = init_grpc_channel( ray.worker._global_node.gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=True, diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index ded8d63cec6e..faad1b1e7c30 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -23,6 +23,7 @@ async_wait_for_condition, wait_for_condition, ) +from ray._private.grpc_utils import init_grpc_channel from ray._private.state_api_test_utils import ( create_api_options, get_state_api_manager, @@ -144,7 +145,7 @@ def state_source_client(gcs_address): ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ) - gcs_channel = ray._private.grpc_utils.init_grpc_channel( + gcs_channel = init_grpc_channel( gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True ) gcs_client = GcsClient(address=gcs_address) diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 80070a5b235e..c0468dfa5fe8 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -24,6 +24,7 @@ get_auth_headers_if_auth_enabled, ) from ray._private.client_mode_hook import disable_client_hook +from ray._private.grpc_utils import init_grpc_channel from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ProcessInfo, start_ray_client_server @@ -210,7 +211,7 @@ def create_specific_server(self, client_id: str) -> SpecificServer: server = SpecificServer( port=port, process_handle_future=futures.Future(), - channel=ray._private.grpc_utils.init_grpc_channel( + channel=init_grpc_channel( build_address(host, port), options=GRPC_OPTIONS ), ) diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index f6a3c3b2118f..22e387bbbd5c 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -9,12 +9,12 @@ import grpc from grpc.aio._call import UnaryStreamCall -import ray import ray.dashboard.consts as dashboard_consts import ray.dashboard.modules.log.log_consts as log_consts from ray._common.network_utils import build_address from ray._common.utils import hex_to_binary from ray._private import ray_constants +from ray._private.grpc_utils import init_grpc_channel from ray._raylet import ActorID, GcsClient, JobID, NodeID, TaskID from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated.gcs_pb2 import ActorTableData, GcsNodeInfo @@ -146,9 +146,7 @@ def register_gcs_client(self, gcs_channel: grpc.aio.Channel): def get_raylet_stub(self, ip: str, port: int): options = _STATE_MANAGER_GRPC_OPTIONS - channel = ray._private.grpc_utils.init_grpc_channel( - build_address(ip, port), options, asynchronous=True - ) + channel = init_grpc_channel(build_address(ip, port), options, asynchronous=True) return NodeManagerServiceStub(channel) async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub: @@ -162,7 +160,7 @@ async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub: return None ip, http_port, grpc_port = json.loads(agent_addr) options = ray_constants.GLOBAL_GRPC_OPTIONS - channel = ray._private.grpc_utils.init_grpc_channel( + channel = init_grpc_channel( build_address(ip, grpc_port), options=options, asynchronous=True ) return LogServiceStub(channel) From 2fd9386c9327949a7a2e5c477249430d42b88802 Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 15:18:59 +0000 Subject: [PATCH 25/26] make grpc_utils imports lazy Signed-off-by: sampan --- python/ray/_private/grpc_utils.py | 2 -- python/ray/_private/test_utils.py | 3 ++- python/ray/autoscaler/v2/tests/test_sdk.py | 3 ++- python/ray/tests/authentication/__init__.py | 1 - python/ray/tests/authentication/conftest.py | 8 +++----- .../tests/authentication/test_async_grpc_interceptors.py | 8 +++----- .../tests/authentication/test_sync_grpc_interceptors.py | 8 +++----- python/ray/util/state/state_manager.py | 5 ++++- 8 files changed, 17 insertions(+), 21 deletions(-) diff --git a/python/ray/_private/grpc_utils.py b/python/ray/_private/grpc_utils.py index b795d2c58c01..be5a90fee01a 100644 --- a/python/ray/_private/grpc_utils.py +++ b/python/ray/_private/grpc_utils.py @@ -1,5 +1,3 @@ -"""Utilities for creating gRPC channels and servers with authentication support.""" - import os from concurrent import futures from typing import Any, Optional, Sequence, Tuple diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 40ca6ae643dc..44e787643c7b 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -37,7 +37,6 @@ from ray._private import ( ray_constants, ) -from ray._private.grpc_utils import init_grpc_channel from ray._private.internal_api import memory_summary from ray._private.tls_utils import generate_self_signed_tls_certs from ray._private.worker import RayContext @@ -1863,6 +1862,7 @@ def wandb_setup_api_key_hook(): def get_node_stats(raylet, num_retry=5, timeout=2): import grpc + from ray._private.grpc_utils import init_grpc_channel from ray.core.generated import node_manager_pb2_grpc raylet_address = build_address( @@ -1884,6 +1884,7 @@ def get_node_stats(raylet, num_retry=5, timeout=2): # Gets resource usage assuming gcs is local. def get_resource_usage(gcs_address, timeout=10): + from ray._private.grpc_utils import init_grpc_channel from ray.core.generated import gcs_service_pb2_grpc if not gcs_address: diff --git a/python/ray/autoscaler/v2/tests/test_sdk.py b/python/ray/autoscaler/v2/tests/test_sdk.py index d71454bc64f6..43616aa4c392 100644 --- a/python/ray/autoscaler/v2/tests/test_sdk.py +++ b/python/ray/autoscaler/v2/tests/test_sdk.py @@ -12,7 +12,6 @@ import ray._private.ray_constants as ray_constants from ray._common.test_utils import wait_for_condition from ray._private import authentication_test_utils -from ray._private.grpc_utils import init_grpc_channel from ray.autoscaler.v2.schema import ( ClusterStatus, LaunchRequest, @@ -37,6 +36,8 @@ def _autoscaler_state_service_stub(): """Get the grpc stub for the autoscaler state service""" + from ray._private.grpc_utils import init_grpc_channel + gcs_address = ray.get_runtime_context().gcs_address gcs_channel = init_grpc_channel(gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS) return autoscaler_pb2_grpc.AutoscalerStateServiceStub(gcs_channel) diff --git a/python/ray/tests/authentication/__init__.py b/python/ray/tests/authentication/__init__.py index ee3ad837bc34..e69de29bb2d1 100644 --- a/python/ray/tests/authentication/__init__.py +++ b/python/ray/tests/authentication/__init__.py @@ -1 +0,0 @@ -"""Tests for gRPC authentication interceptors.""" diff --git a/python/ray/tests/authentication/conftest.py b/python/ray/tests/authentication/conftest.py index 5452c030ba9b..4b1e876f4665 100644 --- a/python/ray/tests/authentication/conftest.py +++ b/python/ray/tests/authentication/conftest.py @@ -1,19 +1,17 @@ -"""Shared fixtures for gRPC authentication interceptor integration tests.""" - import uuid import grpc import pytest from grpc import aio as aiogrpc -from ray._private.grpc_utils import create_grpc_server_with_interceptors -from ray.core.generated import reporter_pb2, reporter_pb2_grpc -from ray.tests.authentication_test_utils import ( +from ray._private.authentication_test_utils import ( authentication_env_guard, reset_auth_token_state, set_auth_mode, set_env_auth_token, ) +from ray._private.grpc_utils import create_grpc_server_with_interceptors +from ray.core.generated import reporter_pb2, reporter_pb2_grpc class SyncReporterService(reporter_pb2_grpc.ReporterServiceServicer): diff --git a/python/ray/tests/authentication/test_async_grpc_interceptors.py b/python/ray/tests/authentication/test_async_grpc_interceptors.py index 410543caa2fd..ca37fcc6ed60 100644 --- a/python/ray/tests/authentication/test_async_grpc_interceptors.py +++ b/python/ray/tests/authentication/test_async_grpc_interceptors.py @@ -1,19 +1,17 @@ -"""Integration tests for asynchronous gRPC authentication interceptors.""" - import uuid import grpc import pytest from grpc import aio as aiogrpc -from ray._private.grpc_utils import init_grpc_channel -from ray.core.generated import reporter_pb2, reporter_pb2_grpc -from ray.tests.authentication_test_utils import ( +from ray._private.authentication_test_utils import ( authentication_env_guard, reset_auth_token_state, set_auth_mode, set_env_auth_token, ) +from ray._private.grpc_utils import init_grpc_channel +from ray.core.generated import reporter_pb2, reporter_pb2_grpc @pytest.mark.asyncio diff --git a/python/ray/tests/authentication/test_sync_grpc_interceptors.py b/python/ray/tests/authentication/test_sync_grpc_interceptors.py index 07429964b1d1..276b81740e66 100644 --- a/python/ray/tests/authentication/test_sync_grpc_interceptors.py +++ b/python/ray/tests/authentication/test_sync_grpc_interceptors.py @@ -1,18 +1,16 @@ -"""Integration tests for synchronous gRPC authentication interceptors.""" - import uuid import grpc import pytest -from ray._private.grpc_utils import init_grpc_channel -from ray.core.generated import reporter_pb2, reporter_pb2_grpc -from ray.tests.authentication_test_utils import ( +from ray._private.authentication_test_utils import ( authentication_env_guard, reset_auth_token_state, set_auth_mode, set_env_auth_token, ) +from ray._private.grpc_utils import init_grpc_channel +from ray.core.generated import reporter_pb2, reporter_pb2_grpc def test_sync_server_and_client_with_valid_token(create_sync_test_server): diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index 22e387bbbd5c..d693eea0a780 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -14,7 +14,6 @@ from ray._common.network_utils import build_address from ray._common.utils import hex_to_binary from ray._private import ray_constants -from ray._private.grpc_utils import init_grpc_channel from ray._raylet import ActorID, GcsClient, JobID, NodeID, TaskID from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated.gcs_pb2 import ActorTableData, GcsNodeInfo @@ -145,12 +144,16 @@ def register_gcs_client(self, gcs_channel: grpc.aio.Channel): ) def get_raylet_stub(self, ip: str, port: int): + from ray._private.grpc_utils import init_grpc_channel + options = _STATE_MANAGER_GRPC_OPTIONS channel = init_grpc_channel(build_address(ip, port), options, asynchronous=True) return NodeManagerServiceStub(channel) async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub: """Returns None if the agent on the node is not registered in Internal KV.""" + from ray._private.grpc_utils import init_grpc_channel + agent_addr = await self._gcs_client.async_internal_kv_get( f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), namespace=ray_constants.KV_NAMESPACE_DASHBOARD, From 417f5750db5a662d29abcf3af7bdbfb783da85ab Mon Sep 17 00:00:00 2001 From: sampan Date: Wed, 12 Nov 2025 15:25:25 +0000 Subject: [PATCH 26/26] add BUILD.bazel Signed-off-by: sampan --- python/ray/tests/BUILD.bazel | 1 - python/ray/tests/authentication/BUILD.bazel | 23 +++ ..._grpc_authentication_server_interceptor.py | 195 ------------------ 3 files changed, 23 insertions(+), 196 deletions(-) create mode 100644 python/ray/tests/authentication/BUILD.bazel delete mode 100644 python/ray/tests/test_grpc_authentication_server_interceptor.py diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 87b2f2a99971..5cd50a7eeff3 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -56,7 +56,6 @@ py_test_module_list( "test_gcs_utils.py", "test_get_locations.py", "test_global_state.py", - "test_grpc_authentication_server_interceptor.py", "test_healthcheck.py", "test_metric_cardinality.py", "test_metrics_agent.py", diff --git a/python/ray/tests/authentication/BUILD.bazel b/python/ray/tests/authentication/BUILD.bazel new file mode 100644 index 000000000000..2d01db995c2e --- /dev/null +++ b/python/ray/tests/authentication/BUILD.bazel @@ -0,0 +1,23 @@ +load("@rules_python//python:defs.bzl", "py_library") +load("//bazel:python.bzl", "py_test_run_all_subdirectory") + +py_library( + name = "conftest", + srcs = ["conftest.py"], +) + +py_test_run_all_subdirectory( + size = "medium", + include = glob(["test_*.py"]), + exclude = [], + extra_srcs = [], + tags = [ + "exclusive", + "medium_size_python_tests_a_to_j", + "team:core", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) diff --git a/python/ray/tests/test_grpc_authentication_server_interceptor.py b/python/ray/tests/test_grpc_authentication_server_interceptor.py deleted file mode 100644 index 4b8b5b88710d..000000000000 --- a/python/ray/tests/test_grpc_authentication_server_interceptor.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Unit tests for gRPC server authentication interceptor.""" - -import uuid - -import grpc -import pytest -from grpc import aio as aiogrpc - -from ray._private.authentication.grpc_authentication_server_interceptor import ( - AsyncAuthenticationServerInterceptor, -) -from ray._private.authentication_test_utils import ( - authentication_env_guard, - reset_auth_token_state, - set_auth_mode, - set_env_auth_token, -) - -# Create a simple test service for testing -from ray.core.generated import reporter_pb2, reporter_pb2_grpc - - -class TestReporterServicer(reporter_pb2_grpc.ReporterServiceServicer): - """Simple test servicer for testing authentication.""" - - async def HealthCheck(self, request, context): - """Return a health check response.""" - return reporter_pb2.HealthCheckReply() - - -@pytest.fixture -async def auth_server_and_port(): - """Create a gRPC server with authentication interceptor.""" - interceptor = AsyncAuthenticationServerInterceptor() - server = aiogrpc.server(interceptors=[interceptor]) - - servicer = TestReporterServicer() - reporter_pb2_grpc.add_ReporterServiceServicer_to_server(servicer, server) - - port = server.add_insecure_port("[::]:0") - await server.start() - - yield server, port - - await server.stop(grace=1) - - -@pytest.fixture -async def no_auth_server_and_port(): - """Create a gRPC server without authentication interceptor.""" - server = aiogrpc.server() - - servicer = TestReporterServicer() - reporter_pb2_grpc.add_ReporterServiceServicer_to_server(servicer, server) - - port = server.add_insecure_port("[::]:0") - await server.start() - - yield server, port - - await server.stop(grace=1) - - -@pytest.mark.asyncio -async def test_server_interceptor_allows_valid_token(auth_server_and_port): - """Test that server interceptor allows requests with valid tokens.""" - with authentication_env_guard(): - # Set up token authentication - token = uuid.uuid4().hex - set_auth_mode("token") - set_env_auth_token(token) - reset_auth_token_state() - - # Get server from fixture - _, port = auth_server_and_port - - # Create client with valid token in metadata - async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: - stub = reporter_pb2_grpc.ReporterServiceStub(channel) - - # Add valid token to metadata - metadata = (("authorization", f"Bearer {token}"),) - - request = reporter_pb2.HealthCheckRequest() - response = await stub.HealthCheck(request, metadata=metadata) - - # Should succeed (response exists and is not None) - assert response is not None - - -@pytest.mark.asyncio -async def test_server_interceptor_rejects_invalid_token(auth_server_and_port): - """Test that server interceptor rejects requests with invalid tokens.""" - with authentication_env_guard(): - # Set up token authentication - correct_token = uuid.uuid4().hex - set_auth_mode("token") - set_env_auth_token(correct_token) - reset_auth_token_state() - - # Get server from fixture - _, port = auth_server_and_port - - # Create client with invalid token in metadata - async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: - stub = reporter_pb2_grpc.ReporterServiceStub(channel) - - # Add invalid token to metadata - wrong_token = uuid.uuid4().hex - metadata = (("authorization", f"Bearer {wrong_token}"),) - - request = reporter_pb2.HealthCheckRequest() - - # Should fail with UNAUTHENTICATED status - with pytest.raises(grpc.RpcError) as exc_info: - await stub.HealthCheck(request, metadata=metadata) - - assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED - - -@pytest.mark.asyncio -async def test_server_interceptor_rejects_missing_token(auth_server_and_port): - """Test that server interceptor rejects requests without tokens.""" - with authentication_env_guard(): - # Set up token authentication - token = uuid.uuid4().hex - set_auth_mode("token") - set_env_auth_token(token) - reset_auth_token_state() - - # Get server from fixture - _, port = auth_server_and_port - - # Create client without token in metadata - async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: - stub = reporter_pb2_grpc.ReporterServiceStub(channel) - - request = reporter_pb2.HealthCheckRequest() - - # Should fail with UNAUTHENTICATED status - with pytest.raises(grpc.RpcError) as exc_info: - await stub.HealthCheck(request) - - assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED - - -@pytest.mark.asyncio -async def test_server_interceptor_disabled_auth_allows_all(auth_server_and_port): - """Test that when auth is disabled, all requests are allowed.""" - with authentication_env_guard(): - # Set auth mode to disabled (or don't set it at all) - set_auth_mode("disabled") - reset_auth_token_state() - - # Get server from fixture - _, port = auth_server_and_port - - # Create client without any token - async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: - stub = reporter_pb2_grpc.ReporterServiceStub(channel) - - request = reporter_pb2.HealthCheckRequest() - response = await stub.HealthCheck(request) - - # Should succeed even without token - assert response is not None - - -@pytest.mark.asyncio -async def test_no_interceptor_allows_all_requests(no_auth_server_and_port): - """Test that server without interceptor allows all requests.""" - with authentication_env_guard(): - # Even with token auth enabled, server without interceptor allows all - token = uuid.uuid4().hex - set_auth_mode("token") - set_env_auth_token(token) - reset_auth_token_state() - - _, port = no_auth_server_and_port - - # Create client without token - async with aiogrpc.insecure_channel(f"localhost:{port}") as channel: - stub = reporter_pb2_grpc.ReporterServiceStub(channel) - - request = reporter_pb2.HealthCheckRequest() - response = await stub.HealthCheck(request) - - # Should succeed (no interceptor means no auth check) - assert response is not None - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main([__file__, "-v"]))