Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
76983cf
JaxTrainer support in V2 with SPMD
ryanaoleary Jul 30, 2025
703e185
Fix errors and comments
ryanaoleary Aug 5, 2025
69d349c
Add minimal unit tests and fix scheduling logic to use one placement …
ryanaoleary Aug 6, 2025
b9f0764
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 6, 2025
c48734a
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 7, 2025
360b952
Fix tpu default labels population
ryanaoleary Aug 7, 2025
a341722
Add callback and fix other comments
ryanaoleary Aug 7, 2025
d2ee6a1
Fix wording and remove unneeded change
ryanaoleary Aug 7, 2025
7f165a6
Fix test code string
ryanaoleary Aug 7, 2025
566a788
remove unused fields from config
ryanaoleary Aug 8, 2025
4885a45
Make fields required for multi-host 'use_tpu'
ryanaoleary Aug 8, 2025
62730cd
Reserve tpu slice take required args
ryanaoleary Aug 8, 2025
7ccc7c5
Update example now that tested on v6e
ryanaoleary Aug 8, 2025
2fb6e88
add jax pip install
ryanaoleary Aug 9, 2025
536f83c
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 9, 2025
dbbd4cf
Merge branch 'master' into implement-jax-trainer
ryanaoleary Aug 11, 2025
4e8bbc2
ensure TPUReservationCallback is executed and fix minor review comments
andrewsykim Aug 12, 2025
444b9de
refactor _JaxBackend.onStart to fetch unique worker address and port …
andrewsykim Aug 12, 2025
27ab61b
Merge branch 'master' into implement-jax-trainer
andrewsykim Aug 12, 2025
e530846
temporarily remove TPU callback import
andrewsykim Aug 12, 2025
67f077c
Revert "temporarily remove TPU callback import"
andrewsykim Aug 12, 2025
05e61cf
formatting fixes
andrewsykim Aug 12, 2025
5588927
fix lint errors
andrewsykim Aug 12, 2025
116c37f
move some TPU util functions to Ray Core to resolve import errors
andrewsykim Aug 12, 2025
4a8ea0d
remove test_tpu_utils reference in ray train v2 BUILD
andrewsykim Aug 12, 2025
e069ab7
fix code format and lint failures
andrewsykim Aug 12, 2025
a8b1829
fix code formatting
andrewsykim Aug 13, 2025
4a21677
move field assertion to after use_tpu condition
andrewsykim Aug 13, 2025
64c5a95
Merge branch 'master' into implement-jax-trainer
andrewsykim Aug 13, 2025
6598487
address nits from matthewdeng
andrewsykim Aug 13, 2025
b8638c4
remove unused import caught by lint check
andrewsykim Aug 13, 2025
1028bfe
fix docstring for JaxTrainer
andrewsykim Aug 13, 2025
f10e268
fix docstring order for JaxTrainer
andrewsykim Aug 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/ray/_private/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ def get_tpu_cores_per_chip(accelerator_type: str) -> int:
return DEFAULT_TPU_NUM_CORES_PER_CHIP


def infer_tpu_pod_type_from_topology(
topology: str, accelerator_type: str
) -> Optional[str]:
"""Infer the TPU pod type (e.g. v4-32) from topology and accelerator type."""
try:
num_chips = 1
for value in topology.strip().lower().split("x"):
num_chips *= int(value)
generation = accelerator_type.lower().replace("tpu-", "")
return f"{generation}-{num_chips}"
except Exception as e:
logger.warning(
f"Failed to infer pod type from topology {topology} and type {accelerator_type}: {e}"
)
return None


class TPUAcceleratorManager(AcceleratorManager):
"""Google TPU accelerators."""

Expand Down
8 changes: 4 additions & 4 deletions python/ray/_private/resource_and_label_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ray._common.utils import RESOURCE_CONSTRAINT_PREFIX
from ray._private import accelerators
from ray._private.accelerators import AcceleratorManager
from ray._private.accelerators.tpu import TPUAcceleratorManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -292,10 +291,11 @@ def _get_default_labels(
ray._raylet.RAY_NODE_ACCELERATOR_TYPE_KEY
] = accelerator_type

