Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions tests/ut/kv_connector/test_mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,36 +170,6 @@ def test_transfer_skips_when_no_local_blocks(self):
self.thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called()

def test_transfer_skips_when_tp_not_sender(self):

thread = KVCacheSendingLayerThread(
engine=self.engine,
total_layers=2,
ready_event=self.ready_event,
tp_rank=1,
pd_head_ratio=1,
num_head_replica=2,
kv_cache_base_addr=[1000, 2000, 3000, 4000],
use_mla=False,
block_len=[1024],
decode_tp_size=1,
first_kv_cache=self.first_kv_cache,
k_buffer=MagicMock(),
v_buffer=MagicMock(),
resharding_stream=MagicMock(),
callback_func=MagicMock())
req_meta = self.req_meta_base
send_task = SendTask(
send_request={"req3": req_meta},
wait_event=MagicMock(),
k_cache=self.key,
v_cache=self.value,
layer_idx=1,
rearrange_block_ids=[],
)
thread._transfer_kv_cache(send_task)
self.engine.batch_transfer_sync_write.assert_not_called()

@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
side_effect=group_concurrent_contiguous)
Expand Down Expand Up @@ -425,6 +395,7 @@ def __init__(self):
self.parallel_config.data_parallel_size = 1
self.parallel_config.data_parallel_rank = 0
self.cache_config.block_size = 16
self.model_config.hf_config.num_key_value_heads = 1

self.kv_transfer_config.engine_id = "test_engine"
self.kv_transfer_config.kv_port = 5000
Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class AscendSFAMetadata:
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None
reshape_cache_event: torch.npu.Event = None


M = TypeVar("M", bound=AscendSFAMetadata)
Expand Down Expand Up @@ -371,6 +372,7 @@ def __init__(
self.enable_sfa_cp = enable_sp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size

Expand Down Expand Up @@ -901,11 +903,15 @@ def indexer_select(
k = get_tp_group().all_gather(k, 0)

if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()

weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
Expand Down
222 changes: 138 additions & 84 deletions vllm_ascend/distributed/mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
self.kv_caches_base_addr = kv_cache_base_addr
self.total_layers = total_layers
self.use_mla = use_mla
self.use_sparse = len(block_len) == 3
self.block_len = block_len
self._decode_tp_size = decode_tp_size
self.resharding_stream = resharding_stream
Expand Down Expand Up @@ -195,17 +196,6 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str,
src_list: list[str] = []
dst_list: list[str] = []
length_list: list[int] = []
# not need to send kv cache
if self.tp_rank % self.num_head_replica != 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
)
return (src_list, dst_list, length_list)
if self.use_mla and self.tp_rank >= self._decode_tp_size:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
)
return (src_list, dst_list, length_list)

layer_idx = send_task.layer_idx
remote_block_ids = req_meta.remote_block_ids
Expand All @@ -214,21 +204,36 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str,
local_block_ids = req_meta.local_block_ids

if self.pd_head_ratio == 1:
layer_local_kv_base_addr = [
local_kv_base_addr[i]
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
if self.use_sparse:
layer_local_kv_base_addr = [
local_kv_base_addr[i] for i in
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] # type:ignore
for i in
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
]
else:
layer_local_kv_base_addr = [
local_kv_base_addr[i]
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i] # type:ignore
for i in [2 * layer_idx, 2 * layer_idx + 1]
]
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)

for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
block_len = self.block_len[
k % 2] if self.use_mla else self.block_len[0]
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sparse:
block_len = (self.block_len[k % 3])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's an indexing error when calculating block_len for the sparse attention case. When self.use_sparse is true, layer_local_kv_base_addr is constructed using the first and third elements of the per-layer cache addresses (indices 3 * layer_idx and 3 * layer_idx + 2). However, self.block_len is indexed with k % 3, which results in using self.block_len[0] and self.block_len[1]. This is incorrect as it should be using self.block_len[0] and self.block_len[2] to match the base addresses. This will cause incorrect transfer lengths.

Suggested change
block_len = (self.block_len[k % 3])
block_len = (self.block_len[k * 2])

else:
block_len = (self.block_len[0])
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + group_local_block_id[
Expand Down Expand Up @@ -931,7 +936,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):

# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
Expand All @@ -945,6 +952,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
Expand All @@ -955,8 +977,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)

logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla, self.use_sparse, first_kv_cache.shape)

