From 0bff06498f47f2953cd5946bc9a8f7ceaea7d1c4 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Mon, 4 Aug 2025 18:02:20 -0700 Subject: [PATCH 01/13] test jax trainer Signed-off-by: Lehui Liu --- .../__pycache__/serialization_context.cpython-39.pyc.4431272656 | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 diff --git a/python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 b/python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 new file mode 100644 index 000000000000..e69de29bb2d1 From 4be64d6a8f4aefeca8327a50128d2ec74938c472 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 15 Oct 2025 21:40:51 -0700 Subject: [PATCH 02/13] set JAX_PLATFORMS automatically Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 28 ++++- python/ray/train/v2/tests/test_jax_trainer.py | 111 ++++++++++++++++++ 2 files changed, 133 insertions(+), 6 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 9b3f5c7d5cdc..5150f7488a88 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -26,7 +26,23 @@ def backend_cls(self): return _JaxBackend -def _setup_jax_tpu_environment( +def _set_jax_env_vars(use_tpu: bool): + """Set JAX environment variables based on configuration.""" + if use_tpu: + # Get existing JAX_PLATFORMS if set + existing_jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + + if "tpu" in existing_jax_platforms: + return + elif existing_jax_platforms: + # Prepend tpu to existing platforms + os.environ["JAX_PLATFORMS"] = "tpu," + existing_jax_platforms + else: + # No existing platforms, just set to "tpu" + os.environ["JAX_PLATFORMS"] = "tpu" + + +def _setup_jax_distributed_environment( master_addr_with_port: str, num_workers: int, index: int ): """Set up distributed Jax training information. @@ -35,10 +51,7 @@ def _setup_jax_tpu_environment( """ import jax - jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() - - if "tpu" in jax_platforms.split(","): - jax.distributed.initialize(master_addr_with_port, num_workers, index) + jax.distributed.initialize(master_addr_with_port, num_workers, index) def _shutdown_jax_distributed(): @@ -60,6 +73,9 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): if not backend_config.use_tpu: return + # Set JAX environment variables on all workers + worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) master_addr_with_port = f"{master_addr}:{master_port}" @@ -69,7 +85,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): setup_futures.append( worker_group.execute_single_async( i, - _setup_jax_tpu_environment, + _setup_jax_distributed_environment, master_addr_with_port=master_addr_with_port, num_workers=len(worker_group), index=i, diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 4f8db5cf8dd8..223b4357441a 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -71,6 +71,16 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) +def train_func_check_env(): + """Training function to verify JAX_PLATFORMS env var is set.""" + import os + + from ray import train + + jax_platforms = os.environ.get("JAX_PLATFORMS", "") + train.report({"jax_platforms": jax_platforms}) + + def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, @@ -124,6 +134,107 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): assert len(labeled_nodes) == 2 +def test_jax_platforms_env_var_no_existing(ray_tpu_single_host, tmp_path, monkeypatch): + """Test that JAX_PLATFORMS is set to 'tpu' when use_tpu=True and no existing value.""" + # Ensure JAX_PLATFORMS is not set + monkeypatch.delenv("JAX_PLATFORMS", raising=False) + + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that JAX_PLATFORMS was set to "tpu" + jax_platforms = result.metrics["jax_platforms"] + assert jax_platforms == "tpu" + + +def test_jax_platforms_env_var_with_existing_tpu( + ray_tpu_single_host, tmp_path, monkeypatch +): + """Test that JAX_PLATFORMS is not modified when 'tpu' is already present.""" + monkeypatch.setenv("JAX_PLATFORMS", "tpu,cpu") + + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that JAX_PLATFORMS was not modified + jax_platforms = result.metrics["jax_platforms"] + assert jax_platforms == "tpu,cpu" + + +def test_jax_platforms_env_var_with_existing_other( + ray_tpu_single_host, tmp_path, monkeypatch +): + """Test that 'tpu' is prepended when JAX_PLATFORMS contains other platforms.""" + monkeypatch.setenv("JAX_PLATFORMS", "cpu") + + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that 'tpu' was prepended to existing platforms + jax_platforms = result.metrics["jax_platforms"] + assert jax_platforms == "tpu,cpu" + + +def test_jax_platforms_env_var_no_tpu_flag(ray_tpu_single_host, tmp_path, monkeypatch): + """Test that JAX_PLATFORMS is not set when use_tpu=False.""" + monkeypatch.delenv("JAX_PLATFORMS", raising=False) + + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=False, # Not using TPU + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that JAX_PLATFORMS was not set + jax_platforms = result.metrics["jax_platforms"] + assert jax_platforms == "" + + if __name__ == "__main__": import sys From 6f94496feac6853d77373cb10f7027e56b33353c Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 10:15:56 -0700 Subject: [PATCH 03/13] add unit test Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 90 +++++++------------ 1 file changed, 34 insertions(+), 56 deletions(-) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 223b4357441a..6e4b3e3d64ae 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -71,16 +71,6 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) -def train_func_check_env(): - """Training function to verify JAX_PLATFORMS env var is set.""" - import os - - from ray import train - - jax_platforms = os.environ.get("JAX_PLATFORMS", "") - train.report({"jax_platforms": jax_platforms}) - - def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, @@ -139,6 +129,13 @@ def test_jax_platforms_env_var_no_existing(ray_tpu_single_host, tmp_path, monkey # Ensure JAX_PLATFORMS is not set monkeypatch.delenv("JAX_PLATFORMS", raising=False) + def train_func_check_env(): + """Training function to verify JAX_PLATFORMS env var is set.""" + import os + + jax_platforms = os.environ.get("JAX_PLATFORMS", "") + assert jax_platforms == "tpu" + trainer = JaxTrainer( train_loop_per_worker=train_func_check_env, scaling_config=ScalingConfig( @@ -151,19 +148,18 @@ def test_jax_platforms_env_var_no_existing(ray_tpu_single_host, tmp_path, monkey storage_path=str(tmp_path), ), ) - result = trainer.fit() - assert result.error is None - - # Check that JAX_PLATFORMS was set to "tpu" - jax_platforms = result.metrics["jax_platforms"] - assert jax_platforms == "tpu" + trainer.fit() -def test_jax_platforms_env_var_with_existing_tpu( - ray_tpu_single_host, tmp_path, monkeypatch -): +def test_jax_platforms_env_var_with_existing_tpu(ray_tpu_single_host, tmp_path): """Test that JAX_PLATFORMS is not modified when 'tpu' is already present.""" - monkeypatch.setenv("JAX_PLATFORMS", "tpu,cpu") + + def train_func_check_env(): + """Training function to verify JAX_PLATFORMS env var is set.""" + import os + + jax_platforms = os.environ.get("JAX_PLATFORMS", "") + assert jax_platforms == "tpu,cpu" trainer = JaxTrainer( train_loop_per_worker=train_func_check_env, @@ -175,64 +171,46 @@ def test_jax_platforms_env_var_with_existing_tpu( ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "tpu,cpu", + }, + }, ), ) - result = trainer.fit() - assert result.error is None - - # Check that JAX_PLATFORMS was not modified - jax_platforms = result.metrics["jax_platforms"] - assert jax_platforms == "tpu,cpu" + trainer.fit() def test_jax_platforms_env_var_with_existing_other( ray_tpu_single_host, tmp_path, monkeypatch ): """Test that 'tpu' is prepended when JAX_PLATFORMS contains other platforms.""" - monkeypatch.setenv("JAX_PLATFORMS", "cpu") - - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - ), - ) - result = trainer.fit() - assert result.error is None - # Check that 'tpu' was prepended to existing platforms - jax_platforms = result.metrics["jax_platforms"] - assert jax_platforms == "tpu,cpu" + def train_func_check_env(): + """Training function to verify JAX_PLATFORMS env var is set.""" + import os - -def test_jax_platforms_env_var_no_tpu_flag(ray_tpu_single_host, tmp_path, monkeypatch): - """Test that JAX_PLATFORMS is not set when use_tpu=False.""" - monkeypatch.delenv("JAX_PLATFORMS", raising=False) + jax_platforms = os.environ.get("JAX_PLATFORMS", "") + assert jax_platforms == "tpu,cpu" trainer = JaxTrainer( train_loop_per_worker=train_func_check_env, scaling_config=ScalingConfig( num_workers=1, resources_per_worker={"TPU": 8}, - use_tpu=False, # Not using TPU + use_tpu=True, accelerator_type="TPU-V6E", ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, ), ) - result = trainer.fit() - assert result.error is None - - # Check that JAX_PLATFORMS was not set - jax_platforms = result.metrics["jax_platforms"] - assert jax_platforms == "" + trainer.fit() if __name__ == "__main__": From a58a54c706bf9647940f4035102cbd0cc64ea004 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 15 Oct 2025 21:40:51 -0700 Subject: [PATCH 04/13] set JAX_PLATFORMS automatically Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 6e4b3e3d64ae..01239e70f502 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -161,6 +161,7 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" + trainer = JaxTrainer( train_loop_per_worker=train_func_check_env, scaling_config=ScalingConfig( From 9f6e6f7acaaf9756be0de01504a7a8cecca8ef15 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 10:56:30 -0700 Subject: [PATCH 05/13] remove comment Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 5150f7488a88..3fa54c295e6f 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -35,10 +35,8 @@ def _set_jax_env_vars(use_tpu: bool): if "tpu" in existing_jax_platforms: return elif existing_jax_platforms: - # Prepend tpu to existing platforms os.environ["JAX_PLATFORMS"] = "tpu," + existing_jax_platforms else: - # No existing platforms, just set to "tpu" os.environ["JAX_PLATFORMS"] = "tpu" From 7e466754b34153a48a2e5c07e4638c6dc14ecd66 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 11:25:40 -0700 Subject: [PATCH 06/13] lint Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 01239e70f502..6e4b3e3d64ae 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -161,7 +161,6 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" - trainer = JaxTrainer( train_loop_per_worker=train_func_check_env, scaling_config=ScalingConfig( From 383cc369b7ba93d484c1a7453c898b76a5926520 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 11:34:47 -0700 Subject: [PATCH 07/13] try to mock jax distributed Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 78 +++++++++++-------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 6e4b3e3d64ae..73f47fc243fb 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest import ray @@ -161,24 +163,28 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "tpu,cpu", + # Mock JAX distributed setup to avoid TPU initialization + with mock.patch( + "ray.train.v2.jax.config._setup_jax_distributed_environment" + ), mock.patch("ray.train.v2.jax.config._shutdown_jax_distributed"): + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "tpu,cpu", + }, }, - }, - ), - ) - trainer.fit() + ), + ) + trainer.fit() def test_jax_platforms_env_var_with_existing_other( @@ -193,24 +199,28 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "cpu", + # Mock JAX distributed setup to avoid TPU initialization + with mock.patch( + "ray.train.v2.jax.config._setup_jax_distributed_environment" + ), mock.patch("ray.train.v2.jax.config._shutdown_jax_distributed"): + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, }, - }, - ), - ) - trainer.fit() + ), + ) + trainer.fit() if __name__ == "__main__": From e76eb82aa2b58ba131a9a81f54b33335789b13fc Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 12:14:42 -0700 Subject: [PATCH 08/13] mock jax distributed Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 99 +++++++++++-------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 73f47fc243fb..a92282495442 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -63,6 +63,29 @@ def reduce_health_check_interval(monkeypatch): yield +@pytest.fixture(autouse=True) +def mock_jax_distributed(monkeypatch): + """Mock JAX distributed setup/shutdown to avoid TPU initialization in tests. + + This prevents the RuntimeError: Unable to initialize backend 'tpu' error + in test environments without actual TPU hardware. + """ + # Mock the setup and shutdown functions to prevent JAX TPU initialization + mock_setup = mock.MagicMock() + mock_shutdown = mock.MagicMock() + + monkeypatch.setattr( + "ray.train.v2.jax.config._setup_jax_distributed_environment", + mock_setup, + ) + monkeypatch.setattr( + "ray.train.v2.jax.config._shutdown_jax_distributed", + mock_shutdown, + ) + + yield + + def train_func(): import jax @@ -163,28 +186,24 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" - # Mock JAX distributed setup to avoid TPU initialization - with mock.patch( - "ray.train.v2.jax.config._setup_jax_distributed_environment" - ), mock.patch("ray.train.v2.jax.config._shutdown_jax_distributed"): - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "tpu,cpu", - }, + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "tpu,cpu", }, - ), - ) - trainer.fit() + }, + ), + ) + trainer.fit() def test_jax_platforms_env_var_with_existing_other( @@ -199,28 +218,24 @@ def train_func_check_env(): jax_platforms = os.environ.get("JAX_PLATFORMS", "") assert jax_platforms == "tpu,cpu" - # Mock JAX distributed setup to avoid TPU initialization - with mock.patch( - "ray.train.v2.jax.config._setup_jax_distributed_environment" - ), mock.patch("ray.train.v2.jax.config._shutdown_jax_distributed"): - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "cpu", - }, + trainer = JaxTrainer( + train_loop_per_worker=train_func_check_env, + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", }, - ), - ) - trainer.fit() + }, + ), + ) + trainer.fit() if __name__ == "__main__": From 838662110523c7e85f677289b8dd70222ff92325 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 13:49:45 -0700 Subject: [PATCH 09/13] reset to cpu for vanilla tests Signed-off-by: Lehui Liu --- python/ray/train/v2/tests/test_jax_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index a92282495442..e96c2f265ddc 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -87,6 +87,11 @@ def mock_jax_distributed(monkeypatch): def train_func(): + import os + + # Set JAX_PLATFORMS to cpu for testing (avoid TPU initialization without hardware) + os.environ["JAX_PLATFORMS"] = "cpu" + import jax from ray import train From 1902da427b9f259450e33e5ed43a3982f1842297 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 15:11:04 -0700 Subject: [PATCH 10/13] modify Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 26 ++-- python/ray/train/v2/tests/test_jax_trainer.py | 129 ++---------------- 2 files changed, 25 insertions(+), 130 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 3fa54c295e6f..5cfead3cdb8f 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -27,17 +27,18 @@ def backend_cls(self): def _set_jax_env_vars(use_tpu: bool): - """Set JAX environment variables based on configuration.""" - if use_tpu: - # Get existing JAX_PLATFORMS if set - existing_jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + """Set JAX environment variables based on configuration. - if "tpu" in existing_jax_platforms: - return - elif existing_jax_platforms: - os.environ["JAX_PLATFORMS"] = "tpu," + existing_jax_platforms - else: - os.environ["JAX_PLATFORMS"] = "tpu" + If JAX_PLATFORMS is already set (by user or test), we trust that configuration + and do nothing. Otherwise, if use_tpu=True, we set it to "tpu". + """ + # If user/test already set JAX_PLATFORMS, respect their choice + if os.environ.get("JAX_PLATFORMS"): + return + + # Only set JAX_PLATFORMS if not already specified + if use_tpu: + os.environ["JAX_PLATFORMS"] = "tpu" def _setup_jax_distributed_environment( @@ -49,7 +50,10 @@ def _setup_jax_distributed_environment( """ import jax - jax.distributed.initialize(master_addr_with_port, num_workers, index) + jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + + if "tpu" in jax_platforms.split(","): + jax.distributed.initialize(master_addr_with_port, num_workers, index) def _shutdown_jax_distributed(): diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index e96c2f265ddc..976f3d507819 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -1,5 +1,3 @@ -from unittest import mock - import pytest import ray @@ -63,35 +61,7 @@ def reduce_health_check_interval(monkeypatch): yield -@pytest.fixture(autouse=True) -def mock_jax_distributed(monkeypatch): - """Mock JAX distributed setup/shutdown to avoid TPU initialization in tests. - - This prevents the RuntimeError: Unable to initialize backend 'tpu' error - in test environments without actual TPU hardware. - """ - # Mock the setup and shutdown functions to prevent JAX TPU initialization - mock_setup = mock.MagicMock() - mock_shutdown = mock.MagicMock() - - monkeypatch.setattr( - "ray.train.v2.jax.config._setup_jax_distributed_environment", - mock_setup, - ) - monkeypatch.setattr( - "ray.train.v2.jax.config._shutdown_jax_distributed", - mock_shutdown, - ) - - yield - - def train_func(): - import os - - # Set JAX_PLATFORMS to cpu for testing (avoid TPU initialization without hardware) - os.environ["JAX_PLATFORMS"] = "cpu" - import jax from ray import train @@ -113,6 +83,11 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path): ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, ), ) result = trainer.fit() @@ -138,6 +113,11 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): ), run_config=RunConfig( storage_path=str(tmp_path), + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, ), ) result = trainer.fit() @@ -154,95 +134,6 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path): assert len(labeled_nodes) == 2 -def test_jax_platforms_env_var_no_existing(ray_tpu_single_host, tmp_path, monkeypatch): - """Test that JAX_PLATFORMS is set to 'tpu' when use_tpu=True and no existing value.""" - # Ensure JAX_PLATFORMS is not set - monkeypatch.delenv("JAX_PLATFORMS", raising=False) - - def train_func_check_env(): - """Training function to verify JAX_PLATFORMS env var is set.""" - import os - - jax_platforms = os.environ.get("JAX_PLATFORMS", "") - assert jax_platforms == "tpu" - - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - ), - ) - trainer.fit() - - -def test_jax_platforms_env_var_with_existing_tpu(ray_tpu_single_host, tmp_path): - """Test that JAX_PLATFORMS is not modified when 'tpu' is already present.""" - - def train_func_check_env(): - """Training function to verify JAX_PLATFORMS env var is set.""" - import os - - jax_platforms = os.environ.get("JAX_PLATFORMS", "") - assert jax_platforms == "tpu,cpu" - - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "tpu,cpu", - }, - }, - ), - ) - trainer.fit() - - -def test_jax_platforms_env_var_with_existing_other( - ray_tpu_single_host, tmp_path, monkeypatch -): - """Test that 'tpu' is prepended when JAX_PLATFORMS contains other platforms.""" - - def train_func_check_env(): - """Training function to verify JAX_PLATFORMS env var is set.""" - import os - - jax_platforms = os.environ.get("JAX_PLATFORMS", "") - assert jax_platforms == "tpu,cpu" - - trainer = JaxTrainer( - train_loop_per_worker=train_func_check_env, - scaling_config=ScalingConfig( - num_workers=1, - resources_per_worker={"TPU": 8}, - use_tpu=True, - accelerator_type="TPU-V6E", - ), - run_config=RunConfig( - storage_path=str(tmp_path), - worker_runtime_env={ - "env_vars": { - "JAX_PLATFORMS": "cpu", - }, - }, - ), - ) - trainer.fit() - - if __name__ == "__main__": import sys From 43ce7f769d361f9d42d0fad18af9e6b04b3756c2 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 29 Oct 2025 15:11:42 -0700 Subject: [PATCH 11/13] remove Signed-off-by: Lehui Liu --- .../__pycache__/serialization_context.cpython-39.pyc.4431272656 | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 diff --git a/python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 b/python/ray/experimental/channel/__pycache__/serialization_context.cpython-39.pyc.4431272656 deleted file mode 100644 index e69de29bb2d1..000000000000 From 0ac0d1eaee8b1e14c050d56c8119b9c1bf91793b Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 12 Nov 2025 10:54:42 -0800 Subject: [PATCH 12/13] aggregate into one async call Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 47 ++++++++++++++++--------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index 5cfead3cdb8f..db150fb0bb6c 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -26,31 +26,33 @@ def backend_cls(self): return _JaxBackend -def _set_jax_env_vars(use_tpu: bool): - """Set JAX environment variables based on configuration. - - If JAX_PLATFORMS is already set (by user or test), we trust that configuration - and do nothing. Otherwise, if use_tpu=True, we set it to "tpu". - """ - # If user/test already set JAX_PLATFORMS, respect their choice - if os.environ.get("JAX_PLATFORMS"): - return - - # Only set JAX_PLATFORMS if not already specified - if use_tpu: - os.environ["JAX_PLATFORMS"] = "tpu" - - def _setup_jax_distributed_environment( - master_addr_with_port: str, num_workers: int, index: int + master_addr_with_port: str, num_workers: int, index: int, use_tpu: bool ): """Set up distributed Jax training information. - This function should be called on each worker. + This function should be called on each worker. It sets JAX environment + variables and initializes JAX distributed training. + + Args: + master_addr_with_port: The master address with port for coordination. + num_workers: Total number of workers. + index: Index of this worker. + use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not + already set, it will be set to "tpu". """ - import jax + # If user/test already set JAX_PLATFORMS, respect their choice + if os.environ.get("JAX_PLATFORMS"): + jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + else: + # Set JAX_PLATFORMS based on configuration + if use_tpu: + os.environ["JAX_PLATFORMS"] = "tpu" + jax_platforms = "tpu" + else: + jax_platforms = "" - jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + import jax if "tpu" in jax_platforms.split(","): jax.distributed.initialize(master_addr_with_port, num_workers, index) @@ -75,13 +77,11 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): if not backend_config.use_tpu: return - # Set JAX environment variables on all workers - worker_group.execute(_set_jax_env_vars, use_tpu=backend_config.use_tpu) - master_addr, master_port = worker_group.execute_single(0, get_address_and_port) master_addr_with_port = f"{master_addr}:{master_port}" - # Get setup tasks in order to throw errors on failure. + # Set up JAX distributed environment on all workers + # This sets JAX_PLATFORMS env var and initializes JAX distributed setup_futures = [] for i in range(len(worker_group)): setup_futures.append( @@ -91,6 +91,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): master_addr_with_port=master_addr_with_port, num_workers=len(worker_group), index=i, + use_tpu=backend_config.use_tpu, ) ) ray.get(setup_futures) From 26da0623d2442deac06f423d096b4e32317985fb Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Wed, 12 Nov 2025 11:18:24 -0800 Subject: [PATCH 13/13] resolve comment Signed-off-by: Lehui Liu --- python/ray/train/v2/jax/config.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py index db150fb0bb6c..e826909c2c6f 100644 --- a/python/ray/train/v2/jax/config.py +++ b/python/ray/train/v2/jax/config.py @@ -41,16 +41,14 @@ def _setup_jax_distributed_environment( use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not already set, it will be set to "tpu". """ - # If user/test already set JAX_PLATFORMS, respect their choice - if os.environ.get("JAX_PLATFORMS"): - jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() - else: - # Set JAX_PLATFORMS based on configuration - if use_tpu: - os.environ["JAX_PLATFORMS"] = "tpu" - jax_platforms = "tpu" - else: - jax_platforms = "" + # Get JAX_PLATFORMS from environment if already set + jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + + if not jax_platforms and use_tpu: + os.environ["JAX_PLATFORMS"] = "tpu" + jax_platforms = "tpu" + + # TODO(lehui): Add env vars for JAX on GPU. import jax