Skip to content

Commit 6e1f067

Browse files
liulehuielliot-barn
authored andcommitted
[train][jax_trainer] add jax.distributed.shutdown() for JaxBackend (#57802)
## Description 1. This PR added the `jax.distributed.shutdown()` for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers. 2. if `jax.distributed` is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html 3. Tested on Anyscale workspace. <img width="1264" height="62" alt="image" src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72" /> Signed-off-by: elliot-barn <[email protected]>
1 parent 44dff2b commit 6e1f067

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

python/ray/train/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def _v2_migration_warnings_enabled() -> bool:
132132
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S"
133133
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30
134134

135+
# Seconds to wait for JAX distributed shutdown.
136+
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = "JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S"
137+
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = 30
138+
135139
# NOTE: When adding a new environment variable, please track it in this list.
136140
TRAIN_ENV_VARS = {
137141
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
@@ -144,6 +148,7 @@ def _v2_migration_warnings_enabled() -> bool:
144148
RAY_TRAIN_ENABLE_STATE_TRACKING,
145149
TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE,
146150
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
151+
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
147152
}
148153

149154
# Key for AIR Checkpoint metadata in TrainingResult metadata

python/ray/train/tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
from ray.train.constants import (
2929
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
3030
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
31+
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
3132
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
3233
TRAIN_ENABLE_WORKER_SPREAD_ENV,
3334
)
3435
from ray.train.torch import TorchConfig
36+
from ray.train.v2.jax.config import JaxConfig
3537
from ray.util.placement_group import get_current_placement_group
3638
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
3739
from ray.util.state import list_actors
@@ -631,6 +633,21 @@ def test():
631633
assert worker_result != placement_group.id
632634

633635

636+
@pytest.mark.parametrize("timeout_s", [5, 0])
637+
def test_jax_distributed_shutdown_timeout(ray_start_2_cpus, monkeypatch, timeout_s):
638+
"""Test that JAX distributed shutdown respects the timeout env var."""
639+
monkeypatch.setenv(JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, str(timeout_s))
640+
jax_config = JaxConfig(use_tpu=True)
641+
e = BackendExecutor(jax_config, num_workers=2)
642+
e.start()
643+
644+
_start_training(e, lambda: 1)
645+
assert e.finish_training() == [1, 1]
646+
647+
# Verify that we do not raise an exception even if we time out
648+
e._backend.on_shutdown(e.worker_group, e._backend_config)
649+
650+
634651
if __name__ == "__main__":
635652
import sys
636653

python/ray/train/v2/jax/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
from dataclasses import dataclass
44

55
import ray
6+
from ray._private import ray_constants
67
from ray.train._internal.utils import get_address_and_port
78
from ray.train._internal.worker_group import WorkerGroup
89
from ray.train.backend import Backend, BackendConfig
10+
from ray.train.constants import (
11+
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
12+
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
13+
)
914
from ray.util import PublicAPI
1015

1116
logger = logging.getLogger(__name__)
@@ -36,6 +41,20 @@ def _setup_jax_tpu_environment(
3641
jax.distributed.initialize(master_addr_with_port, num_workers, index)
3742

3843

44+
def _shutdown_jax_distributed():
45+
"""Shutdown JAX distributed environment.
46+
47+
This function should be called on each worker during cleanup.
48+
If JAX distributed was not initialized, this is a no-op.
49+
"""
50+
try:
51+
import jax
52+
53+
jax.distributed.shutdown()
54+
except Exception as e:
55+
logger.warning(f"Error during JAX distributed shutdown: {e}")
56+
57+
3958
class _JaxBackend(Backend):
4059
def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
4160
if not backend_config.use_tpu:
@@ -57,3 +76,26 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
5776
)
5877
)
5978
ray.get(setup_futures)
79+
80+
def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig):
81+
"""Cleanup JAX distributed resources when shutting down worker group."""
82+
if not backend_config.use_tpu:
83+
return
84+
85+
# Shutdown JAX distributed on all workers
86+
shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed)
87+
88+
timeout_s = ray_constants.env_integer(
89+
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
90+
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
91+
)
92+
try:
93+
ray.get(shutdown_futures, timeout=timeout_s)
94+
logger.debug("JAX distributed shutdown completed")
95+
except ray.exceptions.GetTimeoutError:
96+
logger.warning(
97+
f"JAX distributed shutdown timed out after {timeout_s} seconds. "
98+
"This may indicate workers are hung or unresponsive."
99+
)
100+
except Exception as e:
101+
logger.warning(f"Error during JAX distributed shutdown: {e}")

0 commit comments

Comments
 (0)