-
Notifications
You must be signed in to change notification settings - Fork 7k
[Train] Implement a JaxTrainer to support SPMD with TPUs #55207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
76983cf
703e185
69d349c
b9f0764
c48734a
360b952
a341722
d2ee6a1
7f165a6
566a788
4885a45
62730cd
7ccc7c5
2fb6e88
536f83c
dbbd4cf
4e8bbc2
444b9de
27ab61b
e530846
67f077c
05e61cf
5588927
116c37f
4a8ea0d
e069ab7
a8b1829
4a21677
64c5a95
6598487
b8638c4
1028bfe
f10e268
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from typing import Dict, Optional | ||
|
|
||
| import ray | ||
| from ray.train.v2._internal.execution.callback import ControllerCallback | ||
| from ray.train.v2.api.config import ScalingConfig | ||
| from ray.train.v2.jax.tpu_utils import reserve_tpu_slice | ||
|
|
||
|
|
||
| class TPUReservationCallback(ControllerCallback): | ||
andrewsykim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """A callback to handle TPU slice reservation for multi-host training.""" | ||
|
|
||
| def on_controller_start_worker_group( | ||
| self, *, scaling_config: ScalingConfig, num_workers: int | ||
| ) -> Optional[Dict[str, str]]: | ||
| """Reserves a multi-host TPU slice before the worker group starts. | ||
| This hook is called by the TrainController. It checks if multi-host | ||
| TPUs are being used and, if so, reserves a slice. | ||
| Args: | ||
| scaling_config: The scaling configuration for the run. | ||
| num_workers: The number of workers to be started. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a check for both |
||
| Returns: | ||
| A dictionary defining a `bundle_label_selector` to gang schedule | ||
| the worker group on the reserved TPU slice. | ||
| """ | ||
| bundle_label_selector = None | ||
|
|
||
| if getattr(scaling_config, "use_tpu", False) and num_workers > 1: | ||
| slice_name = reserve_tpu_slice( | ||
| topology=getattr(scaling_config, "topology", None), | ||
| accelerator_type=getattr(scaling_config, "accelerator_type", None), | ||
|
||
| ) | ||
|
||
| if not slice_name: | ||
| raise RuntimeError("Failed to reserve TPU slice.") | ||
|
|
||
| bundle_label_selector = { | ||
| ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name | ||
| } | ||
|
|
||
| return bundle_label_selector | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
|
||
| from ray.train.v2.api.callback import RayTrainCallback | ||
| from ray.train.v2.api.config import ScalingConfig | ||
| from ray.train.v2.api.result import Result | ||
| from ray.util.annotations import DeveloperAPI | ||
|
|
||
|
|
@@ -78,6 +79,24 @@ def after_controller_start(self, train_run_context: "TrainRunContext"): | |
| before the control loop starts executing.""" | ||
| pass | ||
|
|
||
| def on_controller_start_worker_group( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a TODO for me to come back to this interface. It works well for this use case but I think it'll be hard to extend/support different use-cases in the future.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this: |
||
| self, *, scaling_config: ScalingConfig, num_workers: int | ||
| ) -> Optional[Dict[str, str]]: | ||
| """Called by the TrainController before the worker group is started. | ||
|
|
||
| This hook can be used to perform setup that modifies the worker group's | ||
| placement, such as reserving an accelerator slice. | ||
|
|
||
| Args: | ||
| scaling_config: The scaling configuration for the run. | ||
| num_workers: The number of workers to be started. | ||
|
|
||
| Returns: | ||
| An optional dictionary defining a `bundle_label_selector` | ||
| to gang schedule the worker group on the reserved TPU slice. | ||
| """ | ||
| return None | ||
|
|
||
| def before_controller_shutdown(self): | ||
| """Called before `TrainController.run` exits, | ||
| after the control loop has exited.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -280,12 +280,29 @@ def _start_worker_group( | |
| ControllerError if the worker group failed to start. | ||
| """ | ||
| placement_strategy = self._scaling_policy.scaling_config.placement_strategy | ||
| scaling_config = self._train_run_context.scaling_config | ||
|
|
||
| # Check for `bundle_label_selector` to influence WorkerGroup scheduling. | ||
| bundle_label_selector = None | ||
| try: | ||
| for callback in self._callbacks: | ||
| if hasattr(callback, "on_controller_start_worker_group"): | ||
| selector = callback.on_controller_start_worker_group( | ||
| scaling_config=scaling_config, num_workers=num_workers | ||
|
||
| ) | ||
| if selector: | ||
| bundle_label_selector = selector | ||
| break | ||
| except Exception as e: | ||
| return ControllerError(e) | ||
|
|
||
| worker_group_context = WorkerGroupContext( | ||
| run_attempt_id=self._get_run_attempt_id(), | ||
| train_fn_ref=self._train_fn_ref, | ||
| num_workers=num_workers, | ||
| resources_per_worker=resources_per_worker, | ||
| placement_strategy=placement_strategy, | ||
| bundle_label_selector=bundle_label_selector, | ||
| ) | ||
| try: | ||
| self._worker_group = self.worker_group_cls.create( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||||
| import logging | ||||||||
| from dataclasses import dataclass | ||||||||
| from pathlib import Path | ||||||||
| from typing import TYPE_CHECKING, List, Optional, Union | ||||||||
| from typing import TYPE_CHECKING, Dict, List, Optional, Union | ||||||||
|
|
||||||||
| import pyarrow.fs | ||||||||
|
|
||||||||
|
|
@@ -21,7 +21,9 @@ | |||||||
|
|
||||||||
| if TYPE_CHECKING: | ||||||||
| from ray.train import UserCallback | ||||||||
| from ray.tune.search.sample import Domain | ||||||||
|
|
||||||||
| SampleRange = Union["Domain", Dict[str, List]] | ||||||||
|
||||||||
| from ray.tune.search.sample import Domain | |
| SampleRange = Union["Domain", Dict[str, List]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ryanaoleary marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this specific to TPUs? Should it be tpu_topology instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not super familiar with GPUs, but I think the field can probably be extended to set fields automatically in the Config (when left out) for GPUs too - so leaving it as topology might be fine. I don't have much of a preference either way though.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is topology a first-class Ray Core concept? We'd want to make sure it's easy to understand from the API what inputs this takes in and how it'll be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also for TPU users how familiar are topology/accelerator? Would it be easier for the user to just specify the pod type directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see topology used in Ray core at all, except to configure TPU env vars and node labels - but any users of multi-host TPUs should be familiar with the concept. The concept is also already introduced in KubeRay through the numOfHosts field: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/tpu.html.
I think topology and accelerator type are the best top-level variables for users to specify, since currently in GKE these are the two values users configure when creating their GKE nodepool and when scheduling pods to it using the cloud.google.com/gke-tpu-accelerator and cloud.google.com/gke-tpu-topology nodeSelectors: https://cloud.google.com/kubernetes-engine/docs/how-to/tpus.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Topology is quite standard TPU concept. TPU type / Pod Type is in some cases not uniquely mapped to a topology.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, is it safe to say that this API would then be super intuitive for a TPU user? Is there any other grouping/organization that might be more natural to how a user thinks about setting up their workload?
scaling_config=ScalingConfig(
use_tpu=True,
num_workers=4,
topology="2x2x4",
accelerator_type="TPU-V4",
resources_per_worker={"TPU": 4},
placement_strategy="SPREAD",
),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think this top level API should be clear for TPU users - the only thing I can think of is that we could have num_workers, resources_per_worker and placement_strategy be auto-set based on the topology if not provided. For example, if we have a multi-host topology of 4x4 v6e we could automatically detect that num_workers should be 4, resources_per_worker should be TPU: 4 since that's the number of chips on each host, and placement_strategy should be SPREAD.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove SampleRange part, no longer needed for Train V2.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think the outer resources_per_worker check is needed.
andrewsykim marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| try: | ||
| import jax # noqa: F401 | ||
| except ModuleNotFoundError as exception: | ||
| raise ModuleNotFoundError( | ||
| "Jax isn't installed. To install Jax, please check" | ||
| " `https://github.com/google/jax#installation` for the instructions." | ||
| ) from exception | ||
|
|
||
| from ray.train.v2.jax.config import JaxConfig | ||
| from ray.train.v2.jax.jax_trainer import JaxTrainer | ||
|
|
||
| __all__ = ["JaxConfig", "JaxTrainer"] |
Uh oh!
There was an error while loading. Please reload this page.