# Set TPU specific default labels to enable SPMD scheduling.
if isinstance(accelerator_manager, TPUAcceleratorManager):
# Set TPU specific default labels to enable multi-host scheduling.
if accelerator_manager.get_resource_name() == "TPU":
tpu_labels = accelerator_manager.get_current_node_accelerator_labels()
default_labels.update(tpu_labels)
if tpu_labels:
default_labels.update(tpu_labels)

return default_labels

Expand Down
21 changes: 21 additions & 0 deletions python/ray/tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,5 +353,26 @@ def test_get_current_node_tpu_topology_from_metadata():
assert topology == "2x2x4"


@pytest.mark.parametrize(
"topology, accelerator_type, expected_pod_type",
[
("2x4", "TPU-V6E", "v6e-8"),
("2x2x2", "TPU-V4", "v4-8"),
("2x4x4", "TPU-V3", "v3-32"),
("4x4", "TPU-V5P", "v5p-16"),
("8x16", "TPU-V6E", "v6e-128"),
("", "TPU-V3", None),
("4x", "TPU-V3", None),
],
)
def test_infer_tpu_pod_type_from_topology(
topology, accelerator_type, expected_pod_type
):
assert (
tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type)
== expected_pod_type
)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
32 changes: 32 additions & 0 deletions python/ray/train/v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,35 @@ py_test(
"//:ray_lib",
],
)

