-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards #5722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8009367
581dd00
bc8511a
8d3e1be
bbed10f
81a3fb9
361ffc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,7 @@ def __init__( | |
| self.kv_caches_base_addr = kv_cache_base_addr | ||
| self.total_layers = total_layers | ||
| self.use_mla = use_mla | ||
| self.use_sparse = len(block_len) == 3 | ||
| self.block_len = block_len | ||
| self._decode_tp_size = decode_tp_size | ||
| self.resharding_stream = resharding_stream | ||
|
|
@@ -195,17 +196,6 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, | |
| src_list: list[str] = [] | ||
| dst_list: list[str] = [] | ||
| length_list: list[int] = [] | ||
| # not need to send kv cache | ||
| if self.tp_rank % self.num_head_replica != 0: | ||
| logger.debug( | ||
| f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})." | ||
| ) | ||
| return (src_list, dst_list, length_list) | ||
| if self.use_mla and self.tp_rank >= self._decode_tp_size: | ||
| logger.debug( | ||
| f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})." | ||
| ) | ||
| return (src_list, dst_list, length_list) | ||
|
|
||
| layer_idx = send_task.layer_idx | ||
| remote_block_ids = req_meta.remote_block_ids | ||
|
|
@@ -214,21 +204,36 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, | |
| local_block_ids = req_meta.local_block_ids | ||
|
|
||
| if self.pd_head_ratio == 1: | ||
| layer_local_kv_base_addr = [ | ||
| local_kv_base_addr[i] | ||
| for i in [2 * layer_idx, 2 * layer_idx + 1] | ||
| ] | ||
| layer_remote_kv_base_addr = [ | ||
| remote_kv_base_addrs[i] # type:ignore | ||
| for i in [2 * layer_idx, 2 * layer_idx + 1] | ||
| ] | ||
| if self.use_sparse: | ||
| layer_local_kv_base_addr = [ | ||
| local_kv_base_addr[i] for i in | ||
| [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] | ||
| ] | ||
| layer_remote_kv_base_addr = [ | ||
| remote_kv_base_addrs[i] # type:ignore | ||
| for i in | ||
| [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] | ||
| ] | ||
| else: | ||
| layer_local_kv_base_addr = [ | ||
| local_kv_base_addr[i] | ||
| for i in [2 * layer_idx, 2 * layer_idx + 1] | ||
| ] | ||
| layer_remote_kv_base_addr = [ | ||
| remote_kv_base_addrs[i] # type:ignore | ||
| for i in [2 * layer_idx, 2 * layer_idx + 1] | ||
| ] | ||
| grouped_remote_block_ids, grouped_local_block_ids = \ | ||
| group_concurrent_contiguous(remote_block_ids, local_block_ids) | ||
|
|
||
| for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( | ||
| zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): | ||
| block_len = self.block_len[ | ||
| k % 2] if self.use_mla else self.block_len[0] | ||
| if self.use_mla: | ||
| block_len = (self.block_len[k % 2]) | ||
| elif self.use_sparse: | ||
| block_len = (self.block_len[k % 3]) | ||
| else: | ||
| block_len = (self.block_len[0]) | ||
| for group_remote_block_id, group_local_block_id in zip( | ||
| grouped_remote_block_ids, grouped_local_block_ids): | ||
| src = src_layer_base_addr + group_local_block_id[ | ||
|
|
@@ -931,7 +936,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
|
|
||
| # TODO(tms): Find a more robust way to detect and handle MLA | ||
| self.use_mla = first_kv_cache_tuple[0].size( | ||
| -1) != first_kv_cache_tuple[1].size(-1) | ||
| -1) != first_kv_cache_tuple[1].size(-1) and len( | ||
| first_kv_cache_tuple) == 2 | ||
| self.use_sparse = len(first_kv_cache_tuple) == 3 | ||
| if self.use_mla: | ||
| # MLA case.[num_block, block_size, 1, hidden_dim] | ||
| self.num_blocks = first_kv_cache.shape[0] | ||
|
|
@@ -945,6 +952,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| logger.info( | ||
| "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", | ||
| self.num_blocks, block_shape_norm, block_shape_pe) | ||
| elif self.use_sparse: | ||
| self.num_blocks = first_kv_cache.shape[0] | ||
| block_rank = 3 # [block_size, latent_dim] | ||
| block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] | ||
| block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] | ||
| block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] | ||
| self.block_len = [ | ||
| first_kv_cache[0].element_size() * math.prod(block_shape_norm), | ||
| first_kv_cache[1].element_size() * math.prod(block_shape_pe), | ||
| first_kv_cache[2].element_size() * math.prod(block_shape_k) | ||
| ] | ||
| logger.info( | ||
| "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", | ||
| self.num_blocks, block_shape_norm, block_shape_pe, | ||
| block_shape_k) | ||
| else: | ||
| # [num_block, block_size, num_head, hidden_dim] | ||
| self.num_blocks = first_kv_cache.shape[0] | ||
|
|
@@ -955,8 +977,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, | ||
| block_shape) | ||
|
|
||
| logger.info("Registering KV_Caches. use_mla: %s, shape %s", | ||
| self.use_mla, first_kv_cache.shape) | ||
| logger.info( | ||
| "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", | ||
| self.use_mla, self.use_sparse, first_kv_cache.shape) | ||
|
|
||
| self.kv_caches = kv_caches | ||
| kv_caches_base_addr = [] | ||
|
|
@@ -971,9 +994,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| kv_caches_base_addr.append(base_addr) | ||
| ptrs.append(base_addr) | ||
| lengths.append(region_len) | ||
| elif self.use_sparse: | ||
| for i, cache in enumerate(cache_or_caches, 0): | ||
| base_addr = cache.data_ptr() | ||
| region_len = self.num_blocks * self.block_len[i % 3] | ||
| kv_caches_base_addr.append(base_addr) | ||
| ptrs.append(base_addr) | ||
| lengths.append(region_len) | ||
| else: | ||
| cache_list = [cache_or_caches | ||
| ] if self.use_mla else cache_or_caches | ||
| cache_list = [ | ||
| cache_or_caches | ||
| ] if self.use_mla or self.use_sparse else cache_or_caches | ||
|
Comment on lines
+1005
to
+1007
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for creating cache_list = [cache_or_caches] |
||
| for cache in cache_list: | ||
| base_addr = cache.data_ptr() | ||
| region_len = self.num_blocks * self.block_len[0] | ||
|
|
@@ -1046,56 +1077,72 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, | |
| if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( | ||
| ): | ||
| # enable decode prefix cache | ||
| if self.use_mla: | ||
| reshape_cache_event = attn_metadata[ | ||
| layer_name].reshape_cache_event | ||
| else: | ||
| reshape_cache_event = attn_metadata.reshape_cache_event | ||
|
|
||
| if self.pd_head_ratio != 1: | ||
| assert self.resharding_stream is not None | ||
| with npu_stream_switch(self.resharding_stream): | ||
| reshape_cache_event.wait() | ||
| rearrange_block_ids = sorted({ | ||
| block_id | ||
| for request in connector_metadata.requests.values() | ||
| for block_id in request.local_block_ids | ||
| }) | ||
|
|
||
| keys = kv_layer[0][rearrange_block_ids].clone() | ||
| values = kv_layer[1][rearrange_block_ids].clone() | ||
| # sort kv caches for each block | ||
| keys = keys.view(keys.size(0), self.pd_head_ratio, -1, | ||
| *keys.shape[2:]).transpose( | ||
| 0, 1).reshape_as(keys) | ||
| values = values.view(values.size(0), self.pd_head_ratio, | ||
| -1, *values.shape[2:]).transpose( | ||
| 0, 1).reshape_as(values) | ||
| # reshard kv cache | ||
| keys = keys.reshape(-1, *kv_layer[0].shape[2:]) | ||
| values = values.reshape(-1, *kv_layer[1].shape[2:]) | ||
| (keys, values) = kv_alltoall_and_rearrange( | ||
| self.pd_head_ratio, keys, values) | ||
| if self.use_mla or self.use_sparse: | ||
| num_kv_head = self._decode_tp_size | ||
| else: | ||
| keys = None | ||
| values = None | ||
| rearrange_block_ids = None | ||
|
|
||
| assert self.kv_send_layer_thread is not None | ||
| assert reshape_cache_event is not None | ||
| send_task = SendTask(wait_event=reshape_cache_event, | ||
| k_cache=keys, | ||
| v_cache=values, | ||
| layer_idx=self.current_layer, | ||
| rearrange_block_ids=rearrange_block_ids) | ||
| for req_id, req_meta in connector_metadata.requests.items(): | ||
| req_meta_update = self.update_decoder_info(req_id, req_meta) | ||
| logger.debug( | ||
| f"Add request {req_id} to kv send layer thread. {req_meta_update=}" | ||
| ) | ||
| send_task.send_request[req_id] = req_meta_update | ||
| num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads | ||
| num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1 | ||
| replica_group_idx = self.tp_rank % num_replica_groups | ||
| req_ids = sorted(list(connector_metadata.requests.keys())) | ||
| selected_req_ids = [ | ||
| req_id for i, req_id in enumerate(req_ids) | ||
| if i % num_replica_groups == replica_group_idx | ||
| ] | ||
| if selected_req_ids: | ||
| if self.use_mla or self.use_sparse: | ||
| reshape_cache_event = attn_metadata[ | ||
| layer_name].reshape_cache_event | ||
| else: | ||
| reshape_cache_event = attn_metadata.reshape_cache_event | ||
|
|
||
| if self.pd_head_ratio != 1: | ||
| assert self.resharding_stream is not None | ||
| with npu_stream_switch(self.resharding_stream): | ||
| reshape_cache_event.wait() | ||
| rearrange_block_ids = sorted({ | ||
| block_id | ||
| for req_id in selected_req_ids | ||
| for block_id in | ||
| connector_metadata.requests[req_id].local_block_ids | ||
| }) | ||
|
|
||
| keys = kv_layer[0][rearrange_block_ids].clone() | ||
| values = kv_layer[1][rearrange_block_ids].clone() | ||
| # sort kv caches for each block | ||
| keys = keys.view(keys.size(0), self.pd_head_ratio, -1, | ||
| *keys.shape[2:]).transpose( | ||
| 0, 1).reshape_as(keys) | ||
| values = values.view(values.size(0), | ||
| self.pd_head_ratio, -1, | ||
| *values.shape[2:]).transpose( | ||
| 0, 1).reshape_as(values) | ||
| # reshard kv cache | ||
| keys = keys.reshape(-1, *kv_layer[0].shape[2:]) | ||
| values = values.reshape(-1, *kv_layer[1].shape[2:]) | ||
| (keys, values) = kv_alltoall_and_rearrange( | ||
| self.pd_head_ratio, keys, values) | ||
| else: | ||
| keys = None | ||
| values = None | ||
| rearrange_block_ids = None | ||
|
|
||
| assert self.kv_send_layer_thread is not None | ||
| assert reshape_cache_event is not None | ||
| send_task = SendTask(wait_event=reshape_cache_event, | ||
| k_cache=keys, | ||
| v_cache=values, | ||
| layer_idx=self.current_layer, | ||
| rearrange_block_ids=rearrange_block_ids) | ||
| for req_id, req_meta in connector_metadata.requests.items(): | ||
| if req_id in selected_req_ids: | ||
| req_meta_update = self.update_decoder_info( | ||
| req_id, req_meta) | ||
| logger.debug( | ||
| f"Add request {req_id} to kv send layer thread. {req_meta_update=}" | ||
| ) | ||
| send_task.send_request[req_id] = req_meta_update | ||
|
|
||
| self.kv_send_layer_thread.send_queue.put(send_task) | ||
| self.kv_send_layer_thread.send_queue.put(send_task) | ||
| self.current_layer += 1 | ||
|
|
||
| def _get_remote_socket( | ||
|
|
@@ -1121,8 +1168,13 @@ def _get_remote_socket( | |
|
|
||
| def update_decoder_info(self, req_id, req_meta): | ||
| req_meta_update = copy.deepcopy(req_meta) | ||
| req_meta_update.remote_port = req_meta_update.remote_port + ( | ||
| self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size | ||
| if self.use_mla or self.use_sparse: | ||
| pd_tp_ratio = self.tp_size // self._decode_tp_size | ||
| req_meta_update.remote_port = req_meta_update.remote_port + ( | ||
| self.tp_rank // pd_tp_ratio) % self._decode_tp_size | ||
| else: | ||
| req_meta_update.remote_port = req_meta_update.remote_port + ( | ||
| self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size | ||
| if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ | ||
| req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: | ||
| try: | ||
|
|
@@ -1146,14 +1198,16 @@ def update_decoder_info(self, req_id, req_meta): | |
| logger.info( | ||
| f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" | ||
| ) | ||
| session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" | ||
| ret = self.engine.batch_transfer_sync_write( | ||
| session_id, [self.kv_caches_base_addr[0]], | ||
| [agent_meta.kv_caches_base_addr[0]], [128]) | ||
| if ret < 0: | ||
| logger.error( | ||
| f"Mooncake transfer failed to create link to device {session_id}" | ||
| ) | ||
| if self.pd_head_ratio > 1: | ||
| # for tp inequal, pre-create link to prevent alltoall out of memory | ||
| session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" | ||
| ret = self.engine.batch_transfer_sync_write( | ||
| session_id, [self.kv_caches_base_addr[0]], | ||
| [agent_meta.kv_caches_base_addr[0]], [128]) | ||
| if ret < 0: | ||
| logger.error( | ||
| f"Mooncake transfer failed to create link to device {session_id}" | ||
| ) | ||
| req_meta_update.remote_te_rpc_port = self.remote_te_port[ | ||
| req_meta_update.remote_engine_id][req_meta_update.remote_port] | ||
| req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an indexing error when calculating
block_lenfor the sparse attention case. Whenself.use_sparseis true,layer_local_kv_base_addris constructed using the first and third elements of the per-layer cache addresses (indices3 * layer_idxand3 * layer_idx + 2). However,self.block_lenis indexed withk % 3, which results in usingself.block_len[0]andself.block_len[1]. This is incorrect as it should be usingself.block_len[0]andself.block_len[2]to match the base addresses. This will cause incorrect transfer lengths.