From 3d09927728c0017ae386d259e9e15c173862111d Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Thu, 16 Oct 2025 09:49:53 -0700 Subject: [PATCH 1/7] Add shutdown for jaxtrainer Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 184 ++++++++++++++++++ python/ray/train/v2/tests/test_jax_trainer.py | 46 +++++ 2 files changed, 230 insertions(+) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 5e8dc5ba33e4..90c1a103c0c1 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -36,6 +36,68 @@ def _setup_jax_tpu_environment( jax.distributed.initialize(master_addr_with_port, num_workers, index) +def _setup_jax_distributed_environment( + master_addr_with_port: str, num_workers: int, index: int +): + """Set up distributed Jax training information. + + This function should be called on each worker. + """ + import jax + + jax.distributed.initialize(master_addr_with_port, num_workers, index) + + +def _shutdown_jax_distributed(): + """Shutdown JAX distributed environment. + + This function should be called on each worker during cleanup. + """ + try: + import jax + + # Only shutdown if JAX distributed was initialized + if jax.process_count() > 1: + jax.distributed.shutdown() + logger.debug("JAX distributed shutdown completed") + except Exception as e: + # Log but don't raise - we want graceful degradation during shutdown + logger.warning(f"Error during JAX distributed shutdown: {e}") os.environ["JAX_PLATFORMS"] = "tpu," + existing_jax_platforms + else: + # No existing platforms, just set to "tpu" + os.environ["JAX_PLATFORMS"] = "tpu" + + +def _setup_jax_distributed_environment( + master_addr_with_port: str, num_workers: int, index: int +): + """Set up distributed Jax training information. + + This function should be called on each worker. + """ + import jax + + jax.distributed.initialize(master_addr_with_port, num_workers, index) + + +def _shutdown_jax_distributed(): + """Shutdown JAX distributed environment. + + This function should be called on each worker during cleanup. + """ + try: + import jax + + # Only shutdown if JAX distributed was initialized + if jax.process_count() > 1: + jax.distributed.shutdown() + logger.debug("JAX distributed shutdown completed") + except Exception as e: + # Log but don't raise - we want graceful degradation during shutdown + logger.warning(f"Error during JAX distributed shutdown: {e}") + + + class _JaxBackend(Backend): def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): if not backend_config.use_tpu: @@ -57,3 +119,125 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): ) ) ray.get(setup_futures) + + + +class _JaxBackend(Backend): + def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): + if not backend_config.use_tpu: + return + + # Set JAX environment variables on all workers + worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) + + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) + master_addr_with_port = f"{master_addr}:{master_port}" + + # Get setup tasks in order to throw errors on failure. + setup_futures = [] + for i in range(len(worker_group)): + setup_futures.append( + worker_group.execute_single_async( + i, + _setup_jax_distributed_environment, + master_addr_with_port=master_addr_with_port, + num_workers=len(worker_group), + index=i, + ) + ) + ray.get(setup_futures) + + + class _JaxBackend(Backend): + def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): + if not backend_config.use_tpu: + return + + # Set JAX environment variables on all workers + worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) + + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) + master_addr_with_port = f"{master_addr}:{master_port}" + + # Get setup tasks in order to throw errors on failure. + setup_futures = [] + for i in range(len(worker_group)): + setup_futures.append( + worker_group.execute_single_async( + i, + _setup_jax_distributed_environment, + master_addr_with_port=master_addr_with_port, + num_workers=len(worker_group), + index=i, + ) + ) + ray.get(setup_futures) + + def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): + """Cleanup JAX distributed resources when shutting down worker group. + + This is critical to prevent resource leaks and hanging workers. + """ + if not backend_config.use_tpu: + return + + # Shutdown JAX distributed on all workers + shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed) + + # Wait for shutdown to complete with a reasonable timeout + timeout_s = 30 # JAX shutdown should be quick + try: + ray.get(shutdown_futures, timeout=timeout_s) + except ray.exceptions.GetTimeoutError: + logger.warning( + f"JAX distributed shutdown timed out after {timeout_s} seconds. " + "This may indicate workers are hung or unresponsive." + ) + except Exception as e: + logger.warning(f"Error during JAX distributed shutdown: {e}")xBackend(Backend): + def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): + if not backend_config.use_tpu: + return + + # Set JAX environment variables on all workers + worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) + + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) + master_addr_with_port = f"{master_addr}:{master_port}" + + # Get setup tasks in order to throw errors on failure. + setup_futures = [] + for i in range(len(worker_group)): + setup_futures.append( + worker_group.execute_single_async( + i, + _setup_jax_distributed_environment, + master_addr_with_port=master_addr_with_port, + num_workers=len(worker_group), + index=i, + ) + ) + ray.get(setup_futures) + + def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): + """Cleanup JAX distributed resources when shutting down worker group. + + This is critical to prevent resource leaks and hanging workers. + """ + if not backend_config.use_tpu: + return + + # Shutdown JAX distributed on all workers + shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed) + + # Wait for shutdown to complete with a reasonable timeout + timeout_s = 30 # JAX shutdown should be quick + try: + ray.get(shutdown_futures, timeout=timeout_s) + except ray.exceptions.GetTimeoutError: + logger.warning( + f"JAX distributed shutdown timed out after {timeout_s} seconds. " + "This may indicate workers are hung or unresponsive." + ) + except Exception as e: + logger.warning(f"Error during JAX distributed shutdown: {e}") diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 4f8db5cf8dd8..2dcb9e479213 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -71,6 +71,25 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) +def train_func_check_distributed_shutdown(): + """Training function to verify JAX distributed is properly initialized and can be checked.""" + import jax + + from ray import train + + # Verify JAX distributed is initialized + process_count = jax.process_count() + process_index = jax.process_index() + + train.report( + { + "process_count": process_count, + "process_index": process_index, + "initialized": process_count > 1, + } + ) + + def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, @@ -124,6 +143,33 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): assert len(labeled_nodes) == 2 + +def test_jax_distributed_initialization_multihost(ray_tpu_multi_host, tmp_path): + """Test that JAX distributed is properly initialized in multi-host setup. + + This test also verifies that the shutdown logic works correctly. + """ + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_distributed_shutdown, + scaling_config=ScalingConfig( + num_workers=2, + resources_per_worker={"TPU": 4}, + use_tpu=True, + topology="2x2x2", + accelerator_type="TPU-V4", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + ), + ) + result = trainer.fit() + assert result.error is None + + # Verify that JAX distributed was initialized with correct process count + assert result.metrics["process_count"] == 2 + assert result.metrics["initialized"] is True + + if __name__ == "__main__": import sys From ed041ecd5d3b8bbcb08da9eea6cad91f35dbc79c Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Thu, 16 Oct 2025 10:47:27 -0700 Subject: [PATCH 2/7] change unit test to capture debug log Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 151 +----------------- python/ray/train/v2/tests/test_jax_trainer.py | 74 ++++----- 2 files changed, 32 insertions(+), 193 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 90c1a103c0c1..12c6881229b0 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -36,68 +36,22 @@ def _setup_jax_tpu_environment( jax.distributed.initialize(master_addr_with_port, num_workers, index) -def _setup_jax_distributed_environment( - master_addr_with_port: str, num_workers: int, index: int -): - """Set up distributed Jax training information. - - This function should be called on each worker. - """ - import jax - - jax.distributed.initialize(master_addr_with_port, num_workers, index) - - def _shutdown_jax_distributed(): """Shutdown JAX distributed environment. This function should be called on each worker during cleanup. + If JAX distributed was not initialized, this is a no-op. """ try: import jax - # Only shutdown if JAX distributed was initialized - if jax.process_count() > 1: - jax.distributed.shutdown() - logger.debug("JAX distributed shutdown completed") - except Exception as e: - # Log but don't raise - we want graceful degradation during shutdown - logger.warning(f"Error during JAX distributed shutdown: {e}") os.environ["JAX_PLATFORMS"] = "tpu," + existing_jax_platforms - else: - # No existing platforms, just set to "tpu" - os.environ["JAX_PLATFORMS"] = "tpu" - - -def _setup_jax_distributed_environment( - master_addr_with_port: str, num_workers: int, index: int -): - """Set up distributed Jax training information. - - This function should be called on each worker. - """ - import jax - - jax.distributed.initialize(master_addr_with_port, num_workers, index) - - -def _shutdown_jax_distributed(): - """Shutdown JAX distributed environment. - - This function should be called on each worker during cleanup. - """ - try: - import jax - - # Only shutdown if JAX distributed was initialized - if jax.process_count() > 1: - jax.distributed.shutdown() - logger.debug("JAX distributed shutdown completed") + jax.distributed.shutdown() + logger.debug("JAX distributed shutdown completed") except Exception as e: # Log but don't raise - we want graceful degradation during shutdown logger.warning(f"Error during JAX distributed shutdown: {e}") - class _JaxBackend(Backend): def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): if not backend_config.use_tpu: @@ -120,105 +74,6 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): ) ray.get(setup_futures) - - -class _JaxBackend(Backend): - def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): - if not backend_config.use_tpu: - return - - # Set JAX environment variables on all workers - worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) - - master_addr, master_port = worker_group.execute_single(0, get_address_and_port) - master_addr_with_port = f"{master_addr}:{master_port}" - - # Get setup tasks in order to throw errors on failure. - setup_futures = [] - for i in range(len(worker_group)): - setup_futures.append( - worker_group.execute_single_async( - i, - _setup_jax_distributed_environment, - master_addr_with_port=master_addr_with_port, - num_workers=len(worker_group), - index=i, - ) - ) - ray.get(setup_futures) - - - class _JaxBackend(Backend): - def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): - if not backend_config.use_tpu: - return - - # Set JAX environment variables on all workers - worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) - - master_addr, master_port = worker_group.execute_single(0, get_address_and_port) - master_addr_with_port = f"{master_addr}:{master_port}" - - # Get setup tasks in order to throw errors on failure. - setup_futures = [] - for i in range(len(worker_group)): - setup_futures.append( - worker_group.execute_single_async( - i, - _setup_jax_distributed_environment, - master_addr_with_port=master_addr_with_port, - num_workers=len(worker_group), - index=i, - ) - ) - ray.get(setup_futures) - - def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): - """Cleanup JAX distributed resources when shutting down worker group. - - This is critical to prevent resource leaks and hanging workers. - """ - if not backend_config.use_tpu: - return - - # Shutdown JAX distributed on all workers - shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed) - - # Wait for shutdown to complete with a reasonable timeout - timeout_s = 30 # JAX shutdown should be quick - try: - ray.get(shutdown_futures, timeout=timeout_s) - except ray.exceptions.GetTimeoutError: - logger.warning( - f"JAX distributed shutdown timed out after {timeout_s} seconds. " - "This may indicate workers are hung or unresponsive." - ) - except Exception as e: - logger.warning(f"Error during JAX distributed shutdown: {e}")xBackend(Backend): - def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): - if not backend_config.use_tpu: - return - - # Set JAX environment variables on all workers - worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) - - master_addr, master_port = worker_group.execute_single(0, get_address_and_port) - master_addr_with_port = f"{master_addr}:{master_port}" - - # Get setup tasks in order to throw errors on failure. - setup_futures = [] - for i in range(len(worker_group)): - setup_futures.append( - worker_group.execute_single_async( - i, - _setup_jax_distributed_environment, - master_addr_with_port=master_addr_with_port, - num_workers=len(worker_group), - index=i, - ) - ) - ray.get(setup_futures) - def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): """Cleanup JAX distributed resources when shutting down worker group. diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 2dcb9e479213..a8a72a4b9c81 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -1,3 +1,5 @@ +import logging + import pytest import ray @@ -71,26 +73,11 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) -def train_func_check_distributed_shutdown(): - """Training function to verify JAX distributed is properly initialized and can be checked.""" - import jax - - from ray import train - - # Verify JAX distributed is initialized - process_count = jax.process_count() - process_index = jax.process_index() - - train.report( - { - "process_count": process_count, - "process_index": process_index, - "initialized": process_count > 1, - } - ) - +def test_minimal_singlehost(ray_tpu_single_host, tmp_path, caplog): + """Test single-host TPU training and verify shutdown callback is invoked.""" + # Capture DEBUG logs to verify shutdown was called + caplog.set_level(logging.DEBUG) -def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, # Topology can be omitted for single-host. @@ -114,8 +101,22 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path): ] assert len(labeled_nodes) == 1 + # Verify that JAX distributed shutdown was called via backend on_shutdown callback + shutdown_log_found = any( + "JAX distributed shutdown completed" in record.message + for record in caplog.records + ) + assert shutdown_log_found, ( + "Expected 'JAX distributed shutdown completed' in logs. " + "This indicates the backend on_shutdown() callback was not invoked." + ) + + +def test_minimal_multihost(ray_tpu_multi_host, tmp_path, caplog): + """Test multi-host TPU training and verify shutdown callback is invoked.""" + # Capture DEBUG logs to verify shutdown was called + caplog.set_level(logging.DEBUG) -def test_minimal_multihost(ray_tpu_multi_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, scaling_config=ScalingConfig( @@ -142,32 +143,15 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): ] assert len(labeled_nodes) == 2 - - -def test_jax_distributed_initialization_multihost(ray_tpu_multi_host, tmp_path): - """Test that JAX distributed is properly initialized in multi-host setup. - - This test also verifies that the shutdown logic works correctly. - """ - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_distributed_shutdown, - scaling_config=ScalingConfig( - num_workers=2, - resources_per_worker={"TPU": 4}, - use_tpu=True, - topology="2x2x2", - accelerator_type="TPU-V4", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - ), + # Verify that JAX distributed shutdown was called via backend on_shutdown callback + shutdown_log_found = any( + "JAX distributed shutdown completed" in record.message + for record in caplog.records + ) + assert shutdown_log_found, ( + "Expected 'JAX distributed shutdown completed' in logs. " + "This indicates the backend on_shutdown() callback was not invoked." ) - result = trainer.fit() - assert result.error is None - - # Verify that JAX distributed was initialized with correct process count - assert result.metrics["process_count"] == 2 - assert result.metrics["initialized"] is True if __name__ == "__main__": From 0b92a063027418e355e836bd3e5f0841c824647a Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Thu, 16 Oct 2025 11:57:47 -0700 Subject: [PATCH 3/7] add debug log check in controller Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 12c6881229b0..7a8f955122b6 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -46,9 +46,7 @@ def _shutdown_jax_distributed(): import jax jax.distributed.shutdown() - logger.debug("JAX distributed shutdown completed") except Exception as e: - # Log but don't raise - we want graceful degradation during shutdown logger.warning(f"Error during JAX distributed shutdown: {e}") @@ -75,20 +73,17 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): ray.get(setup_futures) def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): - """Cleanup JAX distributed resources when shutting down worker group. - - This is critical to prevent resource leaks and hanging workers. - """ + """Cleanup JAX distributed resources when shutting down worker group.""" if not backend_config.use_tpu: return # Shutdown JAX distributed on all workers shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed) - # Wait for shutdown to complete with a reasonable timeout - timeout_s = 30 # JAX shutdown should be quick + timeout_s = 30 try: ray.get(shutdown_futures, timeout=timeout_s) + logger.debug("JAX distributed shutdown completed") except ray.exceptions.GetTimeoutError: logger.warning( f"JAX distributed shutdown timed out after {timeout_s} seconds. " From b1f131f69ff02640051e6766c18506b974609455 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Thu, 16 Oct 2025 13:05:09 -0700 Subject: [PATCH 4/7] not asserting logging since it is ray actors Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 34 ++----------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index a8a72a4b9c81..4f8db5cf8dd8 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -1,5 +1,3 @@ -import logging - import pytest import ray @@ -73,11 +71,7 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) -def test_minimal_singlehost(ray_tpu_single_host, tmp_path, caplog): - """Test single-host TPU training and verify shutdown callback is invoked.""" - # Capture DEBUG logs to verify shutdown was called - caplog.set_level(logging.DEBUG) - +def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, # Topology can be omitted for single-host. @@ -101,22 +95,8 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path, caplog): ] assert len(labeled_nodes) == 1 - # Verify that JAX distributed shutdown was called via backend on_shutdown callback - shutdown_log_found = any( - "JAX distributed shutdown completed" in record.message - for record in caplog.records - ) - assert shutdown_log_found, ( - "Expected 'JAX distributed shutdown completed' in logs. " - "This indicates the backend on_shutdown() callback was not invoked." - ) - - -def test_minimal_multihost(ray_tpu_multi_host, tmp_path, caplog): - """Test multi-host TPU training and verify shutdown callback is invoked.""" - # Capture DEBUG logs to verify shutdown was called - caplog.set_level(logging.DEBUG) +def test_minimal_multihost(ray_tpu_multi_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, scaling_config=ScalingConfig( @@ -143,16 +123,6 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path, caplog): ] assert len(labeled_nodes) == 2 - # Verify that JAX distributed shutdown was called via backend on_shutdown callback - shutdown_log_found = any( - "JAX distributed shutdown completed" in record.message - for record in caplog.records - ) - assert shutdown_log_found, ( - "Expected 'JAX distributed shutdown completed' in logs. " - "This indicates the backend on_shutdown() callback was not invoked." - ) - if __name__ == "__main__": import sys From f18748c9769cba9cf617adce30dff2e10283dd0d Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Fri, 17 Oct 2025 11:21:01 -0700 Subject: [PATCH 5/7] address comment Signed-off-by: Lehui Liu --- python/ray/train/constants.py | 5 +++++ python/ray/train/tests/test_backend.py | 17 +++++++++++++++++ python/ray/train/v2/jax/config.py | 10 +++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 8e2294827dcd..a51cc4920845 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -132,6 +132,10 @@ def _v2_migration_warnings_enabled() -> bool: TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S" DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30 +# Seconds to wait for JAX distributed shutdown. +JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = "JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S" +DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = 30 + # NOTE: When adding a new environment variable, please track it in this list. TRAIN_ENV_VARS = { ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, @@ -144,6 +148,7 @@ def _v2_migration_warnings_enabled() -> bool: RAY_TRAIN_ENABLE_STATE_TRACKING, TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE, TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S, + JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, } # Key for AIR Checkpoint metadata in TrainingResult metadata diff --git a/python/ray/train/tests/test_backend.py b/python/ray/train/tests/test_backend.py index 7ea35b11e6bc..bc957d62bb95 100644 --- a/python/ray/train/tests/test_backend.py +++ b/python/ray/train/tests/test_backend.py @@ -28,10 +28,12 @@ from ray.train.constants import ( ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, + JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S, TRAIN_ENABLE_WORKER_SPREAD_ENV, ) from ray.train.torch import TorchConfig +from ray.train.v2.jax.config import JaxConfig from ray.util.placement_group import get_current_placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.state import list_actors @@ -631,6 +633,21 @@ def test(): assert worker_result != placement_group.id +@pytest.mark.parametrize("timeout_s", [5, 0]) +def test_jax_distributed_shutdown_timeout(ray_start_2_cpus, monkeypatch, timeout_s): + monkeypatch.setenv(JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, timeout_s) + monkeypatch.setenv("JAX_PLATFORMS", "tpu") + jax_config = JaxConfig(use_tpu=True) + e = BackendExecutor(jax_config, num_workers=2) + e.start() + + _start_training(e, lambda: 1) + assert e.finish_training() == [1, 1] + + # Verify that we do not raise an exception even if we time out + e._backend.on_shutdown(e.worker_group, e._backend_config) + + if __name__ == "__main__": import sys diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 7a8f955122b6..9b3f5c7d5cdc 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -3,9 +3,14 @@ from dataclasses import dataclass import ray +from ray._private import ray_constants from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup from ray.train.backend import Backend, BackendConfig +from ray.train.constants import ( + DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, + JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, +) from ray.util import PublicAPI logger = logging.getLogger(__name__) @@ -80,7 +85,10 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig): # Shutdown JAX distributed on all workers shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed) - timeout_s = 30 + timeout_s = ray_constants.env_integer( + JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, + DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, + ) try: ray.get(shutdown_futures, timeout=timeout_s) logger.debug("JAX distributed shutdown completed") From 477407d7a124ec0e9ed57ea6f4fa021fdffc5757 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Fri, 17 Oct 2025 14:53:34 -0700 Subject: [PATCH 6/7] add jax backend timeout Signed-off-by: Lehui Liu --- python/ray/train/tests/test_backend.py | 30 ++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/python/ray/train/tests/test_backend.py b/python/ray/train/tests/test_backend.py index bc957d62bb95..7718a090f11a 100644 --- a/python/ray/train/tests/test_backend.py +++ b/python/ray/train/tests/test_backend.py @@ -633,10 +633,36 @@ def test(): assert worker_result != placement_group.id +def test_jax_start_shutdown(ray_start_2_cpus, monkeypatch): + """Test that JAX distributed is properly initialized and shutdown.""" + monkeypatch.setenv("JAX_PLATFORMS", "tpu") + jax_config = JaxConfig(use_tpu=True) + e = BackendExecutor(jax_config, num_workers=2) + e.start() + + def check_jax_distributed(): + try: + import jax + + # Check if JAX distributed is initialized + return jax.process_count() == 2 + except Exception: + # If jax is not available or not initialized, return False + return False + + _start_training(e, check_jax_distributed) + assert all(e.finish_training()) + + e._backend.on_shutdown(e.worker_group, e._backend_config) + + _start_training(e, check_jax_distributed) + assert not any(e.finish_training()) + + @pytest.mark.parametrize("timeout_s", [5, 0]) def test_jax_distributed_shutdown_timeout(ray_start_2_cpus, monkeypatch, timeout_s): - monkeypatch.setenv(JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, timeout_s) - monkeypatch.setenv("JAX_PLATFORMS", "tpu") + """Test that JAX distributed shutdown respects the timeout env var.""" + monkeypatch.setenv(JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, str(timeout_s)) jax_config = JaxConfig(use_tpu=True) e = BackendExecutor(jax_config, num_workers=2) e.start() From 3c1e0a8a721307aad98a126dece8818010e9405d Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Fri, 17 Oct 2025 15:50:06 -0700 Subject: [PATCH 7/7] address comments Signed-off-by: Lehui Liu --- python/ray/train/tests/test_backend.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/python/ray/train/tests/test_backend.py b/python/ray/train/tests/test_backend.py index 7718a090f11a..a33713d45578 100644 --- a/python/ray/train/tests/test_backend.py +++ b/python/ray/train/tests/test_backend.py @@ -633,32 +633,6 @@ def test(): assert worker_result != placement_group.id -def test_jax_start_shutdown(ray_start_2_cpus, monkeypatch): - """Test that JAX distributed is properly initialized and shutdown.""" - monkeypatch.setenv("JAX_PLATFORMS", "tpu") - jax_config = JaxConfig(use_tpu=True) - e = BackendExecutor(jax_config, num_workers=2) - e.start() - - def check_jax_distributed(): - try: - import jax - - # Check if JAX distributed is initialized - return jax.process_count() == 2 - except Exception: - # If jax is not available or not initialized, return False - return False - - _start_training(e, check_jax_distributed) - assert all(e.finish_training()) - - e._backend.on_shutdown(e.worker_group, e._backend_config) - - _start_training(e, check_jax_distributed) - assert not any(e.finish_training()) - - @pytest.mark.parametrize("timeout_s", [5, 0]) def test_jax_distributed_shutdown_timeout(ray_start_2_cpus, monkeypatch, timeout_s): """Test that JAX distributed shutdown respects the timeout env var."""