Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 21 additions & 67 deletions vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def __init__(
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sparse = len(block_len) == 3

self.request_queue: queue.Queue[Any] = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=32)
Expand Down Expand Up @@ -529,15 +528,11 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]):

req_start_time = time.perf_counter()
src_list, dst_list, length_list = [], [], []
block_length = len(self.block_len)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)
):
if self.use_mla:
block_len = self.block_len[k % 2]
elif self.use_sparse:
block_len = self.block_len[k % 3]
else:
block_len = self.block_len[0]
block_len = self.block_len[k % block_length]
inner_block_len = block_len // tp_num_need_pulls
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len
Expand Down Expand Up @@ -1196,51 +1191,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
)
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]

self.num_blocks = first_kv_cache.shape[0]
logger.info("num_blocks: %s", self.num_blocks)
self.block_len = []
if self.use_mla or self.use_sparse:
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
block_shape_k,
)
for i in range(len(first_kv_cache_tuple)):
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
logger.info("block_shape: %s", block_shape)
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There seems to be a potential bug here. first_kv_cache is assigned first_kv_cache_tuple[0], so it's the first tensor in the tuple. When calculating block_len, you are using first_kv_cache[i].element_size(), which is equivalent to first_kv_cache.element_size() for any valid i. This means you are using the element size of the first tensor (first_kv_cache_tuple[0]) for all tensors in first_kv_cache_tuple. If the tensors in first_kv_cache_tuple have different dtypes, this will lead to an incorrect block_len calculation. It should probably be first_kv_cache_tuple[i].element_size() to get the element size of the correct tensor in the tuple.

Suggested change
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
self.block_len.append(first_kv_cache_tuple[i].element_size() * math.prod(block_shape))

else:
# eager:[num_block, block_size, num_head, hidden_dim]
# torchair:[num_block, block_size, num_head*hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = (
len(first_kv_cache.shape) - 1
) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
logger.info("block_shape: %s", block_shape)
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]

logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla,
Expand All @@ -1252,30 +1221,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr = []
ptrs = []
lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,11 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta)
remote_block_ids, local_block_ids
)

block_length = len(self.block_len)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
):
if self.use_mla:
block_len = self.block_len[k % 2]
elif self.use_sparse:
block_len = self.block_len[k % 3]
else:
block_len = self.block_len[0]
block_len = self.block_len[k % block_length]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
):
Expand Down Expand Up @@ -925,48 +921,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2
)
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
block_shape_k,
)

self.num_blocks = first_kv_cache.shape[0]
logger.info("num_blocks: %s", self.num_blocks)
block_rank = 3
self.block_len = []
if self.use_mla or self.use_sparse:
for i in range(len(first_kv_cache_tuple)):
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
logger.info("block_shape: %s", block_shape)
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the other file, there's a potential bug here. first_kv_cache is first_kv_cache_tuple[0]. The code first_kv_cache[i].element_size() uses the element size of the first tensor for all calculations within the loop. If tensors in first_kv_cache_tuple can have different dtypes, this will be incorrect. You should use first_kv_cache_tuple[i].element_size() to ensure you're using the element size of the correct tensor.

Suggested change
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
self.block_len.append(first_kv_cache_tuple[i].element_size() * math.prod(block_shape))

else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
logger.info("block_shape: %s", block_shape)
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]

logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
Expand All @@ -979,30 +948,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr = []
ptrs = []
lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
self.kv_caches_base_addr = kv_caches_base_addr

Expand Down