diff --git a/python/ray/_private/accelerators/tpu.py b/python/ray/_private/accelerators/tpu.py index c6df2c858779..83da22475879 100644 --- a/python/ray/_private/accelerators/tpu.py +++ b/python/ray/_private/accelerators/tpu.py @@ -9,6 +9,7 @@ import ray from ray._private.accelerators.accelerator import AcceleratorManager +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy logger = logging.getLogger(__name__) @@ -110,6 +111,91 @@ def get_tpu_cores_per_chip(accelerator_type: str) -> int: return DEFAULT_TPU_NUM_CORES_PER_CHIP +def infer_tpu_pod_type_from_topology( + topology: str, accelerator_type: str +) -> Optional[str]: + """Infer the TPU pod type (e.g. v4-32) from topology and accelerator type.""" + try: + num_chips = 1 + for value in topology.strip().lower().split("x"): + num_chips *= int(value) + generation = accelerator_type.lower().replace("tpu-", "") + return f"{generation}-{num_chips}" + except Exception as e: + logger.warning( + f"Failed to infer pod type from topology {topology} and type {accelerator_type}: {e}" + ) + return None + + +def fetch_tpu_slice_name_from_pg(pg): + @ray.remote(num_cpus=0) + def _get_tpu_slice_name(): + return TPUAcceleratorManager.get_current_node_tpu_name() + + tpu_name_ref = _get_tpu_slice_name.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=0 + ) + ).remote() + + return ray.get(tpu_name_ref) + + +def reserve_tpu_slice( + topology: str, + accelerator_type: str, +) -> Optional[str]: + """Reserves a TPU slice using its head resource and returns the slice name. + This enables gang scheduling of training workers with multi-host TPUs. + This is used by JaxTrainer with TPUs in Ray Train. + + Args: + topology: The TPU topology string (e.g. "2x2x2"). + accelerator_type: The accelerator type of the node (e.g. "TPU-V4"). + + Returns: + A string representing a unique TPU slice name. + """ + pod_type = infer_tpu_pod_type_from_topology(topology, accelerator_type) + if pod_type is None: + return None + + # Reserve a slice by creating a placement group on the TPU head. + head_label_selector = { + "ray.io/tpu-worker-id": "0", + "ray.io/tpu-pod-type": pod_type, + } + head_placement_group = ray.util.placement_group( + bundles=[{f"TPU-{pod_type}-head": 1}], + bundle_label_selector=[head_label_selector], + ) + + logger.debug("Waiting to reserve multi-host slice head.") + timeout = 100 + ready, _ = ray.wait([head_placement_group.ready()], timeout=timeout) + + if not ready: + raise TimeoutError( + "Failed to reserve TPU head for slice with shape: {}. " + "Ensure your cluster has sufficient resources. Requesting TPU " + "head node with labels: {}. Current resources: {}".format( + pod_type, head_label_selector, ray.available_resources() + ) + ) + + # Retrieve the unique slice ID. + slice_name = fetch_tpu_slice_name_from_pg(head_placement_group) + if slice_name is None: + raise RuntimeError( + "Failed to retrieve TPU slice name after reserving head placement group. " + "Ensure that TPU slice metadata is available and correctly configured on multi-host nodes." + ) + + # TODO: return both the slice name and reference to the PG reservation. + return slice_name + + class TPUAcceleratorManager(AcceleratorManager): """Google TPU accelerators.""" diff --git a/python/ray/_private/resource_and_label_spec.py b/python/ray/_private/resource_and_label_spec.py index 03c8efd0e119..d737b70961da 100644 --- a/python/ray/_private/resource_and_label_spec.py +++ b/python/ray/_private/resource_and_label_spec.py @@ -10,7 +10,6 @@ from ray._common.utils import RESOURCE_CONSTRAINT_PREFIX from ray._private import accelerators from ray._private.accelerators import AcceleratorManager -from ray._private.accelerators.tpu import TPUAcceleratorManager logger = logging.getLogger(__name__) @@ -292,10 +291,11 @@ def _get_default_labels( ray._raylet.RAY_NODE_ACCELERATOR_TYPE_KEY ] = accelerator_type - # Set TPU specific default labels to enable SPMD scheduling. - if isinstance(accelerator_manager, TPUAcceleratorManager): + # Set TPU specific default labels to enable multi-host scheduling. + if accelerator_manager.get_resource_name() == "TPU": tpu_labels = accelerator_manager.get_current_node_accelerator_labels() - default_labels.update(tpu_labels) + if tpu_labels: + default_labels.update(tpu_labels) return default_labels diff --git a/python/ray/tests/accelerators/test_tpu.py b/python/ray/tests/accelerators/test_tpu.py index e0d405f60efd..f13f27ffea80 100644 --- a/python/ray/tests/accelerators/test_tpu.py +++ b/python/ray/tests/accelerators/test_tpu.py @@ -8,6 +8,7 @@ import ray from ray._private.accelerators import TPUAcceleratorManager from ray._private.accelerators import tpu +from ray.tests.conftest import _ray_start_cluster @patch("glob.glob") @@ -353,5 +354,76 @@ def test_get_current_node_tpu_topology_from_metadata(): assert topology == "2x2x4" +@pytest.mark.parametrize( + "topology, accelerator_type, expected_pod_type", + [ + ("2x4", "TPU-V6E", "v6e-8"), + ("2x2x2", "TPU-V4", "v4-8"), + ("2x4x4", "TPU-V3", "v3-32"), + ("4x4", "TPU-V5P", "v5p-16"), + ("8x16", "TPU-V6E", "v6e-128"), + ("", "TPU-V3", None), + ("4x", "TPU-V3", None), + ], +) +def test_infer_tpu_pod_type_from_topology( + topology, accelerator_type, expected_pod_type +): + assert ( + tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type) + == expected_pod_type + ) + + +@pytest.fixture +def ray_start_cpu(): + address_info = ray.init(num_cpus=1) + yield address_info + ray.shutdown() + + +@pytest.fixture +def ray_tpu_cluster(monkeypatch): + """Start a mock TPU Ray cluster.""" + with _ray_start_cluster() as cluster: + monkeypatch.setenv("TPU_NAME", "test-slice-0") + monkeypatch.setenv("TPU_WORKER_ID", "0") + monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8") + monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2") + + cluster.add_node( + num_cpus=2, + resources={"TPU": 4, "TPU-v4-8-head": 1}, + ) + monkeypatch.setenv("TPU_WORKER_ID", "1") + cluster.add_node( + num_cpus=2, + resources={"TPU": 4}, + ) + ray.init(address=cluster.address) + + yield cluster + ray.shutdown() + + +def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster): + """Tests that the slice name can be fetched from a PG.""" + tpu_head_pg = ray.util.placement_group(bundles=[{"TPU-v4-8-head": 1}]) + ray.get(tpu_head_pg.ready()) + + tpu_slice_name = "test-slice-0" + slice_name = tpu.fetch_tpu_slice_name_from_pg(tpu_head_pg) + assert slice_name == tpu_slice_name + + ray.util.remove_placement_group(tpu_head_pg) + + +def test_reserve_tpu_slice(ray_tpu_cluster): + """Tests that a TPU slice can be successfully reserved.""" + tpu_slice_name = "test-slice-0" + reserved_name = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4") + assert reserved_name == tpu_slice_name + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/train/v2/BUILD b/python/ray/train/v2/BUILD index 31bab21afc44..8b0128a3a15d 100644 --- a/python/ray/train/v2/BUILD +++ b/python/ray/train/v2/BUILD @@ -133,6 +133,22 @@ py_test( ], ) +py_test( + name = "test_jax_trainer", + size = "small", + srcs = ["tests/test_jax_trainer.py"], + env = {"RAY_TRAIN_V2_ENABLED": "1"}, + tags = [ + "exclusive", + "team:ml", + "train_v2", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_lightgbm_trainer", size = "small", diff --git a/python/ray/train/v2/_internal/callbacks/__init__.py b/python/ray/train/v2/_internal/callbacks/__init__.py index 5c5b204acdcf..3db8d835fba3 100644 --- a/python/ray/train/v2/_internal/callbacks/__init__.py +++ b/python/ray/train/v2/_internal/callbacks/__init__.py @@ -2,6 +2,7 @@ from .backend_setup import BackendSetupCallback from .datasets import DatasetsSetupCallback from .state_manager import StateManagerCallback +from .tpu_reservation_callback import TPUReservationCallback from .working_dir_setup import WorkingDirectorySetupCallback __all__ = [ @@ -9,6 +10,7 @@ "BackendSetupCallback", "DatasetsSetupCallback", "StateManagerCallback", + "TPUReservationCallback", "WorkingDirectorySetupCallback", ] diff --git a/python/ray/train/v2/_internal/callbacks/tpu_reservation_callback.py b/python/ray/train/v2/_internal/callbacks/tpu_reservation_callback.py new file mode 100644 index 000000000000..acb7b70847ea --- /dev/null +++ b/python/ray/train/v2/_internal/callbacks/tpu_reservation_callback.py @@ -0,0 +1,45 @@ +from typing import Dict, Optional + +import ray +from ray._private.accelerators.tpu import reserve_tpu_slice +from ray.train.v2._internal.execution.callback import ControllerCallback +from ray.train.v2.api.config import ScalingConfig + + +class TPUReservationCallback(ControllerCallback): + """A callback to handle TPU slice reservation for multi-host training.""" + + def on_controller_start_worker_group( + self, *, scaling_config: ScalingConfig, num_workers: int + ) -> Optional[Dict[str, str]]: + """Reserves a multi-host TPU slice before the worker group starts. + + This hook is called by the TrainController. It checks if multi-host + TPUs are being used and, if so, reserves a slice. + + Args: + scaling_config: The scaling configuration for the run. + num_workers: The number of workers to be started. + + Returns: + A dictionary defining a `bundle_label_selector` to gang schedule + the worker group on the reserved TPU slice. + """ + bundle_label_selector = None + + if scaling_config.use_tpu and num_workers > 1: + assert scaling_config.accelerator_type is not None + assert scaling_config.topology is not None + + slice_name = reserve_tpu_slice( + topology=scaling_config.topology, + accelerator_type=scaling_config.accelerator_type, + ) + if not slice_name: + raise RuntimeError("Failed to reserve TPU slice.") + + bundle_label_selector = { + ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name + } + + return bundle_label_selector diff --git a/python/ray/train/v2/_internal/execution/callback.py b/python/ray/train/v2/_internal/execution/callback.py index 50796a0700c8..f5cfd3584f79 100644 --- a/python/ray/train/v2/_internal/execution/callback.py +++ b/python/ray/train/v2/_internal/execution/callback.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from ray.train.v2.api.callback import RayTrainCallback +from ray.train.v2.api.config import ScalingConfig from ray.train.v2.api.result import Result from ray.util.annotations import DeveloperAPI @@ -78,6 +79,28 @@ def after_controller_start(self, train_run_context: "TrainRunContext"): before the control loop starts executing.""" pass + # TODO(matthewdeng): Revisit this callback interface for better extensibility. + # This hook was added for the specific use case of setting a `bundle_label_selector` + # for new worker groups (e.g., for TPU reservations). The current interface is + # tightly coupled to this purpose and limits its reuse for other use-cases. + def on_controller_start_worker_group( + self, *, scaling_config: ScalingConfig, num_workers: int + ) -> Optional[Dict[str, str]]: + """Called by the TrainController before the worker group is started. + + This hook can be used to perform setup that modifies the worker group's + placement, such as reserving an accelerator slice. + + Args: + scaling_config: The scaling configuration for the run. + num_workers: The number of workers to be started. + + Returns: + An optional dictionary defining a `bundle_label_selector` + to gang schedule the worker group on the reserved TPU slice. + """ + return None + def before_controller_shutdown(self): """Called before `TrainController.run` exits, after the control loop has exited.""" diff --git a/python/ray/train/v2/_internal/execution/controller/controller.py b/python/ray/train/v2/_internal/execution/controller/controller.py index e2f916d140a4..59d3760bb503 100644 --- a/python/ray/train/v2/_internal/execution/controller/controller.py +++ b/python/ray/train/v2/_internal/execution/controller/controller.py @@ -280,12 +280,28 @@ def _start_worker_group( ControllerError if the worker group failed to start. """ placement_strategy = self._scaling_policy.scaling_config.placement_strategy + scaling_config = self._train_run_context.scaling_config + + # Check for `bundle_label_selector` to influence WorkerGroup scheduling. + bundle_label_selector = None + try: + for callback in self._controller_callbacks: + selector = callback.on_controller_start_worker_group( + scaling_config=scaling_config, num_workers=num_workers + ) + if selector: + bundle_label_selector = selector + break + except Exception as e: + return ControllerError(e) + worker_group_context = WorkerGroupContext( run_attempt_id=self._get_run_attempt_id(), train_fn_ref=self._train_fn_ref, num_workers=num_workers, resources_per_worker=resources_per_worker, placement_strategy=placement_strategy, + bundle_label_selector=bundle_label_selector, ) try: self._worker_group = self.worker_group_cls.create( diff --git a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py index 8931145ecdbb..81087c0f91c7 100644 --- a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py +++ b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py @@ -89,6 +89,7 @@ class WorkerGroupContext: num_workers: The number of workers in the worker group. resources_per_worker: The resources per worker. placement_strategy: Strategy for placing workers. + bundle_label_selector: Optional label selectors to apply per-bundle for workers. """ run_attempt_id: str @@ -96,6 +97,7 @@ class WorkerGroupContext: num_workers: int resources_per_worker: Dict[str, float] placement_strategy: str = "PACK" + bundle_label_selector: Optional[Dict[str, str]] = None class WorkerGroup: @@ -268,10 +270,18 @@ def _start_impl( for callback in self._callbacks: callback.before_worker_group_start(worker_group_context) + bundle_label_selector = ( + [worker_group_context.bundle_label_selector.copy()] + * worker_group_context.num_workers + if worker_group_context.bundle_label_selector + else None + ) + pg = placement_group( bundles=[worker_group_context.resources_per_worker] * worker_group_context.num_workers, strategy=worker_group_context.placement_strategy, + bundle_label_selector=bundle_label_selector, ) logger.info( f"Attempting to start training worker group of size {worker_group_context.num_workers} with " diff --git a/python/ray/train/v2/api/config.py b/python/ray/train/v2/api/config.py index 4efc25a2960c..665b3998cf70 100644 --- a/python/ray/train/v2/api/config.py +++ b/python/ray/train/v2/api/config.py @@ -22,7 +22,6 @@ if TYPE_CHECKING: from ray.train import UserCallback - logger = logging.getLogger(__name__) @@ -51,7 +50,17 @@ class ScalingConfig(ScalingConfigV1): of accelerators. See :ref:`the available accelerator types `. Ensure that your cluster has instances with the specified accelerator type - or is able to autoscale to fulfill the request. + or is able to autoscale to fulfill the request. This field is required + when `use_tpu` is True and `num_workers` is greater than 1. + use_tpu: [Experimental] If True, training will be done on TPUs (1 TPU VM + per worker). Defaults to False. The number of TPUs reserved by each + worker can be overridden with the ``resources_per_worker`` + argument. This arg enables SPMD execution of the training workload. + topology: [Experimental] If specified, Ray Train will launch the training + coordinator and workers on nodes with the specified topology. Topology is + auto-detected for TPUs and added as Ray node labels. This arg enables + SPMD execution of the training workload. This field is required + when `use_tpu` is True and `num_workers` is greater than 1. Example: @@ -73,17 +82,62 @@ class ScalingConfig(ScalingConfigV1): """ trainer_resources: Optional[dict] = None + use_tpu: Union[bool] = False + topology: Optional[str] = None def __post_init__(self): if self.trainer_resources is not None: raise DeprecationWarning(TRAINER_RESOURCES_DEPRECATION_MESSAGE) + if self.use_gpu and self.use_tpu: + raise ValueError("Cannot specify both `use_gpu=True` and `use_tpu=True`.") + + if not self.use_tpu and self.num_tpus_per_worker > 0: + raise ValueError( + "`use_tpu` is False but `TPU` was found in " + "`resources_per_worker`. Either set `use_tpu` to True or " + "remove `TPU` from `resources_per_worker." + ) + + if self.use_tpu and self.num_tpus_per_worker == 0: + raise ValueError( + "`use_tpu` is True but `TPU` is set to 0 in " + "`resources_per_worker`. Either set `use_tpu` to False or " + "request a positive number of `TPU` in " + "`resources_per_worker." + ) + + if self.use_tpu and self.num_workers > 1: + if not self.topology: + raise ValueError( + "`topology` must be specified in ScalingConfig when `use_tpu=True` " + " and `num_workers` > 1." + ) + if not self.accelerator_type: + raise ValueError( + "`accelerator_type` must be specified in ScalingConfig when " + "`use_tpu=True` and `num_workers` > 1." + ) + super().__post_init__() + @property + def _resources_per_worker_not_none(self): + if self.resources_per_worker is None: + if self.use_tpu: + return {"TPU": 1} + + return super()._resources_per_worker_not_none + @property def _trainer_resources_not_none(self): return {} + @property + def num_tpus_per_worker(self): + """The number of TPUs to set per worker.""" + return self._resources_per_worker_not_none.get("TPU", 0) + @dataclass class FailureConfig(FailureConfigV1): diff --git a/python/ray/train/v2/api/data_parallel_trainer.py b/python/ray/train/v2/api/data_parallel_trainer.py index 40fede922b90..369d8762e87d 100644 --- a/python/ray/train/v2/api/data_parallel_trainer.py +++ b/python/ray/train/v2/api/data_parallel_trainer.py @@ -27,6 +27,7 @@ AcceleratorSetupCallback, BackendSetupCallback, DatasetsSetupCallback, + TPUReservationCallback, WorkingDirectorySetupCallback, ) from ray.train.v2._internal.callbacks.datasets import GenDataset @@ -154,9 +155,11 @@ def _create_default_callbacks(self) -> List[RayTrainCallback]: data_config=self.data_config, scaling_config=self.scaling_config, ) + tpu_reservation_setup_callback = TPUReservationCallback() callbacks.extend( [ accelerator_setup_callback, + tpu_reservation_setup_callback, backend_setup_callback, datasets_setup_callback, ] diff --git a/python/ray/train/v2/jax/__init__.py b/python/ray/train/v2/jax/__init__.py new file mode 100644 index 000000000000..097ee852b783 --- /dev/null +++ b/python/ray/train/v2/jax/__init__.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + try: + import jax # noqa: F401 + except ModuleNotFoundError as exception: + raise ModuleNotFoundError( + "Jax isn't installed. To install Jax, please check" + " `https://github.com/google/jax#installation` for the instructions." + ) from exception + +from ray.train.v2.jax.config import JaxConfig +from ray.train.v2.jax.jax_trainer import JaxTrainer + +__all__ = ["JaxConfig", "JaxTrainer"] diff --git a/python/ray/train/v2/jax/config.py b/python/ray/train/v2/jax/config.py new file mode 100644 index 000000000000..5e8dc5ba33e4 --- /dev/null +++ b/python/ray/train/v2/jax/config.py @@ -0,0 +1,59 @@ +import logging +import os +from dataclasses import dataclass + +import ray +from ray.train._internal.utils import get_address_and_port +from ray.train._internal.worker_group import WorkerGroup +from ray.train.backend import Backend, BackendConfig +from ray.util import PublicAPI + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +@dataclass +class JaxConfig(BackendConfig): + use_tpu: bool = False + + @property + def backend_cls(self): + return _JaxBackend + + +def _setup_jax_tpu_environment( + master_addr_with_port: str, num_workers: int, index: int +): + """Set up distributed Jax training information. + + This function should be called on each worker. + """ + import jax + + jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower() + + if "tpu" in jax_platforms.split(","): + jax.distributed.initialize(master_addr_with_port, num_workers, index) + + +class _JaxBackend(Backend): + def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig): + if not backend_config.use_tpu: + return + + master_addr, master_port = worker_group.execute_single(0, get_address_and_port) + master_addr_with_port = f"{master_addr}:{master_port}" + + # Get setup tasks in order to throw errors on failure. + setup_futures = [] + for i in range(len(worker_group)): + setup_futures.append( + worker_group.execute_single_async( + i, + _setup_jax_tpu_environment, + master_addr_with_port=master_addr_with_port, + num_workers=len(worker_group), + index=i, + ) + ) + ray.get(setup_futures) diff --git a/python/ray/train/v2/jax/jax_trainer.py b/python/ray/train/v2/jax/jax_trainer.py new file mode 100644 index 000000000000..f1845d8d50ff --- /dev/null +++ b/python/ray/train/v2/jax/jax_trainer.py @@ -0,0 +1,162 @@ +import logging +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union + +from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated +from ray.train import Checkpoint, DataConfig +from ray.train.trainer import GenDataset +from ray.train.v2.api.config import RunConfig, ScalingConfig +from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer +from ray.train.v2.jax.config import JaxConfig +from ray.util import PublicAPI + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class JaxTrainer(DataParallelTrainer): + """A Trainer for Single-Program Multi-Data (SPMD) JAX training. + Currently only supports TPUs. GPUs will be supported in a future version. + + This Trainer runs the function ``train_loop_per_worker`` on multiple Ray + Actors. These actors are expected to be scheduled on TPU VMs within the same + TPU slice, connected via inter-chip interconnects (ICI). The ``train_loop_per_worker`` + function is expected to take in either 0 or 1 arguments: + + .. testcode:: + + import os + from absl import app + import logging + from typing import Sequence + + import ray + from ray.train.v2.api.config import ScalingConfig, RunConfig + from ray.train.v2.jax import JaxTrainer + from MaxText.train import main as maxtext_main + + def train_loop_per_worker(config): + argv = config["argv"] + maxtext_main(argv) + + def main(argv: Sequence[str]): + ray.init() + + trainer = JaxTrainer( + train_loop_per_worker=train_loop_per_worker, + train_loop_config={"argv": absolute_argv}, + scaling_config=ScalingConfig( + use_tpu=True, + num_workers=4, + topology="4x4", + accelerator_type="TPU-V6E", + resources_per_worker={"TPU": 4}, + placement_strategy="SPREAD", + ), + run_config=RunConfig( + name="maxtext_jaxtrainer", + worker_runtime_env={ + "env_vars": { + "JAX_PLATFORMS": "tpu", + "ENABLE_PJRT_COMPATIBILITY": "true", + "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true", + "TPU_SLICE_BUILDER_DUMP_ICI": "true", + "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto", + } + }, + ), + ) + + result = trainer.fit() + + .. testoutput:: + :options: +ELLIPSIS + :hide: + + If ``train_loop_per_worker`` accepts an argument, then + ``train_loop_config`` will be passed in as the argument. + + If the ``datasets`` dict contains a training dataset (denoted by + the "train" key), then it will be split into multiple dataset + shards that can then be accessed by ``session.get_dataset_shard("train")``. + + Note: + * Only TPU-based distributed training is supported. + * Each worker must be assigned one TPU device via + ``resources_per_worker={"TPU": 1}``. + * Placement strategy is automatically set to ``SPREAD`` to ensure + TPU workers are placed on separate VMs. + * Importing `jax` should occur within `train_loop_per_worker` to + avoid driver-side TPU lock issues. + + Args: + train_loop_per_worker: The training function to execute on each worker. + This function can either take in zero arguments or a single ``Dict`` + argument which is set by defining ``train_loop_config``. + Within this function you can use any of the + :ref:`Ray Train Loop utilities `. + train_loop_config: A configuration ``Dict`` to pass in as an argument to + ``train_loop_per_worker``. + This is typically used for specifying hyperparameters. Passing large + datasets via `train_loop_config` is not recommended and may introduce + large overhead and unknown issues with serialization and deserialization. + jax_config: The configuration for setting up the JAX backend. + If set to None, a default configuration with TPUs will be used. + scaling_config: Configuration for how to scale data parallel training + with SPMD. ``num_workers`` should be set to the number of TPU hosts + and ``topology`` should be set to the TPU topology. + See :class:`~ray.train.ScalingConfig` for more info. + dataset_config: The configuration for ingesting the input ``datasets``. + By default, all the Ray Dataset are split equally across workers. + See :class:`~ray.train.DataConfig` for more details. + run_config: The configuration for the execution of the training run. + See :class:`~ray.train.RunConfig` for more info. + datasets: The Ray Datasets to ingest for training. + Datasets are keyed by name (``{name: dataset}``). + Each dataset can be accessed from within the ``train_loop_per_worker`` + by calling ``ray.train.get_dataset_shard(name)``. + Sharding and additional configuration can be done by + passing in a ``dataset_config``. + resume_from_checkpoint: A checkpoint to resume training from. + This checkpoint can be accessed from within ``train_loop_per_worker`` + by calling ``ray.train.get_checkpoint()``. + """ + + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + *, + train_loop_config: Optional[Dict] = None, + jax_config: Optional[JaxConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + dataset_config: Optional[Dict[str, DataConfig]] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + if not jax_config: + jax_config = JaxConfig( + use_tpu=scaling_config.use_tpu, + ) + super(JaxTrainer, self).__init__( + train_loop_per_worker=train_loop_per_worker, + train_loop_config=train_loop_config, + backend_config=jax_config, + scaling_config=scaling_config, + dataset_config=dataset_config, + run_config=run_config, + datasets=datasets, + resume_from_checkpoint=resume_from_checkpoint, + ) + + @classmethod + def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig: + """Return scaling config dataclass after validating updated keys.""" + ensure_only_allowed_dataclass_keys_updated( + dataclass=scaling_config, + allowed_keys=cls._scaling_config_allowed_keys, + ) + + return scaling_config diff --git a/python/ray/train/v2/tests/test_jax_trainer.py b/python/ray/train/v2/tests/test_jax_trainer.py new file mode 100644 index 000000000000..a6449577181b --- /dev/null +++ b/python/ray/train/v2/tests/test_jax_trainer.py @@ -0,0 +1,137 @@ +import pytest + +import ray +from ray.tests.conftest import _ray_start_cluster +from ray.train.v2._internal.constants import HEALTH_CHECK_INTERVAL_S_ENV_VAR +from ray.train.v2.api.config import RunConfig, ScalingConfig +from ray.train.v2.jax import JaxTrainer + + +@pytest.fixture +def ray_tpu_single_host(monkeypatch): + """Start a mock single-host TPU Ray cluster with 2x4 v6e (8 chips per host).""" + with _ray_start_cluster() as cluster: + monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-8") + + # Simulate one node with 8 TPU chips. + cluster.add_node( + num_cpus=4, + resources={"TPU": 8}, + ) + + ray.init(address=cluster.address) + + yield cluster + ray.shutdown() + + +@pytest.fixture +def ray_tpu_multi_host(monkeypatch): + """Start a simulated multi-host TPU Ray cluster.""" + with _ray_start_cluster() as cluster: + monkeypatch.setenv("TPU_NAME", "test-slice-1") + monkeypatch.setenv("TPU_WORKER_ID", "0") + monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8") + monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2") + + cluster.add_node( + num_cpus=2, + resources={"TPU": 4, "TPU-v4-8-head": 1}, + ) + monkeypatch.setenv("TPU_WORKER_ID", "1") + cluster.add_node( + num_cpus=2, + resources={"TPU": 4}, + ) + + ray.init(address=cluster.address) + + yield cluster + ray.shutdown() + + +@pytest.fixture(autouse=True) +def reduce_health_check_interval(monkeypatch): + monkeypatch.setenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, "0.2") + yield + + +def train_func(): + import jax + + from ray import train + + devices = jax.devices() + print(f"Devices on this worker: {devices}") + train.report({"result": [str(d) for d in devices]}) + + +def test_minimal_singlehost(ray_tpu_single_host, tmp_path): + trainer = JaxTrainer( + train_loop_per_worker=train_func, + # Topology can be omitted for single-host. + scaling_config=ScalingConfig( + num_workers=1, + resources_per_worker={"TPU": 8}, + use_tpu=True, + accelerator_type="TPU-V6E", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "pip": ["jax"], + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that exactly 1 TPU node was used. + nodes = ray.nodes() + labeled_nodes = [ + node for node in nodes if node["Alive"] and node["Resources"].get("TPU") == 8 + ] + assert len(labeled_nodes) == 1 + + +def test_minimal_multihost(ray_tpu_multi_host, tmp_path): + trainer = JaxTrainer( + train_loop_per_worker=train_func, + scaling_config=ScalingConfig( + num_workers=2, + resources_per_worker={"TPU": 4}, + use_tpu=True, + topology="2x2x2", + accelerator_type="TPU-V4", + ), + run_config=RunConfig( + storage_path=str(tmp_path), + worker_runtime_env={ + "pip": ["jax"], + "env_vars": { + "JAX_PLATFORMS": "cpu", + }, + }, + ), + ) + result = trainer.fit() + assert result.error is None + + # Check that multi-host slice was scheduled atomically. + nodes = ray.nodes() + slice_label = "test-slice-1" + labeled_nodes = [ + node + for node in nodes + if node["Alive"] and node["Labels"].get("ray.io/tpu-slice-name") == slice_label + ] + assert len(labeled_nodes) == 2 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-x", __file__]))