Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
76983cf
JaxTrainer support in V2 with SPMD
ryanaoleary Jul 30, 2025
703e185
Fix errors and comments
ryanaoleary Aug 5, 2025
69d349c
Add minimal unit tests and fix scheduling logic to use one placement …
ryanaoleary Aug 6, 2025
b9f0764
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 6, 2025
c48734a
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 7, 2025
360b952
Fix tpu default labels population
ryanaoleary Aug 7, 2025
a341722
Add callback and fix other comments
ryanaoleary Aug 7, 2025
d2ee6a1
Fix wording and remove unneeded change
ryanaoleary Aug 7, 2025
7f165a6
Fix test code string
ryanaoleary Aug 7, 2025
566a788
remove unused fields from config
ryanaoleary Aug 8, 2025
4885a45
Make fields required for multi-host 'use_tpu'
ryanaoleary Aug 8, 2025
62730cd
Reserve tpu slice take required args
ryanaoleary Aug 8, 2025
7ccc7c5
Update example now that tested on v6e
ryanaoleary Aug 8, 2025
2fb6e88
add jax pip install
ryanaoleary Aug 9, 2025
536f83c
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 9, 2025
dbbd4cf
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 11, 2025
4e8bbc2
ensure TPUReservationCallback is executed and fix minor review comments
andrewsykim Aug 12, 2025
444b9de
refactor _JaxBackend.onStart to fetch unique worker address and port …
andrewsykim Aug 12, 2025
27ab61b
Merge branch 'master' into implement-jax-trainer
andrewsykim Aug 12, 2025
e530846
temporarily remove TPU callback import
andrewsykim Aug 12, 2025
67f077c
Revert "temporarily remove TPU callback import"
andrewsykim Aug 12, 2025
05e61cf
formatting fixes
andrewsykim Aug 12, 2025
5588927
fix lint errors
andrewsykim Aug 12, 2025
116c37f
move some TPU util functions to Ray Core to resolve import errors
andrewsykim Aug 12, 2025
4a8ea0d
remove test_tpu_utils reference in ray train v2 BUILD
andrewsykim Aug 12, 2025
e069ab7
fix code format and lint failures
andrewsykim Aug 12, 2025
a8b1829
fix code formatting
andrewsykim Aug 13, 2025
4a21677
move field assertion to after use_tpu condition
andrewsykim Aug 13, 2025
64c5a95
Merge branch 'master' into implement-jax-trainer
andrewsykim Aug 13, 2025
6598487
address nits from matthewdeng
andrewsykim Aug 13, 2025
b8638c4
remove unused import caught by lint check
andrewsykim Aug 13, 2025
1028bfe
fix docstring for JaxTrainer
andrewsykim Aug 13, 2025
f10e268
fix docstring order for JaxTrainer
andrewsykim Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions python/ray/_private/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray._private.accelerators.accelerator import AcceleratorManager
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,6 +111,91 @@ def get_tpu_cores_per_chip(accelerator_type: str) -> int:
return DEFAULT_TPU_NUM_CORES_PER_CHIP


def infer_tpu_pod_type_from_topology(
topology: str, accelerator_type: str
) -> Optional[str]:
"""Infer the TPU pod type (e.g. v4-32) from topology and accelerator type."""
try:
num_chips = 1
for value in topology.strip().lower().split("x"):
num_chips *= int(value)
generation = accelerator_type.lower().replace("tpu-", "")
return f"{generation}-{num_chips}"
except Exception as e:
logger.warning(
f"Failed to infer pod type from topology {topology} and type {accelerator_type}: {e}"
)
return None


def fetch_tpu_slice_name_from_pg(pg):
@ray.remote(num_cpus=0)
def _get_tpu_slice_name():
return TPUAcceleratorManager.get_current_node_tpu_name()

tpu_name_ref = _get_tpu_slice_name.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=0
)
).remote()

return ray.get(tpu_name_ref)


def reserve_tpu_slice(
topology: str,
accelerator_type: str,
) -> Optional[str]:
"""Reserves a TPU slice using its head resource and returns the slice name.
This enables gang scheduling of training workers with multi-host TPUs.
This is used by JaxTrainer with TPUs in Ray Train.

Args:
topology: The TPU topology string (e.g. "2x2x2").
accelerator_type: The accelerator type of the node (e.g. "TPU-V4").

Returns:
A string representing a unique TPU slice name.
"""
pod_type = infer_tpu_pod_type_from_topology(topology, accelerator_type)
if pod_type is None:
return None

# Reserve a slice by creating a placement group on the TPU head.
head_label_selector = {
"ray.io/tpu-worker-id": "0",
"ray.io/tpu-pod-type": pod_type,
}
head_placement_group = ray.util.placement_group(
bundles=[{f"TPU-{pod_type}-head": 1}],
bundle_label_selector=[head_label_selector],
)

logger.debug("Waiting to reserve multi-host slice head.")
timeout = 100
ready, _ = ray.wait([head_placement_group.ready()], timeout=timeout)

