Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions doc/source/ray-core/api/utility.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ Utility
ray.util.serialization.register_serializer
ray.util.serialization.deregister_serializer

ray.util.accelerators.tpu.get_current_pod_worker_count
ray.util.accelerators.tpu.get_current_pod_name
ray.util.accelerators.tpu.get_num_tpu_chips_on_node
ray.util.tpu.get_current_pod_worker_count
ray.util.tpu.get_current_pod_name
ray.util.tpu.get_num_tpu_chips_on_node
ray.util.tpu.SlicePlacementGroup
ray.util.tpu.slice_placement_group

ray.nodes
ray.cluster_resources
Expand Down
108 changes: 101 additions & 7 deletions python/ray/_private/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
from functools import lru_cache
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple

import requests

Expand Down Expand Up @@ -64,6 +64,43 @@
# The valid TPU types.
VALID_TPU_TYPES = ("v2", "v3", "v4", "v5p", "v5litepod", "v6e")

# This is only used to construct TPU 3D topologies
def _get_larger_3d_topologies(max_x: int, max_y: int, max_z: int) -> Set[str]:
"""Returns a set of larger 3D TPU topologies given the max x,y,z value. Using DEFAULT_TPU_NUM_CHIPS_PER_HOST as increment"""
topologies = set()
for x in range(
DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_x + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST
):
for y in range(
DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_y + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST
):
for z in range(
DEFAULT_TPU_NUM_CHIPS_PER_HOST,
max_z + 1,
DEFAULT_TPU_NUM_CHIPS_PER_HOST,
):
topologies.add(f"{x}x{y}x{z}")

return topologies


# The valid TPU topologies for each of the TPU types
VALID_TPU_TOPOLOGY = {
"v2": {"4x4", "4x8", "8x8", "8x16", "16x16"},
"v3": {"4x4", "4x8", "8x8", "8x16", "16x16", "16x32", "32x32"},
"v4": {"2x2x1", "2x2x2", "2x2x4", "2x4x4"}.union(
_get_larger_3d_topologies(12, 12, 16)
),
"v5p": {
"2x2x1",
"2x2x2",
"2x2x4",
"2x4x4",
}.union(_get_larger_3d_topologies(16, 16, 24)),
"v5litepod": {"2x8", "4x4", "4x8", "8x8", "8x16", "16x16"},
"v6e": {"2x8", "4x4", "4x8", "8x8", "8x16", "16x16"},
}


def _get_tpu_metadata(key: str) -> Optional[str]:
"""Poll and get TPU metadata."""
Expand Down Expand Up @@ -115,17 +152,19 @@ def infer_tpu_pod_type_from_topology(
topology: str, accelerator_type: str
) -> Optional[str]:
"""Infer the TPU pod type (e.g. v4-32) from topology and accelerator type."""
if not topology or not accelerator_type:
return None
try:
num_chips = 1
for value in topology.strip().lower().split("x"):
num_chips *= int(value)
generation = accelerator_type.lower().replace("tpu-", "")
return f"{generation}-{num_chips}"
except Exception as e:
logger.warning(
f"Failed to infer pod type from topology {topology} and type {accelerator_type}: {e}"
)
return None
raise ValueError(
f"Failed to infer pod type from topology '{topology}' "
f"and type '{accelerator_type}'"
) from e


def fetch_tpu_slice_name_from_pg(pg):
Expand All @@ -142,6 +181,35 @@ def _get_tpu_slice_name():
return ray.get(tpu_name_ref)


def get_chips_per_host(topology: str, accelerator_version: str) -> int:
"""Get the number of chips per host (aka VMs) based on topology and accelerator version.
The current rule is as follows:
Default chips per host is 4.
If accelerator_version is v5e or v6e AND topology product <= 8, the chips per host will just be the proudct. i.e. 1, 4, or 8
If accelerator_version is v5e or v6e AND topology product > 8, the chips per host will be 4
If accelerator_version is v5p or other versions, the chips per host will be 4

Args:
topology: The TPU topology string (e.g. "2x2x2").
accelerator_version: The accelerator version of the node (e.g. "V4", "v4").

Returns:
A int representing the number of chips per host (aka VM)
"""
chips_per_host = DEFAULT_TPU_NUM_CHIPS_PER_HOST
total_chips = 1
for value in topology.strip().lower().split("x"):
total_chips *= int(value)

