Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo",
backend=current_platform.dist_backend,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
Expand Down
26 changes: 13 additions & 13 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def create(


def init_gloo_process_group(
backend: Backend,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
Expand All @@ -432,7 +431,7 @@ def init_gloo_process_group(
group_size,
)
else:
options = ProcessGroup.Options(backend=backend)
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
Expand Down Expand Up @@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform

if backend == "gloo":
return init_gloo_process_group(
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
from vllm.platforms import current_platform

return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)


def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
Expand Down
34 changes: 0 additions & 34 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@

import os
from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar

import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec

# import custom ops, trigger op registration
Expand Down Expand Up @@ -455,37 +452,6 @@ def opaque_attention_op(cls) -> bool:
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)
return pg

@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
Expand Down
2 changes: 1 addition & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def stateless_init_device_torch_dist_pg(
"""
Init platform-specific torch distributed process group.
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
raise NotImplementedError

@classmethod
def is_kv_cache_dtype_supported(
Expand Down
34 changes: 0 additions & 34 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from datetime import timedelta
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING

import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -476,37 +473,6 @@ def is_navi(cls) -> bool:
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)
return pg

@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
Expand Down