Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def _v2_migration_warnings_enabled() -> bool:
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S"
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30

# Seconds to wait for JAX distributed shutdown.
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = "JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S"
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = 30

# NOTE: When adding a new environment variable, please track it in this list.
TRAIN_ENV_VARS = {
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
Expand All @@ -144,6 +148,7 @@ def _v2_migration_warnings_enabled() -> bool:
RAY_TRAIN_ENABLE_STATE_TRACKING,
TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE,
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
}

# Key for AIR Checkpoint metadata in TrainingResult metadata
Expand Down
17 changes: 17 additions & 0 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from ray.train.constants import (
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
)
from ray.train.torch import TorchConfig
from ray.train.v2.jax.config import JaxConfig
from ray.util.placement_group import get_current_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.state import list_actors
Expand Down Expand Up @@ -631,6 +633,21 @@ def test():
assert worker_result != placement_group.id


@pytest.mark.parametrize("timeout_s", [5, 0])
def test_jax_distributed_shutdown_timeout(ray_start_2_cpus, monkeypatch, timeout_s):
"""Test that JAX distributed shutdown respects the timeout env var."""
monkeypatch.setenv(JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S, str(timeout_s))
jax_config = JaxConfig(use_tpu=True)
e = BackendExecutor(jax_config, num_workers=2)
e.start()

_start_training(e, lambda: 1)
assert e.finish_training() == [1, 1]

# Verify that we do not raise an exception even if we time out
e._backend.on_shutdown(e.worker_group, e._backend_config)


if __name__ == "__main__":
import sys

Expand Down
42 changes: 42 additions & 0 deletions python/ray/train/v2/jax/config.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we could add a unit test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a few way including using caplog and try to check jax.process_count()..
but seems Jax does not well support the cpu distributed env, I am adding a release test for it, would it be ok to check there: #57815

Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
from dataclasses import dataclass

import ray
from ray._private import ray_constants
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig
from ray.train.constants import (
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
)
from ray.util import PublicAPI

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,6 +41,20 @@ def _setup_jax_tpu_environment(
jax.distributed.initialize(master_addr_with_port, num_workers, index)


def _shutdown_jax_distributed():
"""Shutdown JAX distributed environment.

This function should be called on each worker during cleanup.
If JAX distributed was not initialized, this is a no-op.
"""
try:
import jax

jax.distributed.shutdown()
except Exception as e:
logger.warning(f"Error during JAX distributed shutdown: {e}")


class _JaxBackend(Backend):
def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
if not backend_config.use_tpu:
Expand All @@ -57,3 +76,26 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
)
)
ray.get(setup_futures)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig):
"""Cleanup JAX distributed resources when shutting down worker group."""
if not backend_config.use_tpu:
return

# Shutdown JAX distributed on all workers
shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed)

timeout_s = ray_constants.env_integer(
JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
)
try:
ray.get(shutdown_futures, timeout=timeout_s)
logger.debug("JAX distributed shutdown completed")
except ray.exceptions.GetTimeoutError:
logger.warning(
f"JAX distributed shutdown timed out after {timeout_s} seconds. "
"This may indicate workers are hung or unresponsive."
)
except Exception as e:
logger.warning(f"Error during JAX distributed shutdown: {e}")