Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 14 additions & 56 deletions vllm_ascend/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase


class NPUCommunicator:
class NPUCommunicator(DeviceCommunicatorBase):

def __init__(self, group, unique_name=""):
self.group = group
self.unique_name = unique_name
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(self.group)
self.ranks = dist.get_process_group_ranks(self.group)
global_rank = dist.get_rank()
self.rank_in_group = dist.get_group_rank(self.group, global_rank)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
# NOTE: We assume that the input tensor is on the same device across
# all the ranks.
# NOTE: `dst` is the local rank of the destination rank.
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [
torch.empty_like(input_) for _ in range(self.world_size)
]
else:
gather_list = None
# Gather.
dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# init device according to local rank
local_rank = dist.get_rank(device_group)
self.device = torch.device(f"npu:{local_rank}")
18 changes: 0 additions & 18 deletions vllm_ascend/patch/__init__.py

This file was deleted.

69 changes: 0 additions & 69 deletions vllm_ascend/patch/patch_commnicator.py

This file was deleted.

2 changes: 0 additions & 2 deletions vllm_ascend/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,6 @@ def init_worker_distributed_environment(
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# register communicator patch before init dist env
from vllm_ascend import patch # noqa: F401

init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
Expand Down