Skip to content
33 changes: 26 additions & 7 deletions python/ray/train/v2/jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,33 @@ def backend_cls(self):
return _JaxBackend


def _setup_jax_tpu_environment(
master_addr_with_port: str, num_workers: int, index: int
def _setup_jax_distributed_environment(
master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool
):
"""Set up distributed Jax training information.

This function should be called on each worker.
This function should be called on each worker. It sets JAX environment
variables and initializes JAX distributed training.

Args:
master_addr_with_port: The master address with port for coordination.
num_workers: Total number of workers.
index: Index of this worker.
use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not
already set, it will be set to "tpu".
"""
import jax
# If user/test already set JAX_PLATFORMS, respect their choice
if os.environ.get("JAX_PLATFORMS"):
jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
else:
# Set JAX_PLATFORMS based on configuration
if use_tpu:
os.environ["JAX_PLATFORMS"] = "tpu"
jax_platforms = "tpu"
else:
jax_platforms = ""

jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
import jax

if "tpu" in jax_platforms.split(","):
jax.distributed.initialize(master_addr_with_port, num_workers, index)
Expand Down Expand Up @@ -63,16 +80,18 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
master_addr_with_port = f"{master_addr}:{master_port}"

# Get setup tasks in order to throw errors on failure.
# Set up JAX distributed environment on all workers
# This sets JAX_PLATFORMS env var and initializes JAX distributed
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_jax_tpu_environment,
_setup_jax_distributed_environment,
master_addr_with_port=master_addr_with_port,
num_workers=len(worker_group),
index=i,
use_tpu=backend_config.use_tpu,
)
)
ray.get(setup_futures)
Expand Down
10 changes: 10 additions & 0 deletions python/ray/train/v2/tests/test_jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand All @@ -108,6 +113,11 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand Down