1
+ import os
2
+
1
3
import torch
2
4
import torch .distributed as dist
3
5
from torch .distributed import ProcessGroup
4
6
5
7
from vllm .platforms import current_platform
6
8
7
9
if current_platform .is_tpu ():
8
- import ray
9
10
import torch_xla .core .xla_model as xm
10
11
import torch_xla .runtime as xr
11
12
from torch_xla ._internal import pjrt
12
13
14
+ from vllm .executor import ray_utils
15
+
13
16
14
17
class TpuCommunicator :
15
18
@@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup):
24
27
# be simply calculated as follows.
25
28
global_rank = dist .get_rank (group )
26
29
global_world_size = dist .get_world_size (group )
27
- num_nodes = len (ray .nodes ())
30
+
31
+ # Calculate how many TPU nodes are in the current deployment. This
32
+ # is the Ray placement group if it is deployed with Ray. Default
33
+ # to the number of TPU nodes in the Ray cluster. The number of TPU
34
+ # nodes is computed by the total number of TPUs divided by the
35
+ # number of TPU accelerators per node, to account for clusters
36
+ # with both CPUs and TPUs.
37
+ num_nodes = ray_utils .get_num_tpu_nodes ()
38
+ num_nodes_in_pg = ray_utils .get_num_nodes_in_placement_group ()
39
+ if num_nodes_in_pg > 0 :
40
+ num_nodes = num_nodes_in_pg
41
+
28
42
local_world_size = global_world_size // num_nodes
29
43
local_rank = global_rank % local_world_size
44
+
45
+ # Ensure environment variables are set for multihost deployments.
46
+ # On GKE, this is needed for libtpu and TPU driver to know which TPU
47
+ # chip is actually visible. Otherwise the TPU driver will fail to
48
+ # initialize because the number of devices would be different from
49
+ # the number of visible worker addresses.
50
+ os .environ ["CLOUD_TPU_TASK_ID" ] = str (global_rank )
51
+ os .environ ["TPU_VISIBLE_CHIPS" ] = str (local_rank )
52
+
30
53
pjrt .initialize_multiprocess (local_rank , local_world_size )
31
54
xr ._init_world_size_ordinal ()
32
55
0 commit comments