Skip to content

Commit 2148441

Browse files
authored
[TPU] Support single and multi-host TPUs on GKE (vllm-project#7613)
1 parent dc13e99 commit 2148441

File tree

5 files changed

+74
-4
lines changed

5 files changed

+74
-4
lines changed

requirements-tpu.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# Dependencies for TPU
55
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
66
# You can install the dependencies in Dockerfile.tpu.
7-
ray
7+
ray[default]

vllm/attention/backends/pallas.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def __init__(
123123
raise NotImplementedError("TPU version must be 4 or higher.")
124124

125125
self.megacore_mode = None
126-
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
126+
tpu_env = torch_xla.tpu.get_tpu_env()
127+
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
128+
tpu_type = tpu_type.lower()
129+
127130
if "lite" not in tpu_type:
128131
if self.num_kv_heads % 2 == 0:
129132
self.megacore_mode = "kv_head"

vllm/distributed/device_communicators/tpu_communicator.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import os
2+
13
import torch
24
import torch.distributed as dist
35
from torch.distributed import ProcessGroup
46

57
from vllm.platforms import current_platform
68

79
if current_platform.is_tpu():
8-
import ray
910
import torch_xla.core.xla_model as xm
1011
import torch_xla.runtime as xr
1112
from torch_xla._internal import pjrt
1213

14+
from vllm.executor import ray_utils
15+
1316

1417
class TpuCommunicator:
1518

@@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup):
2427
# be simply calculated as follows.
2528
global_rank = dist.get_rank(group)
2629
global_world_size = dist.get_world_size(group)
27-
num_nodes = len(ray.nodes())
30+
31+
# Calculate how many TPU nodes are in the current deployment. This
32+
# is the Ray placement group if it is deployed with Ray. Default
33+
# to the number of TPU nodes in the Ray cluster. The number of TPU
34+
# nodes is computed by the total number of TPUs divided by the
35+
# number of TPU accelerators per node, to account for clusters
36+
# with both CPUs and TPUs.
37+
num_nodes = ray_utils.get_num_tpu_nodes()
38+
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
39+
if num_nodes_in_pg > 0:
40+
num_nodes = num_nodes_in_pg
41+
2842
local_world_size = global_world_size // num_nodes
2943
local_rank = global_rank % local_world_size
44+
45+
# Ensure environment variables are set for multihost deployments.
46+
# On GKE, this is needed for libtpu and TPU driver to know which TPU
47+
# chip is actually visible. Otherwise the TPU driver will fail to
48+
# initialize because the number of devices would be different from
49+
# the number of visible worker addresses.
50+
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
51+
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
52+
3053
pjrt.initialize_multiprocess(local_rank, local_world_size)
3154
xr._init_world_size_ordinal()
3255

vllm/executor/ray_tpu_executor.py

+15
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
7171
worker_module_name = "vllm.worker.tpu_worker"
7272
worker_class_name = "TPUWorker"
7373

74+
# GKE does not fetch environment information from metadata server
75+
# and instead sets these from within the Ray process. Therefore we
76+
# need to override the Ray environment variables manually.
77+
override_env = {}
78+
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
79+
override_env.update({
80+
"TPU_CHIPS_PER_HOST_BOUNDS":
81+
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
82+
})
83+
if "TPU_HOST_BOUNDS" in os.environ:
84+
override_env.update(
85+
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
86+
7487
worker = ray.remote(
7588
num_cpus=0,
7689
resources={"TPU": 1},
@@ -81,6 +94,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
8194
worker_class_name=worker_class_name,
8295
trust_remote_code=self.model_config.trust_remote_code,
8396
)
97+
if override_env:
98+
worker.override_env_vars.remote(override_env)
8499

85100
worker_ip = ray.get(worker.get_node_ip.remote())
86101
if worker_ip == driver_ip and self.driver_dummy_worker is None:

vllm/executor/ray_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import time
23
from collections import defaultdict
34
from typing import Dict, List, Optional, Tuple, Union
@@ -84,6 +85,9 @@ def execute_model_spmd(
8485

8586
return output
8687

88+
def override_env_vars(self, vars: Dict[str, str]):
89+
os.environ.update(vars)
90+
8791
ray_import_err = None
8892

8993
except ImportError as e:
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
291295
_verify_bundles(current_placement_group, parallel_config, device_str)
292296
# Set the placement group in the parallel config
293297
parallel_config.placement_group = current_placement_group
298+
299+
300+
def get_num_tpu_nodes() -> int:
301+
from ray._private.accelerators import TPUAcceleratorManager
302+
cluster_resources = ray.cluster_resources()
303+
total_tpus = int(cluster_resources["TPU"])
304+
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
305+
assert total_tpus % tpus_per_node == 0
306+
return total_tpus // tpus_per_node
307+
308+
309+
def get_num_nodes_in_placement_group() -> int:
310+
pg_table = ray.util.placement_group_table()
311+
current_pg = ray.util.get_current_placement_group()
312+
num_nodes = 0
313+
314+
if current_pg:
315+
nodes_in_pg = set()
316+
for pg_key, pg in pg_table.items():
317+
if pg_key == current_pg.id.hex():
318+
for _, node in pg["bundles_to_node_id"].items():
319+
nodes_in_pg.add(node)
320+
num_nodes = len(nodes_in_pg)
321+
322+
return num_nodes

0 commit comments

Comments
 (0)