Skip to content
47 changes: 24 additions & 23 deletions python/ray/train/v2/jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,33 @@ def backend_cls(self):
return _JaxBackend


def _set_jax_env_vars(use_tpu: bool):
"""Set JAX environment variables based on configuration.

If JAX_PLATFORMS is already set (by user or test), we trust that configuration
and do nothing. Otherwise, if use_tpu=True, we set it to "tpu".
"""
# If user/test already set JAX_PLATFORMS, respect their choice
if os.environ.get("JAX_PLATFORMS"):
return

# Only set JAX_PLATFORMS if not already specified
if use_tpu:
os.environ["JAX_PLATFORMS"] = "tpu"


def _setup_jax_distributed_environment(
master_addr_with_port: str, num_workers: int, index: int
master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool
):
"""Set up distributed Jax training information.

This function should be called on each worker.
This function should be called on each worker. It sets JAX environment
variables and initializes JAX distributed training.

Args:
master_addr_with_port: The master address with port for coordination.
num_workers: Total number of workers.
index: Index of this worker.
use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not
already set, it will be set to "tpu".
"""
import jax
# If user/test already set JAX_PLATFORMS, respect their choice
if os.environ.get("JAX_PLATFORMS"):
jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
else:
# Set JAX_PLATFORMS based on configuration
if use_tpu:
os.environ["JAX_PLATFORMS"] = "tpu"
jax_platforms = "tpu"
else:
jax_platforms = ""

jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
import jax

if "tpu" in jax_platforms.split(","):
jax.distributed.initialize(master_addr_with_port, num_workers, index)
Expand All @@ -75,13 +77,11 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
if not backend_config.use_tpu:
return

# Set JAX environment variables on all workers
worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu)

master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
master_addr_with_port = f"{master_addr}:{master_port}"

# Get setup tasks in order to throw errors on failure.
# Set up JAX distributed environment on all workers
# This sets JAX_PLATFORMS env var and initializes JAX distributed
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
Expand All @@ -91,6 +91,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
master_addr_with_port=master_addr_with_port,
num_workers=len(worker_group),
index=i,
use_tpu=backend_config.use_tpu,
)
)
ray.get(setup_futures)
Expand Down