Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import weakref
from collections.abc import Callable
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any
from typing import Any, ClassVar
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -60,6 +61,16 @@ class ACLGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""

_all_instances: ClassVar[weakref.WeakSet["ACLGraphWrapper"]] = weakref.WeakSet()
graph_pool: ClassVar[tuple[int, int]] = current_platform.get_global_graph_pool()

@classmethod
def clear_all_graphs(cls) -> None:
"""Clear all graphs from all ACLGraphWrapper instances."""
for instance in list(cls._all_instances):
instance.clear_graphs()
cls.graph_pool = (cls.graph_pool[0], cls.graph_pool[1] + 1)
Comment on lines +67 to +72
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_all_graphs is a new public classmethod that is now used during elastic EP reconfiguration, but there’s no unit test asserting that instances are registered and that clear_all_graphs() actually clears concrete_aclgraph_entries / bumps the graph pool. Adding a focused test in tests/ut/compilation/test_acl_graph.py would prevent regressions here.

Copilot uses AI. Check for mistakes.

def __init__(
self,
runnable: Callable,
Expand All @@ -79,7 +90,6 @@ def __init__(
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
# need to initialize a ACLGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
self.graph_pool = current_platform.get_global_graph_pool()

if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
Expand All @@ -88,6 +98,8 @@ def __init__(
# aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}

ACLGraphWrapper._all_instances.add(self)

def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
Expand All @@ -102,6 +114,13 @@ def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable

def clear_graphs(self) -> None:
for batch_descriptor in self.concrete_aclgraph_entries:
entry = self.concrete_aclgraph_entries[batch_descriptor]
entry.aclgraph.reset()
del entry.aclgraph, entry.batch_descriptor, entry.output, entry.input_addresses, entry
self.concrete_aclgraph_entries.clear()

def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
Expand Down Expand Up @@ -149,7 +168,7 @@ def __call__(self, *args, **kwargs):

# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool):
with torch.npu.graph(aclgraph, pool=ACLGraphWrapper.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs)
if self.aclgraph_options.weak_ref_output:
Expand Down
33 changes: 32 additions & 1 deletion vllm_ascend/distributed/device_communicators/npu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.utils import StatelessProcessGroup


class NPUCommunicator(DeviceCommunicatorBase):
Expand All @@ -27,12 +28,30 @@ def __init__(
device: torch.device | None = None,
device_group: dist.ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()

from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator

self.pyhccl_comm: PyHcclCommunicator | None = None
if self.world_size > 1:
self.pyhccl_comm = PyHcclCommunicator(
group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device
)

def all_to_all(
self,
input_: torch.Tensor,
Expand Down Expand Up @@ -62,3 +81,15 @@ def all_to_all(
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor

def destroy(self):
if self.pyhccl_comm is not None:
self.pyhccl_comm.destroy()
self.pyhccl_comm = None

def batch_isend_irecv(self, p2p_ops: list):
pyhccl_comm = self.pyhccl_comm
if pyhccl_comm is not None and not pyhccl_comm.disabled:
pyhccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyHccl communicator found")
57 changes: 51 additions & 6 deletions vllm_ascend/distributed/device_communicators/pyhccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def __init__(
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank)

stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def destroy(self):
if self.available and not self.disabled:
with torch.accelerator.device_index(self.device.index):
self.hccl.hcclCommDestroy(self.comm)
self.available = False
self.disabled = True

def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
if self.disabled:
Expand All @@ -152,6 +152,40 @@ def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, strea
)
return out_tensor

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return None
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, but the tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.hccl.hcclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
dst,
self.comm,
aclrtStream_t(stream.npu_stream),
)

def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return None
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, but the tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.hccl.hcclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
aclrtStream_t(stream.npu_stream),
)

def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
Expand All @@ -172,3 +206,14 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
self.comm,
aclrtStream_t(stream.npu_stream),
)

def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
54 changes: 54 additions & 0 deletions vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,38 @@ class HCCLLibrary:
aclrtStream_t,
],
),
# HcclResult HcclSend(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream steam);
Function(
"HcclSend",
hcclResult_t,
[
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclRecv(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream steam);
Function(
"HcclRecv",
hcclResult_t,
[
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
],
),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
Expand Down Expand Up @@ -243,6 +275,28 @@ def hcclAllReduce(
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))

def hcclSend(
self,
sendbuff: buffer_type,
count: int,
datatype: int,
dest: int,
comm: hcclComm_t,
stream: aclrtStream_t,
) -> None:
self.HCCL_CHECK(self._funcs["HcclSend"](sendbuff, count, datatype, dest, comm, stream))

def hcclRecv(
self,
sendbuff: buffer_type,
count: int,
datatype: int,
dest: int,
comm: hcclComm_t,
stream: aclrtStream_t,
) -> None:
self.HCCL_CHECK(self._funcs["HcclRecv"](sendbuff, count, datatype, dest, comm, stream))

def hcclBroadcast(
self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t
) -> None:
Expand Down
Empty file.
Loading