diff --git a/python/ray/train/v2/tests/test_jax_gpu.py b/python/ray/train/v2/tests/test_jax_gpu.py index dea20e40bcc9..0e46a28a27f4 100644 --- a/python/ray/train/v2/tests/test_jax_gpu.py +++ b/python/ray/train/v2/tests/test_jax_gpu.py @@ -19,6 +19,10 @@ def reduce_health_check_interval(monkeypatch): @pytest.mark.skipif(sys.platform == "darwin", reason="JAX GPU not supported on macOS") +@pytest.mark.skipif( + sys.version_info >= (3, 12), + reason="Current jax version is not supported in python 3.12+", +) def test_jax_distributed_gpu_training(ray_start_4_cpus_2_gpus, tmp_path): """Test multi-GPU JAX distributed training. diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py index 976f3d507819..c6398b727e93 100644 --- a/python/ray/train/v2/tests/test_jax_trainer.py +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -1,3 +1,5 @@ +import sys + import pytest import ray @@ -71,6 +73,10 @@ def train_func(): train.report({"result": [str(d) for d in devices]}) +@pytest.mark.skipif( + sys.version_info >= (3, 12), + reason="Current jax version is not supported in python 3.12+", +) def test_minimal_singlehost(ray_tpu_single_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func, @@ -101,6 +107,10 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path): assert len(labeled_nodes) == 1 +@pytest.mark.skipif( + sys.version_info >= (3, 12), + reason="Current jax version is not supported in python 3.12+", +) def test_minimal_multihost(ray_tpu_multi_host, tmp_path): trainer = JaxTrainer( train_loop_per_worker=train_func,