Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm_ascend/distributed/kvpool/config_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union

import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.logger import logger
Expand Down Expand Up @@ -284,6 +285,8 @@ class ReqMeta:

is_last_chunk: Optional[bool] = None

current_event: Optional[torch.npu.Event] = None

@staticmethod
def from_request_tracker(
tracker: RequestTracker,
Expand Down Expand Up @@ -375,3 +378,4 @@ class LasyerMultiBlockReqMeta:
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None
12 changes: 12 additions & 0 deletions vllm_ascend/distributed/kvpool/kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _handle_request(self, req_meta: ReqMeta):
block_ids = req_meta.block_ids
req_id = req_meta.req_id
is_last_chunk = req_meta.is_last_chunk
current_event = req_meta.current_event
starts = []
ends = []
keys = []
Expand Down Expand Up @@ -161,6 +162,14 @@ def _handle_request(self, req_meta: ReqMeta):
addrs.append(addr)
sizes.append(size)
if keys:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
if current_event is not None:
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)

if is_last_chunk:
Expand Down Expand Up @@ -235,6 +244,7 @@ def _handle_request( # type: ignore[override]
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
current_event = req_meta.current_event
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
Expand Down Expand Up @@ -270,6 +280,8 @@ def _handle_request( # type: ignore[override]
addr_list.append(addr)
size_list.append(size)

if current_event is not None:
current_event.synchronize()
self.m_store.put(key_list, addr_list, size_list)

if layer_id == self.final_layer_id and is_last_chunk:
Expand Down
24 changes: 22 additions & 2 deletions vllm_ascend/distributed/kvpool/pool_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,20 @@ def save_kv_layer(self,
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue

layerwise_storer = self.store_layer(request)
layerwise_storer = self.store_layer(request, current_event)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
Expand All @@ -266,11 +274,21 @@ def save_kv_layer(self,
self.current_layer = self.current_layer + 1

def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break

for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue

request.current_event = current_event
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )

Expand Down Expand Up @@ -347,6 +365,7 @@ def retrieve_layer(
def store_layer(
self,
request: ReqMeta,
current_event: Optional[torch.npu.Event],
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
Expand Down Expand Up @@ -385,7 +404,8 @@ def store_layer(
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk)
request.is_last_chunk,
current_event)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
Expand Down
Loading