Skip to content
31 changes: 24 additions & 7 deletions python/ray/train/v2/jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,32 @@ def backend_cls(self):
return _JaxBackend


def _setup_jax_tpu_environment(
master_addr_with_port: str, num_workers: int, index: int
def _setup_jax_distributed_environment(
master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool
):
"""Set up distributed Jax training information.

This function should be called on each worker.
"""
import jax
This function should be called on each worker. It sets JAX environment
variables and initializes JAX distributed training.

Args:
master_addr_with_port: The master address with port for coordination.
num_workers: Total number of workers.
index: Index of this worker.
use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not
already set, it will be set to "tpu".
"""
# Get JAX_PLATFORMS from environment if already set
jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()

if not jax_platforms and use_tpu:
os.environ["JAX_PLATFORMS"] = "tpu"
jax_platforms = "tpu"

# TODO(lehui): Add env vars for JAX on GPU.

import jax

if "tpu" in jax_platforms.split(","):
jax.distributed.initialize(master_addr_with_port, num_workers, index)

Expand All @@ -63,16 +78,18 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
master_addr_with_port = f"{master_addr}:{master_port}"

# Get setup tasks in order to throw errors on failure.
# Set up JAX distributed environment on all workers
# This sets JAX_PLATFORMS env var and initializes JAX distributed
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_jax_tpu_environment,
_setup_jax_distributed_environment,
master_addr_with_port=master_addr_with_port,
num_workers=len(worker_group),
index=i,
use_tpu=backend_config.use_tpu,
)
)
ray.get(setup_futures)
Expand Down
10 changes: 10 additions & 0 deletions python/ray/train/v2/tests/test_jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand All @@ -108,6 +113,11 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand Down