if not ready:
raise TimeoutError(
"Failed to reserve TPU head for slice with shape: {}. "
"Ensure your cluster has sufficient resources. Requesting TPU "
"head node with labels: {}. Current resources: {}".format(
pod_type, head_label_selector, ray.available_resources()
)
)

# Retrieve the unique slice ID.
slice_name = fetch_tpu_slice_name_from_pg(head_placement_group)
if slice_name is None:
raise RuntimeError(
"Failed to retrieve TPU slice name after reserving head placement group. "
"Ensure that TPU slice metadata is available and correctly configured on multi-host nodes."
)

# TODO: return both the slice name and reference to the PG reservation.
return slice_name


class TPUAcceleratorManager(AcceleratorManager):
"""Google TPU accelerators."""

Expand Down
8 changes: 4 additions & 4 deletions python/ray/_private/resource_and_label_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ray._common.utils import RESOURCE_CONSTRAINT_PREFIX
from ray._private import accelerators
from ray._private.accelerators import AcceleratorManager
from ray._private.accelerators.tpu import TPUAcceleratorManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -292,10 +291,11 @@ def _get_default_labels(
ray._raylet.RAY_NODE_ACCELERATOR_TYPE_KEY
] = accelerator_type

# Set TPU specific default labels to enable SPMD scheduling.
if isinstance(accelerator_manager, TPUAcceleratorManager):
# Set TPU specific default labels to enable multi-host scheduling.
if accelerator_manager.get_resource_name() == "TPU":
tpu_labels = accelerator_manager.get_current_node_accelerator_labels()
default_labels.update(tpu_labels)
if tpu_labels:
default_labels.update(tpu_labels)

return default_labels

Expand Down
72 changes: 72 additions & 0 deletions python/ray/tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray
from ray._private.accelerators import TPUAcceleratorManager
from ray._private.accelerators import tpu
from ray.tests.conftest import _ray_start_cluster


@patch("glob.glob")
Expand Down Expand Up @@ -353,5 +354,76 @@ def test_get_current_node_tpu_topology_from_metadata():
assert topology == "2x2x4"


@pytest.mark.parametrize(
"topology, accelerator_type, expected_pod_type",
[
("2x4", "TPU-V6E", "v6e-8"),
("2x2x2", "TPU-V4", "v4-8"),
("2x4x4", "TPU-V3", "v3-32"),
("4x4", "TPU-V5P", "v5p-16"),
("8x16", "TPU-V6E", "v6e-128"),
("", "TPU-V3", None),
("4x", "TPU-V3", None),
],
)
def test_infer_tpu_pod_type_from_topology(
topology, accelerator_type, expected_pod_type
):
assert (
tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type)
== expected_pod_type
)


@pytest.fixture
def ray_start_cpu():
address_info = ray.init(num_cpus=1)
yield address_info
ray.shutdown()


@pytest.fixture
def ray_tpu_cluster(monkeypatch):
"""Start a mock TPU Ray cluster."""
with _ray_start_cluster() as cluster:
monkeypatch.setenv("TPU_NAME", "test-slice-0")
monkeypatch.setenv("TPU_WORKER_ID", "0")
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8")
monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2")

cluster.add_node(
num_cpus=2,
resources={"TPU": 4, "TPU-v4-8-head": 1},
)
monkeypatch.setenv("TPU_WORKER_ID", "1")
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
)
ray.init(address=cluster.address)

yield cluster
ray.shutdown()


def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster):
"""Tests that the slice name can be fetched from a PG."""
tpu_head_pg = ray.util.placement_group(bundles=[{"TPU-v4-8-head": 1}])
ray.get(tpu_head_pg.ready())

tpu_slice_name = "test-slice-0"
slice_name = tpu.fetch_tpu_slice_name_from_pg(tpu_head_pg)
assert slice_name == tpu_slice_name

ray.util.remove_placement_group(tpu_head_pg)


def test_reserve_tpu_slice(ray_tpu_cluster):
"""Tests that a TPU slice can be successfully reserved."""
tpu_slice_name = "test-slice-0"
reserved_name = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4")
assert reserved_name == tpu_slice_name


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
16 changes: 16 additions & 0 deletions python/ray/train/v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ py_test(
],
)