self.kv_caches = kv_caches
kv_caches_base_addr = []
Expand All @@ -971,9 +994,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sparse else cache_or_caches
Comment on lines +1005 to +1007
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for creating cache_list in the non-MLA, non-sparse case is incorrect. In this else block, cache_or_caches is a single tensor. The expression [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches evaluates to cache_list = cache_or_caches. The subsequent loop for cache in cache_list: will then iterate over the rows of the tensor, which is not the intended behavior. This will lead to incorrect memory registration. The cache_list should be a list containing the single tensor.

                cache_list = [cache_or_caches]

for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
Expand Down Expand Up @@ -1046,56 +1077,72 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor,
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
):
# enable decode prefix cache
if self.use_mla:
reshape_cache_event = attn_metadata[
layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event

if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
rearrange_block_ids = sorted({
block_id
for request in connector_metadata.requests.values()
for block_id in request.local_block_ids
})

keys = kv_layer[0][rearrange_block_ids].clone()
values = kv_layer[1][rearrange_block_ids].clone()
# sort kv caches for each block
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
*keys.shape[2:]).transpose(
0, 1).reshape_as(keys)
values = values.view(values.size(0), self.pd_head_ratio,
-1, *values.shape[2:]).transpose(
0, 1).reshape_as(values)
# reshard kv cache
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(
self.pd_head_ratio, keys, values)
if self.use_mla or self.use_sparse:
num_kv_head = self._decode_tp_size
else:
keys = None
values = None
rearrange_block_ids = None

assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
send_task = SendTask(wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids)
for req_id, req_meta in connector_metadata.requests.items():
req_meta_update = self.update_decoder_info(req_id, req_meta)
logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
send_task.send_request[req_id] = req_meta_update
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
replica_group_idx = self.tp_rank % num_replica_groups
req_ids = sorted(list(connector_metadata.requests.keys()))
selected_req_ids = [
req_id for i, req_id in enumerate(req_ids)
if i % num_replica_groups == replica_group_idx
]
if selected_req_ids:
if self.use_mla or self.use_sparse:
reshape_cache_event = attn_metadata[
layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event

if self.pd_head_ratio != 1:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
rearrange_block_ids = sorted({
block_id
for req_id in selected_req_ids
for block_id in
connector_metadata.requests[req_id].local_block_ids
})

keys = kv_layer[0][rearrange_block_ids].clone()
values = kv_layer[1][rearrange_block_ids].clone()
# sort kv caches for each block
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
*keys.shape[2:]).transpose(
0, 1).reshape_as(keys)
values = values.view(values.size(0),
self.pd_head_ratio, -1,
*values.shape[2:]).transpose(
0, 1).reshape_as(values)
# reshard kv cache
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(
self.pd_head_ratio, keys, values)
else:
keys = None
values = None
rearrange_block_ids = None

assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
send_task = SendTask(wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
layer_idx=self.current_layer,
rearrange_block_ids=rearrange_block_ids)
for req_id, req_meta in connector_metadata.requests.items():
if req_id in selected_req_ids:
req_meta_update = self.update_decoder_info(
req_id, req_meta)
logger.debug(
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
send_task.send_request[req_id] = req_meta_update

self.kv_send_layer_thread.send_queue.put(send_task)
self.kv_send_layer_thread.send_queue.put(send_task)
self.current_layer += 1

def _get_remote_socket(
Expand All @@ -1121,8 +1168,13 @@ def _get_remote_socket(

def update_decoder_info(self, req_id, req_meta):
req_meta_update = copy.deepcopy(req_meta)
req_meta_update.remote_port = req_meta_update.remote_port + (
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
if self.use_mla or self.use_sparse:
pd_tp_ratio = self.tp_size // self._decode_tp_size
req_meta_update.remote_port = req_meta_update.remote_port + (
self.tp_rank // pd_tp_ratio) % self._decode_tp_size
else:
req_meta_update.remote_port = req_meta_update.remote_port + (
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
try:
Expand All @@ -1146,14 +1198,16 @@ def update_decoder_info(self, req_id, req_meta):
logger.info(
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
)
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
ret = self.engine.batch_transfer_sync_write(
session_id, [self.kv_caches_base_addr[0]],
[agent_meta.kv_caches_base_addr[0]], [128])
if ret < 0:
logger.error(
f"Mooncake transfer failed to create link to device {session_id}"
)
if self.pd_head_ratio > 1:
# for tp inequal, pre-create link to prevent alltoall out of memory
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
ret = self.engine.batch_transfer_sync_write(
session_id, [self.kv_caches_base_addr[0]],
[agent_meta.kv_caches_base_addr[0]], [128])
if ret < 0:
logger.error(
f"Mooncake transfer failed to create link to device {session_id}"
)
req_meta_update.remote_te_rpc_port = self.remote_te_port[
req_meta_update.remote_engine_id][req_meta_update.remote_port]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
Expand Down