Skip to content
22 changes: 20 additions & 2 deletions python/ray/train/v2/jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,22 @@ def backend_cls(self):
return _JaxBackend


def _setup_jax_tpu_environment(
def _set_jax_env_vars(use_tpu: bool):
"""Set JAX environment variables based on configuration.

If JAX_PLATFORMS is already set (by user or test), we trust that configuration
and do nothing. Otherwise, if use_tpu=True, we set it to "tpu".
"""
# If user/test already set JAX_PLATFORMS, respect their choice
if os.environ.get("JAX_PLATFORMS"):
return

# Only set JAX_PLATFORMS if not already specified
if use_tpu:
os.environ["JAX_PLATFORMS"] = "tpu"


def _setup_jax_distributed_environment(
master_addr_with_port: str, num_workers: int, index: int
):
"""Set up distributed Jax training information.
Expand Down Expand Up @@ -60,6 +75,9 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
if not backend_config.use_tpu:
return

# Set JAX environment variables on all workers
worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu)
Copy link
Contributor

@matthewdeng matthewdeng Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call this in a separate execute call? Could we instead call this directly in _setup_jax_distributed_environment to avoid the overhead of an additional blocking method call?


master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
master_addr_with_port = f"{master_addr}:{master_port}"

Expand All @@ -69,7 +87,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_jax_tpu_environment,
_setup_jax_distributed_environment,
master_addr_with_port=master_addr_with_port,
num_workers=len(worker_group),
index=i,
Expand Down
10 changes: 10 additions & 0 deletions python/ray/train/v2/tests/test_jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand All @@ -108,6 +113,11 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path):
),
run_config=RunConfig(
storage_path=str(tmp_path),
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "cpu",
},
},
),
)
result = trainer.fit()
Expand Down