Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/ray/train/v2/jax/config.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we could add a unit test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a few way including using caplog and try to check jax.process_count()..
but seems Jax does not well support the cpu distributed env, I am adding a release test for it, would it be ok to check there: #57815

Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def _setup_jax_tpu_environment(
jax.distributed.initialize(master_addr_with_port, num_workers, index)


def _shutdown_jax_distributed():
"""Shutdown JAX distributed environment.

This function should be called on each worker during cleanup.
If JAX distributed was not initialized, this is a no-op.
"""
try:
import jax

jax.distributed.shutdown()
except Exception as e:
logger.warning(f"Error during JAX distributed shutdown: {e}")


class _JaxBackend(Backend):
def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
if not backend_config.use_tpu:
Expand All @@ -57,3 +71,23 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
)
)
ray.get(setup_futures)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig):
"""Cleanup JAX distributed resources when shutting down worker group."""
if not backend_config.use_tpu:
return

# Shutdown JAX distributed on all workers
shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed)

timeout_s = 30
try:
ray.get(shutdown_futures, timeout=timeout_s)
logger.debug("JAX distributed shutdown completed")
except ray.exceptions.GetTimeoutError:
logger.warning(
f"JAX distributed shutdown timed out after {timeout_s} seconds. "
"This may indicate workers are hung or unresponsive."
)
except Exception as e:
logger.warning(f"Error during JAX distributed shutdown: {e}")