py_test(
name = "test_jax_trainer",
size = "small",
srcs = ["tests/test_jax_trainer.py"],
env = {"RAY_TRAIN_V2_ENABLED": "1"},
tags = [
"exclusive",
"team:ml",
"train_v2",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_lightgbm_trainer",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/v2/_internal/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from .backend_setup import BackendSetupCallback
from .datasets import DatasetsSetupCallback
from .state_manager import StateManagerCallback
from .tpu_reservation_callback import TPUReservationCallback
from .working_dir_setup import WorkingDirectorySetupCallback

__all__ = [
"AcceleratorSetupCallback",
"BackendSetupCallback",
"DatasetsSetupCallback",
"StateManagerCallback",
"TPUReservationCallback",
"WorkingDirectorySetupCallback",
]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Dict, Optional

import ray
from ray._private.accelerators.tpu import reserve_tpu_slice
from ray.train.v2._internal.execution.callback import ControllerCallback
from ray.train.v2.api.config import ScalingConfig


class TPUReservationCallback(ControllerCallback):
"""A callback to handle TPU slice reservation for multi-host training."""

def on_controller_start_worker_group(
self, *, scaling_config: ScalingConfig, num_workers: int
) -> Optional[Dict[str, str]]:
"""Reserves a multi-host TPU slice before the worker group starts.

This hook is called by the TrainController. It checks if multi-host
TPUs are being used and, if so, reserves a slice.

Args:
scaling_config: The scaling configuration for the run.
num_workers: The number of workers to be started.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_workers can come from scaling_config as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check for both


Returns:
A dictionary defining a `bundle_label_selector` to gang schedule
the worker group on the reserved TPU slice.
"""
bundle_label_selector = None

if scaling_config.use_tpu and num_workers > 1:
assert scaling_config.accelerator_type is not None
assert scaling_config.topology is not None

slice_name = reserve_tpu_slice(
topology=scaling_config.topology,
accelerator_type=scaling_config.accelerator_type,
)
if not slice_name:
raise RuntimeError("Failed to reserve TPU slice.")

bundle_label_selector = {
ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name
}

return bundle_label_selector
23 changes: 23 additions & 0 deletions python/ray/train/v2/_internal/execution/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ray.train.v2.api.callback import RayTrainCallback
from ray.train.v2.api.config import ScalingConfig
from ray.train.v2.api.result import Result
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -78,6 +79,28 @@ def after_controller_start(self, train_run_context: "TrainRunContext"):
before the control loop starts executing."""
pass

# TODO(matthewdeng): Revisit this callback interface for better extensibility.
# This hook was added for the specific use case of setting a `bundle_label_selector`
# for new worker groups (e.g., for TPU reservations). The current interface is
# tightly coupled to this purpose and limits its reuse for other use-cases.
def on_controller_start_worker_group(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO for me to come back to this interface. It works well for this use case but I think it'll be hard to extend/support different use-cases in the future.

  1. The bundle_label_selector logic is a bit specific for this callback.
  2. Behavior for creating multiple callbacks with bundle_label_selector logic is undefined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this:

# TODO(matthewdeng): Revisit this callback interface for better extensibility.
# This hook was added for the specific use case of setting a `bundle_label_selector`
# for new worker groups (e.g., for TPU reservations). The current interface is
# tightly coupled to this purpose and limits its reuse for other use-cases.

self, *, scaling_config: ScalingConfig, num_workers: int
) -> Optional[Dict[str, str]]:
"""Called by the TrainController before the worker group is started.

This hook can be used to perform setup that modifies the worker group's
placement, such as reserving an accelerator slice.

Args:
scaling_config: The scaling configuration for the run.
num_workers: The number of workers to be started.

Returns:
An optional dictionary defining a `bundle_label_selector`
to gang schedule the worker group on the reserved TPU slice.
"""
return None

def before_controller_shutdown(self):
"""Called before `TrainController.run` exits,
after the control loop has exited."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,28 @@ def _start_worker_group(
ControllerError if the worker group failed to start.
"""
placement_strategy = self._scaling_policy.scaling_config.placement_strategy
scaling_config = self._train_run_context.scaling_config

# Check for `bundle_label_selector` to influence WorkerGroup scheduling.
bundle_label_selector = None
try:
for callback in self._controller_callbacks:
selector = callback.on_controller_start_worker_group(
scaling_config=scaling_config, num_workers=num_workers
)
if selector:
bundle_label_selector = selector
break
except Exception as e:
return ControllerError(e)

worker_group_context = WorkerGroupContext(
run_attempt_id=self._get_run_attempt_id(),
train_fn_ref=self._train_fn_ref,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
placement_strategy=placement_strategy,
bundle_label_selector=bundle_label_selector,
)
try:
self._worker_group = self.worker_group_cls.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ class WorkerGroupContext:
num_workers: The number of workers in the worker group.
resources_per_worker: The resources per worker.
placement_strategy: Strategy for placing workers.
bundle_label_selector: Optional label selectors to apply per-bundle for workers.
"""

run_attempt_id: str
train_fn_ref: ObjectRefWrapper[Callable[[], None]]
num_workers: int
resources_per_worker: Dict[str, float]
placement_strategy: str = "PACK"
bundle_label_selector: Optional[Dict[str, str]] = None


class WorkerGroup:
Expand Down Expand Up @@ -268,10 +270,18 @@ def _start_impl(
for callback in self._callbacks:
callback.before_worker_group_start(worker_group_context)

bundle_label_selector = (
[worker_group_context.bundle_label_selector.copy()]
* worker_group_context.num_workers
if worker_group_context.bundle_label_selector
else None
)

pg = placement_group(
bundles=[worker_group_context.resources_per_worker]
* worker_group_context.num_workers,
strategy=worker_group_context.placement_strategy,
bundle_label_selector=bundle_label_selector,
)
logger.info(
f"Attempting to start training worker group of size {worker_group_context.num_workers} with "
Expand Down
Loading