diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index ca164ad8cd4..5fc64590765 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -333,7 +333,6 @@ def __init__( self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 - self.use_sparse = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=32) @@ -529,15 +528,11 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]): req_start_time = time.perf_counter() src_list, dst_list, length_list = [], [], [] + block_length = len(self.block_len) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs) ): - if self.use_mla: - block_len = self.block_len[k % 2] - elif self.use_sparse: - block_len = self.block_len[k % 3] - else: - block_len = self.block_len[0] + block_len = self.block_len[k % block_length] inner_block_len = block_len // tp_num_need_pulls for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids): src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len @@ -1196,51 +1191,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 ) self.use_sparse = len(first_kv_cache_tuple) == 3 - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) - elif self.use_sparse: - self.num_blocks = first_kv_cache.shape[0] + + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + self.block_len = [] + if self.use_mla or self.use_sparse: block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - block_shape_k, - ) + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # eager:[num_block, block_size, num_head, hidden_dim] - # torchair:[num_block, block_size, num_head*hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() block_rank = ( len(first_kv_cache.shape) - 1 ) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] + logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", self.use_mla, @@ -1252,30 +1221,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 2] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - elif self.use_sparse: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 3] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[0] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) global_te.register_buffer(ptrs, lengths) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 8425984810a..03584cd38d0 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -233,15 +233,11 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta) remote_block_ids, local_block_ids ) + block_length = len(self.block_len) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) ): - if self.use_mla: - block_len = self.block_len[k % 2] - elif self.use_sparse: - block_len = self.block_len[k % 3] - else: - block_len = self.block_len[0] + block_len = self.block_len[k % block_length] for group_remote_block_id, group_local_block_id in zip( grouped_remote_block_ids, grouped_local_block_ids ): @@ -925,48 +921,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 ) self.use_sparse = len(first_kv_cache_tuple) == 3 - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) - elif self.use_sparse: - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - block_shape_k, - ) + + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + block_rank = 3 + self.block_len = [] + if self.use_mla or self.use_sparse: + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # [num_block, block_size, num_head, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", @@ -979,30 +948,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 2] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - elif self.use_sparse: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % 3] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[0] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) global_te.register_buffer(ptrs, lengths) self.kv_caches_base_addr = kv_caches_base_addr