if (
total_chips <= 8
and accelerator_version.strip().lower() in SINGLE_HOST_8_CHIPS_TPU_TYPES
):
return total_chips

return chips_per_host


def reserve_tpu_slice(
topology: str,
accelerator_type: str,
Expand Down Expand Up @@ -265,6 +333,32 @@ def is_valid_tpu_accelerator_type(tpu_accelerator_type: str) -> bool:
return False
return True

@staticmethod
def is_valid_tpu_accelerator_topology(
tpu_accelerator_version: str, tpu_topology: str
) -> bool:
"""Check whether the tpu topology is valid.

The accelerator_type field follows a form of v{generation}.
The accelerator_topology field follows either the form {A}x{B} or {A}x{B}x{C} depending on the v{generation}

Args:
tpu_accelerator_version: The string representation of the accelerator version. (e.g. v6e, V5P)
tpu_topology: The string representation of the accelerator topology
to be checked for validity

Returns:
True if it's valid topology, false othrwise
"""
tpu_version_formatted = tpu_accelerator_version.strip().lower().split("-")[0]
if (
tpu_version_formatted.lower() not in VALID_TPU_TOPOLOGY
or tpu_topology.strip().lower()
not in VALID_TPU_TOPOLOGY[tpu_version_formatted]
):
return False
return True

@staticmethod
def validate_resource_request_quantity(
quantity: float,
Expand Down Expand Up @@ -510,8 +604,8 @@ def my_jax_fn():
@ray.remote(resources={"TPU-v4-16-head"})
def run_jax_fn(executable):
# Note this will execute on worker 0
tpu_name = ray.util.accelerators.tpu.get_tpu_pod_name()
num_workers = ray.util.accelerators.tpu.get_tpu_num_workers()
tpu_name = ray.util.tpu.get_tpu_pod_name()
num_workers = ray.util.tpu.get_tpu_num_workers()
tpu_executable = executable.options(resources={"TPU": 4, tpu_name: 1})
return [tpu_executable.remote() for _ in range(num_workers)]

Expand Down
1 change: 1 addition & 0 deletions python/ray/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ py_test_module_list(
"test_task_events_3.py",
"test_task_metrics_reconstruction.py",
"test_top_level_api.py",
"test_tpu.py",
"test_tqdm.py",
"test_unhandled_error.py",
"test_wait.py",
Expand Down
165 changes: 0 additions & 165 deletions python/ray/tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import pytest
import requests

import ray
from ray._private.accelerators import TPUAcceleratorManager, tpu
from ray.tests.conftest import _ray_start_cluster


@patch("glob.glob")
Expand Down Expand Up @@ -262,168 +260,5 @@ def test_tpu_pod_detect_and_configure_worker(test_config):
assert final_resources == expected_value


def test_get_current_pod_name_smoke():
with patch(
"ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name",
return_value="my-tpu",
):
name = ray.util.accelerators.tpu.get_current_pod_name()
assert name == "my-tpu"


def test_empty_get_current_pod_name_returns_none():
with patch(
"ray._private.accelerators.tpu.TPUAcceleratorManager.get_current_node_tpu_name",
return_value="",
):
name = ray.util.accelerators.tpu.get_current_pod_name()
assert name is None


@pytest.mark.parametrize(
"test_case",
[
# (number_chips_per_host, accl_type, expected_worker_count)
(4, "v2-4", 1),
(4, "v3-32", 4),
(4, "v4-8", 1),
(4, "v4-16", 2),
(8, "v5litepod-4", 1),
(8, "v5litepod-8", 1),
(8, "v5litepod-16", 2),
(8, "v5litepod-32", 4),
(4, "v5p-4", 1),
(4, "v5p-8", 1),
(4, "v5p-16", 2),
(8, "v6e-4", 1),
(8, "v6e-8", 1),
(8, "v6e-16", 2),
],
)
@patch("glob.glob")
def test_worker_count(mock_glob, test_case):
num_devices, accelerator_type, expected_worker_count = test_case
mock_glob.return_value = ["/dev/accel" + str(x) for x in range(num_devices)]
TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear()

with patch(
"ray._private.accelerators.tpu.TPUAcceleratorManager."
"get_current_node_tpu_pod_type",
return_value=accelerator_type,
):
worker_count = ray.util.accelerators.tpu.get_current_pod_worker_count()

assert worker_count == expected_worker_count


@patch("glob.glob")
def test_num_tpu_chips(mock_glob):
mock_glob.return_value = [
"/dev/accel0",
"/dev/accel1",
"/dev/accel2",
"/dev/accel3",
]
TPUAcceleratorManager.get_current_node_num_accelerators.cache_clear()
num_tpu_chips = ray.util.accelerators.tpu.get_num_tpu_chips_on_node()
assert num_tpu_chips == 4


def test_get_current_node_labels_env_only(monkeypatch):
# Simulate GKE TPU environment variables
monkeypatch.setenv("TPU_NAME", "tpu-worker-group-2")
monkeypatch.setenv("TPU_WORKER_ID", "0")
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-16")
monkeypatch.setenv("TPU_TOPOLOGY", "4x4")

tpu_labels = TPUAcceleratorManager.get_current_node_accelerator_labels()

assert tpu_labels["ray.io/tpu-slice-name"] == "tpu-worker-group-2"
assert tpu_labels["ray.io/tpu-worker-id"] == "0"
assert tpu_labels["ray.io/tpu-topology"] == "4x4"
assert tpu_labels["ray.io/tpu-pod-type"] == "v6e-16"


def test_get_current_node_tpu_topology_from_metadata():
tpu_env_string = "TPU_ACCELERATOR:v6e.\nTOPOLOGY: '2x2x4'\nTPU_HOST_BOUNDS:0,1,1,2"

with patch(
"ray._private.accelerators.tpu._get_tpu_metadata", return_value=tpu_env_string
):
topology = TPUAcceleratorManager.get_current_node_tpu_topology()
assert topology == "2x2x4"


@pytest.mark.parametrize(
"topology, accelerator_type, expected_pod_type",
[
("2x4", "TPU-V6E", "v6e-8"),
("2x2x2", "TPU-V4", "v4-8"),
("2x4x4", "TPU-V3", "v3-32"),
("4x4", "TPU-V5P", "v5p-16"),
("8x16", "TPU-V6E", "v6e-128"),
("", "TPU-V3", None),
("4x", "TPU-V3", None),
],
)
def test_infer_tpu_pod_type_from_topology(
topology, accelerator_type, expected_pod_type
):
assert (
tpu.infer_tpu_pod_type_from_topology(topology, accelerator_type)
== expected_pod_type
)


@pytest.fixture
def ray_start_cpu():
address_info = ray.init(num_cpus=1)
yield address_info
ray.shutdown()


@pytest.fixture
def ray_tpu_cluster(monkeypatch):
"""Start a mock TPU Ray cluster."""
with _ray_start_cluster() as cluster:
monkeypatch.setenv("TPU_NAME", "test-slice-0")
monkeypatch.setenv("TPU_WORKER_ID", "0")
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v4-8")
monkeypatch.setenv("TPU_TOPOLOGY", "2x2x2")

cluster.add_node(
num_cpus=2,
resources={"TPU": 4, "TPU-v4-8-head": 1},
)
monkeypatch.setenv("TPU_WORKER_ID", "1")
cluster.add_node(
num_cpus=2,
resources={"TPU": 4},
)
ray.init(address=cluster.address)

yield cluster
ray.shutdown()


def test_fetch_tpu_slice_name_from_pg(ray_tpu_cluster):
"""Tests that the slice name can be fetched from a PG."""
tpu_head_pg = ray.util.placement_group(bundles=[{"TPU-v4-8-head": 1}])
ray.get(tpu_head_pg.ready())

tpu_slice_name = "test-slice-0"
slice_name = tpu.fetch_tpu_slice_name_from_pg(tpu_head_pg)
assert slice_name == tpu_slice_name

ray.util.remove_placement_group(tpu_head_pg)


def test_reserve_tpu_slice(ray_tpu_cluster):
"""Tests that a TPU slice can be successfully reserved."""
tpu_slice_name = "test-slice-0"
reserved_name = tpu.reserve_tpu_slice(topology="2x2x2", accelerator_type="TPU-V4")
assert reserved_name == tpu_slice_name


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
Loading