diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 9b3f5c7d5cdc..e826909c2c6f 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -26,17 +26,32 @@ def backend_cls(self): return _JaxBackend -def _setup_jax_tpu_environment( - master_addr_with_port: str, num_workers: int, index: int +def _setup_jax_distributed_environment( + master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool ): """Set up distributed Jax training information. - This function should be called on each worker. - """ - import jax + This function should be called on each worker. It sets JAX environment + variables and initializes JAX distributed training. + Args: + master_addr_with_port: The master address with port for coordination. + num_workers: Total number of workers. + index: Index of this worker. + use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not + already set, it will be set to "tpu". + """ + # Get JAX_PLATFORMS from environment if already set jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + if not jax_platforms and use_tpu: + os.environ["JAX_PLATFORMS"] = "tpu" + jax_platforms = "tpu" + + # TODO(lehui): Add env vars for JAX on GPU. + + import jax + if "tpu" in jax_platforms.split(","): jax.distributed.initialize(master_addr_with_port, num_workers, index) @@ -63,16 +78,18 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): master_addr, master_port = worker_group.execute_single(0, get_address_and_port) master_addr_with_port = f"{master_addr}:{master_port}" - # Get setup tasks in order to throw errors on failure. + # Set up JAX distributed environment on all workers + # This sets JAX_PLATFORMS env var and initializes JAX distributed setup_futures = [] for i in range(len(worker_group)): setup_futures.append( worker_group.execute_single_async( i, - _setup_jax_tpu_environment, + _setup_jax_distributed_environment, master_addr_with_port=master_addr_with_port, num_workers=len(worker_group), index=i, + use_tpu=backend_config.use_tpu, ) ) ray.get(setup_futures) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 4f8db5cf8dd8..976f3d507819 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -83,6 +83,11 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path): ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, ), ) result = trainer.fit() @@ -108,6 +113,11 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, ), ) result = trainer.fit()