diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index 398cc3fbfc8..8cc1bad1ed0 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -92,10 +92,9 @@ def to_string(self): class ChunkedTokenDatabase: - def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None): + def __init__(self, metadata: KeyMetadata, block_size: int, partitions: list[int] | None): self.metadata = metadata self.block_size = block_size - self.use_mla = use_mla self.kv_caches_base_addr: list[int] = [] self.block_len: list[int] = [] self.partitions = partitions @@ -117,29 +116,24 @@ def prepare_value(self, start: int, end: int, block_ids: list[int]): addr_list = [] size_list = [] block_id = block_ids[start // self.block_size] + length = len(self.block_len) for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0] - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) + addr = base_addr + block_id * self.block_len[index % length] + size = int(self.block_len[index % length] / self.block_size * (end - start)) addr_list.append(addr) - size_list.append(length) + size_list.append(size) return addr_list, size_list, block_id def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): block_id = block_ids[start // self.block_size] - if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1] - length_k = int(self.block_len[0] / self.block_size * (end - start)) - length_v = int(self.block_len[1] / self.block_size * (end - start)) - size_list = [length_k, length_v] - else: - addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0] - length = int(self.block_len[0] / self.block_size * (end - start)) - size_list = [length, length] - addr_list = [addr_k, addr_v] + addr_list = [] + size_list = [] + length = len(self.block_len) + for i in range(length): + addr = self.kv_caches_base_addr[layer_id * length] + block_id * self.block_len[i] + size = int(self.block_len[i] / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(size) return addr_list, size_list def process_tokens( diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 4e69a112c82..e293f718a25 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -52,6 +52,7 @@ def __init__( self.use_mla = False if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla: self.use_mla = True + self.use_sparse = hasattr(model_config.hf_text_config, "index_topk") self.use_layerwise = use_layerwize self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -123,7 +124,7 @@ def __init__( for i in range(2, remaining_layers + 2): partitions[-i] += 1 - self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) + self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, partitions) real_backend = backend_map.get(self.backend.lower()) @@ -145,55 +146,42 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] - # TODO(tms): Find a more robust way to detect and handle MLA - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + block_rank = 3 + self.block_len = [] + if self.use_mla or self.use_sparse: + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # [num_block, block_size, num_head, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - - logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] + + logger.info( + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, + self.use_sparse, + first_kv_cache.shape, + ) self.kv_caches = kv_caches self.kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - region_len = self.num_blocks * self.block_len[i % 2] - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - region_len = self.num_blocks * self.block_len[0] - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + self.kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) + self.m_store.register_buffer(ptrs, lengths) self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) self.token_database.set_block_len(self.block_len)