-
Notifications
You must be signed in to change notification settings - Fork 7k
[Train] Implement a JaxTrainer to support SPMD with TPUs #55207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76983cf
703e185
69d349c
b9f0764
c48734a
360b952
a341722
d2ee6a1
7f165a6
566a788
4885a45
62730cd
7ccc7c5
2fb6e88
536f83c
dbbd4cf
4e8bbc2
444b9de
27ab61b
e530846
67f077c
05e61cf
5588927
116c37f
4a8ea0d
e069ab7
a8b1829
4a21677
64c5a95
6598487
b8638c4
1028bfe
f10e268
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| from typing import Dict, Optional | ||
|
|
||
| import ray | ||
| from ray._private.accelerators.tpu import reserve_tpu_slice | ||
| from ray.train.v2._internal.execution.callback import ControllerCallback | ||
| from ray.train.v2.api.config import ScalingConfig | ||
|
|
||
|
|
||
| class TPUReservationCallback(ControllerCallback): | ||
andrewsykim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """A callback to handle TPU slice reservation for multi-host training.""" | ||
|
|
||
| def on_controller_start_worker_group( | ||
| self, *, scaling_config: ScalingConfig, num_workers: int | ||
| ) -> Optional[Dict[str, str]]: | ||
| """Reserves a multi-host TPU slice before the worker group starts. | ||
|
|
||
| This hook is called by the TrainController. It checks if multi-host | ||
| TPUs are being used and, if so, reserves a slice. | ||
|
|
||
| Args: | ||
| scaling_config: The scaling configuration for the run. | ||
| num_workers: The number of workers to be started. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a check for both |
||
|
|
||
| Returns: | ||
| A dictionary defining a `bundle_label_selector` to gang schedule | ||
| the worker group on the reserved TPU slice. | ||
| """ | ||
| bundle_label_selector = None | ||
|
|
||
| if scaling_config.use_tpu and num_workers > 1: | ||
| assert scaling_config.accelerator_type is not None | ||
| assert scaling_config.topology is not None | ||
|
|
||
| slice_name = reserve_tpu_slice( | ||
| topology=scaling_config.topology, | ||
| accelerator_type=scaling_config.accelerator_type, | ||
| ) | ||
| if not slice_name: | ||
| raise RuntimeError("Failed to reserve TPU slice.") | ||
|
|
||
| bundle_label_selector = { | ||
| ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name | ||
| } | ||
|
|
||
| return bundle_label_selector | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
|
||
| from ray.train.v2.api.callback import RayTrainCallback | ||
| from ray.train.v2.api.config import ScalingConfig | ||
| from ray.train.v2.api.result import Result | ||
| from ray.util.annotations import DeveloperAPI | ||
|
|
||
|
|
@@ -78,6 +79,28 @@ def after_controller_start(self, train_run_context: "TrainRunContext"): | |
| before the control loop starts executing.""" | ||
| pass | ||
|
|
||
| # TODO(matthewdeng): Revisit this callback interface for better extensibility. | ||
| # This hook was added for the specific use case of setting a `bundle_label_selector` | ||
| # for new worker groups (e.g., for TPU reservations). The current interface is | ||
| # tightly coupled to this purpose and limits its reuse for other use-cases. | ||
| def on_controller_start_worker_group( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a TODO for me to come back to this interface. It works well for this use case but I think it'll be hard to extend/support different use-cases in the future.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this: |
||
| self, *, scaling_config: ScalingConfig, num_workers: int | ||
| ) -> Optional[Dict[str, str]]: | ||
| """Called by the TrainController before the worker group is started. | ||
|
|
||
| This hook can be used to perform setup that modifies the worker group's | ||
| placement, such as reserving an accelerator slice. | ||
|
|
||
| Args: | ||
| scaling_config: The scaling configuration for the run. | ||
| num_workers: The number of workers to be started. | ||
|
|
||
| Returns: | ||
| An optional dictionary defining a `bundle_label_selector` | ||
| to gang schedule the worker group on the reserved TPU slice. | ||
| """ | ||
| return None | ||
|
|
||
| def before_controller_shutdown(self): | ||
| """Called before `TrainController.run` exits, | ||
| after the control loop has exited.""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.