Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
# (3)
"vllm_ascend/attention/*.py",
"vllm_ascend/core/*.py",
"vllm_ascend/distributed/device_communicators/**",
"vllm_ascend/distributed/utils.py",
# (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**",
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc]
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
Expand Down
891 changes: 407 additions & 484 deletions vllm_ascend/attention/mla_v1.py

Large diffs are not rendered by default.

541 changes: 242 additions & 299 deletions vllm_ascend/attention/sfa_v1.py

Large diffs are not rendered by default.

241 changes: 99 additions & 142 deletions vllm_ascend/core/recompute_scheduler.py

Large diffs are not rendered by default.

197 changes: 88 additions & 109 deletions vllm_ascend/core/scheduler_dynamic_batch.py

Large diffs are not rendered by default.

51 changes: 20 additions & 31 deletions vllm_ascend/distributed/device_communicators/npu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,61 +14,50 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, Optional

import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase


class NPUCommunicator(DeviceCommunicatorBase):

def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: dist.ProcessGroup,
device: torch.device | None = None,
device_group: dist.ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()

def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:

def all_to_all(
self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: list[int] | None = None,
gather_sizes: list[int] | None = None,
) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()

if scatter_sizes is not None and gather_sizes is not None:
input_list = [
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))

else:
input_list = [
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]

dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
Expand Down
63 changes: 36 additions & 27 deletions vllm_ascend/distributed/device_communicators/pyhccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

from typing import Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -24,18 +23,23 @@
from vllm.logger import logger

from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
HCCLLibrary,
aclrtStream_t,
buffer_type,
hcclComm_t,
hcclDataTypeEnum,
hcclRedOpTypeEnum,
hcclUniqueId,
)
from vllm_ascend.utils import current_stream


class PyHcclCommunicator:

def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
group: ProcessGroup | StatelessProcessGroup,
device: int | str | torch.device,
library_path: str | None = None,
):
"""
Args:
Expand All @@ -52,7 +56,8 @@ def __init__(
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
"PyHcclCommunicator should be attached to a non-HCCL group."
)
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
Expand Down Expand Up @@ -113,8 +118,7 @@ def __init__(
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)

stream = current_stream()
# A small all_reduce for warmup.
Expand All @@ -123,43 +127,48 @@ def __init__(
stream.synchronize()
del data

def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}"
)

out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
self.hccl.hcclAllReduce(
buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op),
self.comm,
aclrtStream_t(stream.npu_stream),
)
return out_tensor

def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))
self.hccl.hcclBroadcast(
buffer,
tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
aclrtStream_t(stream.npu_stream),
)
Loading