diff --git a/doc/source/ray-core/api/utility.rst b/doc/source/ray-core/api/utility.rst index 38d3026f629b..79444f7e5bfd 100644 --- a/doc/source/ray-core/api/utility.rst +++ b/doc/source/ray-core/api/utility.rst @@ -12,9 +12,11 @@ Utility ray.util.serialization.register_serializer ray.util.serialization.deregister_serializer - ray.util.accelerators.tpu.get_current_pod_worker_count - ray.util.accelerators.tpu.get_current_pod_name - ray.util.accelerators.tpu.get_num_tpu_chips_on_node + ray.util.tpu.get_current_pod_worker_count + ray.util.tpu.get_current_pod_name + ray.util.tpu.get_num_tpu_chips_on_node + ray.util.tpu.SlicePlacementGroup + ray.util.tpu.slice_placement_group ray.nodes ray.cluster_resources diff --git a/python/ray/_private/accelerators/tpu.py b/python/ray/_private/accelerators/tpu.py index 83da22475879..2115bf1d7f9e 100644 --- a/python/ray/_private/accelerators/tpu.py +++ b/python/ray/_private/accelerators/tpu.py @@ -3,7 +3,7 @@ import os import re from functools import lru_cache -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple import requests @@ -64,6 +64,43 @@ # The valid TPU types. VALID_TPU_TYPES = ("v2", "v3", "v4", "v5p", "v5litepod", "v6e") +# This is only used to construct TPU 3D topologies +def _get_larger_3d_topologies(max_x: int, max_y: int, max_z: int) -> Set[str]: + """Returns a set of larger 3D TPU topologies given the max x,y,z value. Using DEFAULT_TPU_NUM_CHIPS_PER_HOST as increment""" + topologies = set() + for x in range( + DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_x + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST + ): + for y in range( + DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_y + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST + ): + for z in range( + DEFAULT_TPU_NUM_CHIPS_PER_HOST, + max_z + 1, + DEFAULT_TPU_NUM_CHIPS_PER_HOST, + ): + topologies.add(f"{x}x{y}x{z}") + + return topologies + + +# The valid TPU topologies for each of the TPU types +VALID_TPU_TOPOLOGY = { + "v2": {"4x4", "4x8", "8x8", "8x16", "16x16"}, + "v3": {"4x4", "4x8", "8x8", "8x16", "16x16", "16x32", "32x32"}, + "v4": {"2x2x1", "2x2x2", "2x2x4", "2x4x4"}.union( + _get_larger_3d_topologies(12, 12, 16) + ), + "v5p": { + "2x2x1", + "2x2x2", + "2x2x4", + "2x4x4", + }.union(_get_larger_3d_topologies(16, 16, 24)), + "v5litepod": {"2x8", "4x4", "4x8", "8x8", "8x16", "16x16"}, + "v6e": {"2x8", "4x4", "4x8", "8x8", "8x16", "16x16"}, +} + def _get_tpu_metadata(key: str) -> Optional[str]: """Poll and get TPU metadata.""" @@ -115,6 +152,8 @@ def infer_tpu_pod_type_from_topology( topology: str, accelerator_type: str ) -> Optional[str]: """Infer the TPU pod type (e.g. v4-32) from topology and accelerator type.""" + if not topology or not accelerator_type: + return None try: num_chips = 1 for value in topology.strip().lower().split("x"): @@ -122,10 +161,10 @@ def infer_tpu_pod_type_from_topology( generation = accelerator_type.lower().replace("tpu-", "") return f"{generation}-{num_chips}" except Exception as e: - logger.warning( - f"Failed to infer pod type from topology {topology} and type {accelerator_type}: {e}" - ) - return None + raise ValueError( + f"Failed to infer pod type from topology '{topology}' " + f"and type '{accelerator_type}'" + ) from e def fetch_tpu_slice_name_from_pg(pg): @@ -142,6 +181,35 @@ def _get_tpu_slice_name(): return ray.get(tpu_name_ref) +def get_chips_per_host(topology: str, accelerator_version: str) -> int: + """Get the number of chips per host (aka VMs) based on topology and accelerator version. + The current rule is as follows: + Default chips per host is 4. + If accelerator_version is v5e or v6e AND topology product <= 8, the chips per host will just be the proudct. i.e. 1, 4, or 8 + If accelerator_version is v5e or v6e AND topology product > 8, the chips per host will be 4 + If accelerator_version is v5p or other versions, the chips per host will be 4 + + Args: + topology: The TPU topology string (e.g. "2x2x2"). + accelerator_version: The accelerator version of the node (e.g. "V4", "v4"). + + Returns: + A int representing the number of chips per host (aka VM) + """ + chips_per_host = DEFAULT_TPU_NUM_CHIPS_PER_HOST + total_chips = 1 + for value in topology.strip().lower().split("x"): + total_chips *= int(value) + + if ( + total_chips <= 8 + and accelerator_version.strip().lower() in SINGLE_HOST_8_CHIPS_TPU_TYPES + ): + return total_chips + + return chips_per_host + + def reserve_tpu_slice( topology: str, accelerator_type: str, @@ -265,6 +333,32 @@ def is_valid_tpu_accelerator_type(tpu_accelerator_type: str) -> bool: return False return True + @staticmethod + def is_valid_tpu_accelerator_topology( + tpu_accelerator_version: str, tpu_topology: str + ) -> bool: + """Check whether the tpu topology is valid. + + The accelerator_type field follows a form of v{generation}. + The accelerator_topology field follows either the form {A}x{B} or {A}x{B}x{C} depending on the v{generation} + + Args: + tpu_accelerator_version: The string representation of the accelerator version. (e.g. v6e, V5P) + tpu_topology: The string representation of the accelerator topology + to be checked for validity + + Returns: + True if it's valid topology, false othrwise + """ + tpu_version_formatted = tpu_accelerator_version.strip().lower().split("-")[0] + if ( + tpu_version_formatted.lower() not in VALID_TPU_TOPOLOGY + or tpu_topology.strip().lower() + not in VALID_TPU_TOPOLOGY[tpu_version_formatted] + ): + return False + return True + @staticmethod def validate_resource_request_quantity( quantity: float, @@ -510,8 +604,8 @@ def my_jax_fn(): @ray.remote(resources={"TPU-v4-16-head"}) def run_jax_fn(executable): # Note this will execute on worker 0 - tpu_name = ray.util.accelerators.tpu.get_tpu_pod_name() - num_workers = ray.util.accelerators.tpu.get_tpu_num_workers() + tpu_name = ray.util.tpu.get_tpu_pod_name() + num_workers = ray.util.tpu.get_tpu_num_workers() tpu_executable = executable.options(resources={"TPU": 4, tpu_name: 1}) return [tpu_executable.remote() for _ in range(num_workers)] diff --git a/python/ray/tests/BUILD.bazel b/python/ray/tests/BUILD.bazel index 963f4609bd0c..82410d5276b5 100644 --- a/python/ray/tests/BUILD.bazel +++ b/python/ray/tests/BUILD.bazel @@ -581,6 +581,7 @@ py_test_module_list( "test_task_events_3.py", "test_task_metrics_reconstruction.py", "test_top_level_api.py", + "test_tpu.py", "test_tqdm.py", "test_unhandled_error.py", "test_wait.py", diff --git a/python/ray/tests/accelerators/test_tpu.py b/python/ray/tests/accelerators/test_tpu.py index 3f2c53286996..6ed4eed9efe7 100644 --- a/python/ray/tests/accelerators/test_tpu.py +++ b/python/ray/tests/accelerators/test_tpu.py @@ -6,9 +6,7 @@ import pytest import requests -import ray from ray._private.accelerators import TPUAcceleratorManager, tpu -from ray.tests.conftest import _ray_start_cluster @patch("glob.glob") @@ -262,168 +260,5 @@ def test_tpu_pod_detect_and_configure_worker(test_config): assert final_resources == expected_value -def test_get_current_pod_name_smoke(): - with patch( - "ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name", - return_value="my-tpu", - ): - name = ray.util.accelerators.tpu.get_current_pod_name() - assert name == "my-tpu" - - -def test_empty_get_current_pod_name_returns_none(): - with patch( - "ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name", - return_value="", - ): - name = ray.util.accelerators.tpu.get_current_pod_name() - assert name is None - - -@pytest.mark.parametrize( - "test_case", - [ - # (number_chips_per_host, accl_type, expected_worker_count) - (4, "v2-4", 1), - (4, "v3-32", 4), - (4, "v4-8", 1), - (4, "v4-16", 2), - (8, "v5litepod-4", 1), - (8, "v5litepod-8", 1), - (8, "v5litepod-16", 2), - (8, "v5litepod-32", 4), - (4, "v5p-4", 1), - (4, "v5p-8", 1), - (4, "v5p-16", 2), - (8, "v6e-4", 1), - (8, "v6e-8", 1), - (8, "v6e-16", 2), - ], -) -@patch("glob.glob") -def test_worker_count(mock_glob, test_case): - num_devices, accelerator_type, expected_worker_count = test_case - mock_glob.return_value = ["/dev/accel" + str(x) for x in range(num_devices)] - TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear() - - with patch( - "ray._private.accelerators.tpu.TPUAcceleratorManager." - "get_current_node_tpu_pod_type", - return_value=accelerator_type, - ): - worker_count = ray.util.accelerators.tpu.get_current_pod_worker_count() - - assert worker_count == expected_worker_count - - -@patch("glob.glob") -def test_num_tpu_chips(mock_glob): - mock_glob.return_value = [ - "/dev/accel0", - "/dev/accel1", - "/dev/accel2", - "/dev/accel3", - ] - TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear() - num_tpu_chips = ray.util.accelerators.tpu.get_num_tpu_chips_on_node() - assert num_tpu_chips == 4 - - -def test_get_current_node_labels_env_only(monkeypatch): - # Simulate GKE TPU environment variables - monkeypatch.setenv("TPU_NAME", "tpu-worker-group-2") - monkeypatch.setenv("TPU_WORKER_ID", "0") - monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-16") - monkeypatch.setenv("TPU_TOPOLOGY", "4x4") - - tpu_labels = TPUAcceleratorManager.get_current_node_accelerator_labels() - - assert tpu_labels["ray.io/tpu-slice-name"] == "tpu-worker-group-2" - assert tpu_labels["ray.io/tpu-worker-id"] == "0" - assert tpu_labels["ray.io/tpu-topology"] == "4x4" - assert tpu_labels["ray.io/tpu-pod-type"] == "v6e-16" - - -def test_get_current_node_tpu_topology_from_metadata(): - tpu_env_string = "TPU_ACCELERATOR:v6e.\nTOPOLOGY: '2x2x4'\nTPU_HOST_BOUNDS:0,1,1,2" - - with patch( - "ray._private.accelerators.tpu._get_tpu_metadata", return_value=tpu_env_string - ): - topology = TPUAcceleratorManager.get_current_node_tpu_topology() - assert topology == "2x2x4" - - -@pytest.mark.parametrize( - "topology, accelerator_type, expected_pod_type", - [ - ("2x4", "TPU-V6E", "v6e-8"), - ("2x2x2", "TPU-V4", "v4-8"), - ("2x4x4", "TPU-V3", "v3-32"), - ("4x4", "TPU-V5P", "v5p-16"), - ("8x16", "TPU-V6E", "v6e-128"), - ("", "TPU-V3", None), - ("4x", "TPU-V3", None), - ], -) -def test_infer_tpu_pod_type_from_topology( - topology, accelerator_type, expected_pod_type -): - assert ( - tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type) - == expected_pod_type - ) - - -@pytest.fixture -def ray_start_cpu(): - address_info = ray.init(num_cpus=1) - yield address_info - ray.shutdown() - - -@pytest.fixture -def ray_tpu_cluster(monkeypatch): - """Start a mock TPU Ray cluster.""" - with _ray_start_cluster() as cluster: - monkeypatch.setenv("TPU_NAME", "test-slice-0") - monkeypatch.setenv("TPU_WORKER_ID", "0") - monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8") - monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2") - - cluster.add_node( - num_cpus=2, - resources={"TPU": 4, "TPU-v4-8-head": 1}, - ) - monkeypatch.setenv("TPU_WORKER_ID", "1") - cluster.add_node( - num_cpus=2, - resources={"TPU": 4}, - ) - ray.init(address=cluster.address) - - yield cluster - ray.shutdown() - - -def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster): - """Tests that the slice name can be fetched from a PG.""" - tpu_head_pg = ray.util.placement_group(bundles=[{"TPU-v4-8-head": 1}]) - ray.get(tpu_head_pg.ready()) - - tpu_slice_name = "test-slice-0" - slice_name = tpu.fetch_tpu_slice_name_from_pg(tpu_head_pg) - assert slice_name == tpu_slice_name - - ray.util.remove_placement_group(tpu_head_pg) - - -def test_reserve_tpu_slice(ray_tpu_cluster): - """Tests that a TPU slice can be successfully reserved.""" - tpu_slice_name = "test-slice-0" - reserved_name = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4") - assert reserved_name == tpu_slice_name - - if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_tpu.py b/python/ray/tests/test_tpu.py new file mode 100644 index 000000000000..6213e5f7eb71 --- /dev/null +++ b/python/ray/tests/test_tpu.py @@ -0,0 +1,269 @@ +import sys +from unittest.mock import patch + +import pytest + +import ray +from ray._private.accelerators import TPUAcceleratorManager, tpu +from ray.tests.conftest import _ray_start_cluster + + +def test_get_current_pod_name_smoke(): + with patch( + "ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name", + return_value="my-tpu", + ): + name = ray.util.tpu.get_current_pod_name() + assert name == "my-tpu" + + +def test_empty_get_current_pod_name_returns_none(): + with patch( + "ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name", + return_value="", + ): + name = ray.util.tpu.get_current_pod_name() + assert name is None + + +@pytest.mark.parametrize( + "test_case", + [ + # (number_chips_per_host, accl_type, expected_worker_count) + (4, "v2-4", 1), + (4, "v3-32", 4), + (4, "v4-8", 1), + (4, "v4-16", 2), + (8, "v5litepod-4", 1), + (8, "v5litepod-8", 1), + (8, "v5litepod-16", 2), + (8, "v5litepod-32", 4), + (4, "v5p-4", 1), + (4, "v5p-8", 1), + (4, "v5p-16", 2), + (8, "v6e-4", 1), + (8, "v6e-8", 1), + (8, "v6e-16", 2), + ], +) +@patch("glob.glob") +def test_worker_count(mock_glob, test_case): + num_devices, accelerator_type, expected_worker_count = test_case + mock_glob.return_value = ["/dev/accel" + str(x) for x in range(num_devices)] + TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear() + + with patch( + "ray._private.accelerators.tpu.TPUAcceleratorManager." + "get_current_node_tpu_pod_type", + return_value=accelerator_type, + ): + worker_count = ray.util.tpu.get_current_pod_worker_count() + + assert worker_count == expected_worker_count + + +@patch("glob.glob") +def test_num_tpu_chips(mock_glob): + mock_glob.return_value = [ + "/dev/accel0", + "/dev/accel1", + "/dev/accel2", + "/dev/accel3", + ] + TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear() + num_tpu_chips = ray.util.tpu.get_num_tpu_chips_on_node() + assert num_tpu_chips == 4 + + +@pytest.mark.parametrize( + "test_case", + [ + # (accelerator_type, accelerator_topology, expected_result) + ("v2-16", "4x4", True), + ("v2-256", "16x16", True), + ("v2-4", "2x2", False), + ("v3-16", "4x4", True), + ("v3-1024", "32x32", True), + ("v3-4", "4x16", False), + ("v4-4", "2x2x1", True), + ("v4-32", "2x4x4", True), + ("v4-2048", "8x8x16", True), + ("v4-4", "16x16x16", False), + ("v5p-64", "4x4x4", True), + ("v5p-4096", "16x16x16", True), + ("v5p-6144", "16x16x24", True), + ("v5p-4", "24x24x24", False), + ("v5litepod-16", "2x8", True), + ("v5litepod-256", "16x16", True), + ("v5litepod-4", "2x2", False), + ("v6e-16", "4x4", True), + ("v6e-64", "8x8", True), + ("v6e-4", "4x16", False), + ], +) +@patch("glob.glob") +def test_is_valid_tpu_accelerator_topology(_mock_glob, test_case): + """Test valid TPU accelerator topologies.""" + accelerator_type, accelerator_topology, expected_result = test_case + actual_result = TPUAcceleratorManager.is_valid_tpu_accelerator_topology( + accelerator_type, accelerator_topology + ) + + assert actual_result == expected_result + + +def test_get_current_node_labels_env_only(monkeypatch): + # Simulate GKE TPU environment variables + monkeypatch.setenv("TPU_NAME", "tpu-worker-group-2") + monkeypatch.setenv("TPU_WORKER_ID", "0") + monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-16") + monkeypatch.setenv("TPU_TOPOLOGY", "4x4") + + tpu_labels = TPUAcceleratorManager.get_current_node_accelerator_labels() + + assert tpu_labels["ray.io/tpu-slice-name"] == "tpu-worker-group-2" + assert tpu_labels["ray.io/tpu-worker-id"] == "0" + assert tpu_labels["ray.io/tpu-topology"] == "4x4" + assert tpu_labels["ray.io/tpu-pod-type"] == "v6e-16" + + +def test_get_current_node_tpu_topology_from_metadata(): + tpu_env_string = "TPU_ACCELERATOR:v6e.\nTOPOLOGY: '2x2x4'\nTPU_HOST_BOUNDS:0,1,1,2" + + with patch( + "ray._private.accelerators.tpu._get_tpu_metadata", return_value=tpu_env_string + ): + topology = TPUAcceleratorManager.get_current_node_tpu_topology() + assert topology == "2x2x4" + + +@pytest.mark.parametrize( + "topology, accelerator_type, expected_pod_type, should_raise", + [ + ("2x4", "TPU-V6E", "v6e-8", False), + ("2x2x2", "TPU-V4", "v4-8", False), + ("2x4x4", "TPU-V3", "v3-32", False), + ("4x4", "TPU-V5P", "v5p-16", False), + ("8x16", "TPU-V6E", "v6e-128", False), + ("", "TPU-V3", None, False), + ("4x", "TPU-V3", None, True), + ], +) +def test_infer_tpu_pod_type_from_topology( + topology, accelerator_type, expected_pod_type, should_raise +): + if should_raise: + with pytest.raises(ValueError): + tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type) + else: + actual_result = tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type) + assert actual_result == expected_pod_type + + +@pytest.fixture +def ray_start_cpu(): + address_info = ray.init(num_cpus=1) + yield address_info + ray.shutdown() + + +@pytest.fixture +def ray_tpu_cluster(monkeypatch): + """Start a mock TPU Ray cluster.""" + with _ray_start_cluster() as cluster: + monkeypatch.setenv("TPU_NAME", "test-slice-0") + monkeypatch.setenv("TPU_WORKER_ID", "0") + monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8") + monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2") + + # First slice - 2x2x2 with 2 TPU workers. + cluster.add_node( + num_cpus=2, + resources={"TPU": 4, "TPU-v4-8-head": 1}, + ) + monkeypatch.setenv("TPU_WORKER_ID", "1") + cluster.add_node( + num_cpus=2, + resources={"TPU": 4}, + ) + + # Second slice - 2x2x2 with 2 TPU workers. + monkeypatch.setenv("TPU_NAME", "test-slice-1") + monkeypatch.setenv("TPU_WORKER_ID", "0") + cluster.add_node( + num_cpus=2, + resources={"TPU": 4, "TPU-v4-8-head": 1}, + ) + monkeypatch.setenv("TPU_WORKER_ID", "1") + cluster.add_node( + num_cpus=2, + resources={"TPU": 4}, + ) + + ray.init(address=cluster.address) + + yield cluster + ray.shutdown() + + +def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster): + """Tests that the slice name can be fetched from a PG.""" + tpu_head_pg = ray.util.placement_group(bundles=[{"TPU-v4-8-head": 1}]) + ray.get(tpu_head_pg.ready()) + + expected_unique_slice_names = {"test-slice-0", "test-slice-1"} + slice_name = tpu.fetch_tpu_slice_name_from_pg(tpu_head_pg) + assert slice_name in expected_unique_slice_names + + ray.util.remove_placement_group(tpu_head_pg) + + +def test_reserve_tpu_slice(ray_tpu_cluster): + """Tests that a TPU slice can be successfully reserved.""" + reserved_name_0 = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4") + reserved_name_1 = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4") + assert ( + reserved_name_0 != reserved_name_1 + ), f"Expected to reserve two different slices, but got the same name: {reserved_name_0}" + expected_unique_slice_names = {"test-slice-0", "test-slice-1"} + actual_reserved_names = {reserved_name_0, reserved_name_1} + assert actual_reserved_names == expected_unique_slice_names, ( + f"Got unexpected slice names. Expected {expected_unique_slice_names}, " + f"but got {actual_reserved_names}" + ) + + +def test_slice_placement_group(ray_tpu_cluster): + """Test that single TPU slice can be successfully reserved.""" + slice_placement_group = ray.util.tpu.slice_placement_group( + topology="2x2x2", + accelerator_version="v4", + ) + assert slice_placement_group.chips_per_host == 4 + assert slice_placement_group.num_workers == 2 + assert slice_placement_group.placement_group.bundle_count == 2 + assert slice_placement_group.placement_group.bundle_specs == [ + {"TPU": 4}, + {"TPU": 4}, + ] + + +def test_multi_slice_placement_group(ray_tpu_cluster): + """Test that multiple whole TPU slices can be successfully reserved""" + multi_slice_placement_group = ray.util.tpu.slice_placement_group( + topology="2x2x2", + accelerator_version="v4", + num_slices=2, + ) + assert multi_slice_placement_group.placement_group.bundle_count == 4 + assert multi_slice_placement_group.num_workers == 4 + assert multi_slice_placement_group.placement_group.bundle_specs == [ + {"TPU": 4}, # slice 1 + {"TPU": 4}, + {"TPU": 4}, # slice 2 + {"TPU": 4}, + ] + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/util/accelerators/__init__.py b/python/ray/util/accelerators/__init__.py index 6c757121207b..53d6a501fbaa 100644 --- a/python/ray/util/accelerators/__init__.py +++ b/python/ray/util/accelerators/__init__.py @@ -1,6 +1,6 @@ import warnings -from ray.util.accelerators import tpu +from ray.util import tpu from ray.util.accelerators.accelerators import ( AMD_INSTINCT_MI100, AMD_INSTINCT_MI210, diff --git a/python/ray/util/accelerators/tpu.py b/python/ray/util/accelerators/tpu.py deleted file mode 100644 index 4ba38bf46500..000000000000 --- a/python/ray/util/accelerators/tpu.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional - -from ray._private.accelerators import TPUAcceleratorManager -from ray.util.annotations import PublicAPI - - -@PublicAPI(stability="alpha") -def get_current_pod_name() -> Optional[str]: - """ - Return the name of the TPU pod that the worker is a part of. - - Returns: - The name of the TPU pod. Returns None if not part of a TPU pod. - """ - tpu_name = TPUAcceleratorManager.get_current_node_tpu_name() - if tpu_name == "": - tpu_name = None - return tpu_name - - -@PublicAPI(stability="alpha") -def get_current_pod_worker_count() -> Optional[int]: - """ - Count the number of workers associated with the TPU pod that the worker belongs to. - - Returns: - The total number of workers in the TPU pod. Returns None if the worker is not - part of a TPU pod. - """ - return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod() - - -@PublicAPI(stability="alpha") -def get_num_tpu_chips_on_node() -> int: - """ - Return the number of TPU chips on the node. - Returns: - The total number of chips on the TPU node. Returns 0 if none are found. - """ - return TPUAcceleratorManager.get_current_node_num_accelerators() diff --git a/python/ray/util/tpu.py b/python/ray/util/tpu.py new file mode 100644 index 000000000000..49cd8a2581e2 --- /dev/null +++ b/python/ray/util/tpu.py @@ -0,0 +1,253 @@ +from typing import Optional + +import ray +from ray._private.accelerators import TPUAcceleratorManager +from ray._private.accelerators.tpu import ( + VALID_TPU_TYPES, + get_chips_per_host, + reserve_tpu_slice, +) +from ray._private.client_mode_hook import client_mode_wrap +from ray.util.annotations import PublicAPI +from ray.util.placement_group import PlacementGroup, placement_group + + +@PublicAPI(stability="alpha") +def get_current_pod_name() -> Optional[str]: + """ + Return the name of the TPU pod that the worker is a part of. + + Returns: + The name of the TPU pod. Returns None if not part of a TPU pod. + """ + tpu_name = TPUAcceleratorManager.get_current_node_tpu_name() + if tpu_name == "": + tpu_name = None + return tpu_name + + +@PublicAPI(stability="alpha") +def get_current_pod_worker_count() -> Optional[int]: + """ + Count the number of workers associated with the TPU pod that the worker belongs to. + + Returns: + The total number of workers in the TPU pod. Returns None if the worker is not + part of a TPU pod. + """ + return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod() + + +@PublicAPI(stability="alpha") +def get_num_tpu_chips_on_node() -> int: + """ + Return the number of TPU chips on the node. + Returns: + The total number of chips on the TPU node. Returns 0 if none are found. + """ + return TPUAcceleratorManager.get_current_node_num_accelerators() + + +@PublicAPI(stability="alpha") +class SlicePlacementGroup: + """ + A handle to a placement group reservation for a TPU slice. + + The following definitions are added for clarity: + + - Accelerator type: A string describing the accelerator type and version (e.g. TPU-V2, TPU-V6E). + - Accelerator version: The accelerator generation only (e.g. v6e, v5p, v5litepod). + - Pod type: The TPU accelerator version and the number of chips in a topology. (e.g. v6e-128, v5p-8). + - Accelerator topology: The physical topology representing the structure (e.g. 2x2x2, 16x16). + + Args: + topology: The TPU topology string (e.g. "2x2x2"). + accelerator_version: The TPU accelerator generation (e.g. "v6e", "v5p", "v4"). + strategy: PlacementGroup parameter. The strategy to create the placement group. Currently default to "SPREAD" + + - "PACK": Packs Bundles into as few nodes as possible. + - "SPREAD": Places Bundles across distinct nodes as even as possible. + - "STRICT_PACK": Packs Bundles into one node. The group is + not allowed to span multiple nodes. + - "STRICT_SPREAD": Packs Bundles across distinct nodes. + + lifetime: PlacementGroup parameter. Either `None`, which defaults to the placement group + will fate share with its creator and will be deleted once its + creator is dead, or "detached", which means the placement group + will live as a global object independent of the creator. + + num_slices: Number of TPU slices in the SlicePlacementGroup. Defaults to 1 when unspecified. + + Examples: + + .. testcode:: python + :skipif: True + + import ray + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + from ray.util.tpu import SlicePlacementGroup + + slice_handle = SlicePlacementGroup(topology="4x4", accelerator_version="v6e") + slice_pg = slice_handle.placement_group + ray.get(slice_pg.ready(), timeout=10) + + @ray.remote(num_cpus=0, resources={'TPU': 4}) + def spmd_task(world, rank): + print(f"Current TPU is rank {rank} of {world}") + + tasks = [ + spmd_task.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=slice_pg, + ) + ).remote(world=4, rank=i) + for i in range(slice_handle.num_workers) + ] + + """ + + def __init__( + self, + topology: str, + accelerator_version: str, + # below are args related to PG + strategy: str = "SPREAD", + name: str = "", + lifetime: Optional[str] = None, + # default + num_slices=1, + ): + self._topology = topology.strip().lower() + self._accelerator_version = accelerator_version.strip().lower() + self._num_slices = num_slices + self._validate_tpu_config() + + # Reserve a TPU slice of the provided accelerator version and topology. + self._placement_group = self._reserve_slice( + strategy, + name, + lifetime, + ) + + def _accelerator_version_check(self, accelerator_version: str): + if accelerator_version not in VALID_TPU_TYPES: + raise ValueError( + f"Invalid accelerator version: {accelerator_version}. Must be one of: {VALID_TPU_TYPES}" + ) + + def _validate_tpu_config(self): + # Should validate topology and generation values, calculate and + # set self._num_workers, and self._chips_per_host, and return a + # ValueError if invalid. + self._accelerator_version_check(self.accelerator_version) + if not TPUAcceleratorManager.is_valid_tpu_accelerator_topology( + tpu_accelerator_version=self.accelerator_version, + tpu_topology=self._topology, + ): + raise ValueError( + f"Invalid accelerator topology: '{self._topology}' for " + f"accelerator version: '{self.accelerator_version}'" + ) + + total_chips = 1 + for value in self._topology.strip().lower().split("x"): + total_chips *= int(value) + + self._chips_per_host = get_chips_per_host( + self._topology, self.accelerator_version + ) + self._num_workers_per_slice = total_chips // self._chips_per_host + self._num_workers = self._num_workers_per_slice * self._num_slices + + def _reserve_slice( + self, + strategy: str = "SPREAD", + name: str = "", + lifetime: Optional[str] = None, + ) -> PlacementGroup: + """Performs the two-step scheduling to reserve a TPU slice.""" + bundle_label_selector = [] + bundles = [] + + # Construct accelerator format for reserve_tpu_slice. e.g. From "v6e" to "TPU-V6E", "v5p" to "TPU-V5P". + accelerator_type = "TPU-" + self.accelerator_version.upper() + for _ in range(self.num_slices): + # Reserving a slice is done through constructing num_workers bundles, each with a label selector for + # the unique name of an available TPU slice. + slice_name = reserve_tpu_slice(self._topology, accelerator_type) + bundle_label_selector += [ + {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name} + ] * self._num_workers_per_slice + bundles += [{"TPU": self._chips_per_host}] * self._num_workers_per_slice + + pg = placement_group( + bundles=bundles, + strategy=strategy, + name=name, + lifetime=lifetime, + bundle_label_selector=bundle_label_selector, + ) + + return pg + + @property + def placement_group(self) -> PlacementGroup: + """The underlying PlacementGroup object.""" + return self._placement_group + + @property + def chips_per_host(self) -> int: + """The number of chips per host for this TPU slice.""" + # This is the same value as resources per worker for TPU. + return self._chips_per_host + + @property + def num_workers(self) -> int: + """The total number of hosts in the SlicePlacementGroup.""" + return self._num_workers + + @property + def topology(self) -> str: + """The physical topology of the TPU slice.""" + return self._topology + + @property + def accelerator_version(self) -> str: + """The TPU accelerator type of the slice.""" + return self._accelerator_version + + @property + def num_slices(self) -> int: + """The number of TPU slices this SlicePlacementGroup spans.""" + return self._num_slices + + +@PublicAPI(stability="alpha") +@client_mode_wrap +def slice_placement_group( + topology: str, + accelerator_version: str, + num_slices: int = 1, + **kwargs, +) -> SlicePlacementGroup: + """Asynchronously creates a PlacementGroup for a TPU slice. + + A slice placement group reserves num_slices TPU slice(s) and creates a placement + group for scheduling tasks. + + Args: + topology: The desired TPU pod topology (e.g. "4x4", "2x8"). + accelerator_version: The TPU accelerator generation, (e.g. "V4", "V5P", "V6E"). + num_slices: The number of tpu slices within the placement group + **kwargs: Additional arguments for the placement group, such as 'name', 'lifetime', or 'strategy'. + + Returns: + The handle for the created SlicePlacementGroup. + """ + + return SlicePlacementGroup( + topology=topology, + accelerator_version=accelerator_version, + num_slices=num_slices, + **kwargs, + )