py_test(
name = "test_jax_trainer",
size = "small",
srcs = ["tests/test_jax_trainer.py"],
env = {"RAY_TRAIN_V2_ENABLED": "1"},
tags = [
"exclusive",
"team:ml",
"train_v2",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_tpu_utils",
size = "small",
srcs = ["tests/test_tpu_utils.py"],
env = {"RAY_TRAIN_V2_ENABLED": "1"},
tags = [
"exclusive",
"team:ml",
"train_v2",
],
deps = [
":conftest",
"//:ray_lib",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict, Optional

import ray
from ray.train.v2._internal.execution.callback import ControllerCallback
from ray.train.v2.api.config import ScalingConfig
from ray.train.v2.jax.tpu_utils import reserve_tpu_slice


class TPUReservationCallback(ControllerCallback):
"""A callback to handle TPU slice reservation for multi-host training."""

def on_controller_start_worker_group(
self, *, scaling_config: ScalingConfig, num_workers: int
) -> Optional[Dict[str, str]]:
"""Reserves a multi-host TPU slice before the worker group starts.
This hook is called by the TrainController. It checks if multi-host
TPUs are being used and, if so, reserves a slice.
Args:
scaling_config: The scaling configuration for the run.
num_workers: The number of workers to be started.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_workers can come from scaling_config as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check for both

Returns:
A dictionary defining a `bundle_label_selector` to gang schedule
the worker group on the reserved TPU slice.
"""
bundle_label_selector = None

if getattr(scaling_config, "use_tpu", False) and num_workers > 1:
slice_name = reserve_tpu_slice(
topology=getattr(scaling_config, "topology", None),
accelerator_type=getattr(scaling_config, "accelerator_type", None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert that topology and accelerator_type are both set here.

Note to self: not sure if it's best to have the validation here, when this Callback is initialized, or when validating the ScalingConifg.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Access attributes directly from ScalingConfig rather than get_attr.

if not slice_name:
raise RuntimeError("Failed to reserve TPU slice.")

bundle_label_selector = {
ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name
}

return bundle_label_selector
19 changes: 19 additions & 0 deletions python/ray/train/v2/_internal/execution/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ray.train.v2.api.callback import RayTrainCallback
from ray.train.v2.api.config import ScalingConfig
from ray.train.v2.api.result import Result
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -78,6 +79,24 @@ def after_controller_start(self, train_run_context: "TrainRunContext"):
before the control loop starts executing."""
pass

def on_controller_start_worker_group(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO for me to come back to this interface. It works well for this use case but I think it'll be hard to extend/support different use-cases in the future.

  1. The bundle_label_selector logic is a bit specific for this callback.
  2. Behavior for creating multiple callbacks with bundle_label_selector logic is undefined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this:

# TODO(matthewdeng): Revisit this callback interface for better extensibility.
# This hook was added for the specific use case of setting a `bundle_label_selector`
# for new worker groups (e.g., for TPU reservations). The current interface is
# tightly coupled to this purpose and limits its reuse for other use-cases.

self, *, scaling_config: ScalingConfig, num_workers: int
) -> Optional[Dict[str, str]]:
"""Called by the TrainController before the worker group is started.

This hook can be used to perform setup that modifies the worker group's
placement, such as reserving an accelerator slice.

Args:
scaling_config: The scaling configuration for the run.
num_workers: The number of workers to be started.

Returns:
An optional dictionary defining a `bundle_label_selector`
to gang schedule the worker group on the reserved TPU slice.
"""
return None

def before_controller_shutdown(self):
"""Called before `TrainController.run` exits,
after the control loop has exited."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,29 @@ def _start_worker_group(
ControllerError if the worker group failed to start.
"""
placement_strategy = self._scaling_policy.scaling_config.placement_strategy
scaling_config = self._train_run_context.scaling_config

# Check for `bundle_label_selector` to influence WorkerGroup scheduling.
bundle_label_selector = None
try:
for callback in self._callbacks:
if hasattr(callback, "on_controller_start_worker_group"):
selector = callback.on_controller_start_worker_group(
scaling_config=scaling_config, num_workers=num_workers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get thes ControllerCallback instead, which will always have this method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
if selector:
bundle_label_selector = selector
break
except Exception as e:
return ControllerError(e)

worker_group_context = WorkerGroupContext(
run_attempt_id=self._get_run_attempt_id(),
train_fn_ref=self._train_fn_ref,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
placement_strategy=placement_strategy,
bundle_label_selector=bundle_label_selector,
)
try:
self._worker_group = self.worker_group_cls.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ class WorkerGroupContext:
num_workers: The number of workers in the worker group.
resources_per_worker: The resources per worker.
placement_strategy: Strategy for placing workers.
bundle_label_selector: Optional label selectors to apply per-bundle for workers.
"""

run_attempt_id: str
train_fn_ref: ObjectRefWrapper[Callable[[], None]]
num_workers: int
resources_per_worker: Dict[str, float]
placement_strategy: str = "PACK"
bundle_label_selector: Optional[Dict[str, str]] = None


class WorkerGroup:
Expand Down Expand Up @@ -268,10 +270,18 @@ def _start_impl(
for callback in self._callbacks:
callback.before_worker_group_start(worker_group_context)

bundle_label_selector = (
[worker_group_context.bundle_label_selector.copy()]
* worker_group_context.num_workers
if worker_group_context.bundle_label_selector
else None
)

pg = placement_group(
bundles=[worker_group_context.resources_per_worker]
* worker_group_context.num_workers,
strategy=worker_group_context.placement_strategy,
bundle_label_selector=bundle_label_selector,
)
logger.info(
f"Attempting to start training worker group of size {worker_group_context.num_workers} with "
Expand Down
48 changes: 47 additions & 1 deletion python/ray/train/v2/api/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import pyarrow.fs

Expand All @@ -21,7 +21,9 @@

if TYPE_CHECKING:
from ray.train import UserCallback
from ray.tune.search.sample import Domain

SampleRange = Union["Domain", Dict[str, List]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this.

Suggested change
from ray.tune.search.sample import Domain
SampleRange = Union["Domain", Dict[str, List]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,6 +54,14 @@ class ScalingConfig(ScalingConfigV1):
See :ref:`the available accelerator types <accelerator_types>`.
Ensure that your cluster has instances with the specified accelerator type
or is able to autoscale to fulfill the request.
use_tpu: [Experimental] If True, training will be done on TPUs (1 TPU VM
per worker). Defaults to False. The number of TPUs reserved by each
worker can be overridden with the ``resources_per_worker``
argument. This arg enables SPMD execution of the training workload.
topology: [Experimental] If specified, Ray Train will launch the training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this specific to TPUs? Should it be tpu_topology instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super familiar with GPUs, but I think the field can probably be extended to set fields automatically in the Config (when left out) for GPUs too - so leaving it as topology might be fine. I don't have much of a preference either way though.

coordinator and workers on nodes with the specified topology. Topology is
auto-detected for TPUs and added as Ray node labels. This arg enables
SPMD execution of the training workload.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is topology a first-class Ray Core concept? We'd want to make sure it's easy to understand from the API what inputs this takes in and how it'll be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for TPU users how familiar are topology/accelerator? Would it be easier for the user to just specify the pod type directly?

Copy link
Contributor Author

@ryanaoleary ryanaoleary Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see topology used in Ray core at all, except to configure TPU env vars and node labels - but any users of multi-host TPUs should be familiar with the concept. The concept is also already introduced in KubeRay through the numOfHosts field: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/tpu.html.

I think topology and accelerator type are the best top-level variables for users to specify, since currently in GKE these are the two values users configure when creating their GKE nodepool and when scheduling pods to it using the cloud.google.com/gke-tpu-accelerator and cloud.google.com/gke-tpu-topology nodeSelectors: https://cloud.google.com/kubernetes-engine/docs/how-to/tpus.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Topology is quite standard TPU concept. TPU type / Pod Type is in some cases not uniquely mapped to a topology.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, is it safe to say that this API would then be super intuitive for a TPU user? Is there any other grouping/organization that might be more natural to how a user thinks about setting up their workload?

        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=4,
            topology="2x2x4",
            accelerator_type="TPU-V4",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),

Copy link
Contributor Author

@ryanaoleary ryanaoleary Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this top level API should be clear for TPU users - the only thing I can think of is that we could have num_workers, resources_per_worker and placement_strategy be auto-set based on the topology if not provided. For example, if we have a multi-host topology of 4x4 v6e we could automatically detect that num_workers should be 4, resources_per_worker should be TPU: 4 since that's the number of chips on each host, and placement_strategy should be SPREAD.


Example:

Expand All @@ -73,17 +83,53 @@ class ScalingConfig(ScalingConfigV1):
"""

trainer_resources: Optional[dict] = None
use_tpu: Union[bool, SampleRange] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove SampleRange part, no longer needed for Train V2.

topology: Optional[str] = None

def __post_init__(self):
if self.trainer_resources is not None:
raise DeprecationWarning(TRAINER_RESOURCES_DEPRECATION_MESSAGE)

if self.resources_per_worker:
if self.use_gpu and self.use_tpu:
raise ValueError(
"Cannot specify both `use_gpu=True` and `use_tpu=True`."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think the outer resources_per_worker check is needed.


if not self.use_tpu and self.num_tpus_per_worker > 0:
raise ValueError(
"`use_tpu` is False but `TPU` was found in "
"`resources_per_worker`. Either set `use_tpu` to True or "
"remove `TPU` from `resources_per_worker."
)

if self.use_tpu and self.num_tpus_per_worker == 0:
raise ValueError(
"`use_tpu` is True but `TPU` is set to 0 in "
"`resources_per_worker`. Either set `use_tpu` to False or "
"request a positive number of `TPU` in "
"`resources_per_worker."
)

super().__post_init__()

@property
def _resources_per_worker_not_none(self):
if self.resources_per_worker is None:
if self.use_tpu:
return {"TPU": 1}

return super()._resources_per_worker_not_none

@property
def _trainer_resources_not_none(self):
return {}

@property
def num_tpus_per_worker(self):
"""The number of TPUs to set per worker."""
return self._resources_per_worker_not_none.get("TPU", 0)


@dataclass
class FailureConfig(FailureConfigV1):
Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/v2/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
try:
import jax # noqa: F401
except ModuleNotFoundError as exception:
raise ModuleNotFoundError(
"Jax isn't installed. To install Jax, please check"
" `https://github.com/google/jax#installation` for the instructions."
) from exception

from ray.train.v2.jax.config import JaxConfig
from ray.train.v2.jax.jax_trainer import JaxTrainer

__all__ = ["JaxConfig", "JaxTrainer"]
Loading