-
-
Notifications
You must be signed in to change notification settings - Fork 16k
[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout #30275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
002a105
bf0feba
f0befd7
7c8104c
a7f6524
bff8eaa
15ff574
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,9 @@ | |
| from vllm.distributed.kv_transfer.kv_connector.utils import ( | ||
| EngineId, | ||
| TpKVTopology, | ||
| kv_postprocess_blksize_and_layout_on_receive, | ||
| kv_postprocess_blksize_on_receive, | ||
| kv_postprocess_layout_on_receive, | ||
| yield_req_data, | ||
| ) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
|
|
@@ -1749,88 +1752,62 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): | |
| "d2h", | ||
| ) | ||
|
|
||
| def permute_device_kv(self, block_ids: list[int]): | ||
| """Transforms the layout of received KV cache blocks to the local format. | ||
|
|
||
| This method corrects layout mismatches from direct memory copies by | ||
| permuting the tensor dimensions. | ||
|
|
||
| - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` | ||
| - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` | ||
|
|
||
| Args: | ||
| block_ids: A list of block IDs to update and permute. | ||
| def post_process_device_kv_on_receive( | ||
| self, | ||
| block_size_ratio: int, | ||
| block_ids_list: list[list[int]], | ||
| ): | ||
| """ | ||
| Post process device kv cache after receiving from remote. | ||
|
|
||
| Implementation: | ||
| - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout | ||
| - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size | ||
| - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back | ||
| 3 types of post processing supported: | ||
| * kv_cache_postprocess_layout => convert from HND to NHD | ||
| * kv_cache_postprocess_blksize => convert from small block size | ||
| to large block size | ||
| * kv_cache_postprocess_blksize_and_layout => convert from small | ||
| block size to large block size and convert from HND to NHD | ||
|
|
||
| """ | ||
| split_k_and_v = self.kv_topo.split_k_and_v | ||
| inv_order = [0, 2, 1, 3] | ||
| sample_cache = list(self.device_kv_caches.values())[0][0] | ||
| target_shape = list(sample_cache.shape) | ||
| target_shape[0] = -1 | ||
| src_shape = tuple(target_shape[i] for i in inv_order) | ||
| indices = torch.tensor(block_ids, device=sample_cache.device) | ||
|
|
||
| for _, cache_or_caches in self.device_kv_caches.items(): | ||
| cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] | ||
| for cache in cache_list: | ||
| blocks_to_update = cache.index_select(0, indices) | ||
| permuted_blocks = blocks_to_update.reshape(src_shape).permute( | ||
| *inv_order | ||
| ) | ||
| cache.index_copy_(0, indices, permuted_blocks) | ||
|
|
||
| def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]): | ||
| def _process_local_gt_remote(blocks_to_update, block_size_ratio): | ||
| n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] | ||
| remote_block_size = block_size // block_size_ratio | ||
| n_blocks = block_size_ratio | ||
| # actual permute is to convert | ||
| # for local blocksize > remote blocksize | ||
| # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens | ||
| # local block[0] = remote block[0, 1, 2, 3] | ||
| # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... | ||
| # local is |h0-b0..................|h1-b0..................|... | ||
| # permute is to: | ||
| # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) | ||
| # 2. permute => (H, nblocks, remoteN, D) | ||
| # 3. flatten => (H, localN, D) | ||
| permuted_blocks = ( | ||
| blocks_to_update.reshape( | ||
| -1, n_blocks, n_kv_heads, remote_block_size, head_size | ||
| ) | ||
| .permute(0, 2, 1, 3, 4) | ||
| .flatten(2, 3) | ||
| ) | ||
| return permuted_blocks | ||
|
|
||
| if len(self.device_kv_caches) == 0: | ||
| return | ||
| assert block_size_ratio >= 1, "Only nP < nD supported currently." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could probably use debug log here stating what's being post-processed
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used logger.info_once(), is that ok?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they serve two different purposes, a debug log would provide info on the proceeding of the transfer operation per-request which I think is ok being debug. |
||
| if self.enable_permute_local_kv and block_size_ratio > 1: | ||
| logger.debug( | ||
| "Post-processing device kv cache on receive by converting " | ||
| "block_size with %sx bigger and permuting layout from HND" | ||
| " to NHD.", | ||
| block_size_ratio, | ||
| ) | ||
| elif self.enable_permute_local_kv: | ||
| logger.debug( | ||
| "Post-processing device kv cache on receive by permuting layout" | ||
| "from HND to NHD." | ||
| ) | ||
| else: | ||
| logger.debug( | ||
| "Post-processing device kv cache on receive by converting " | ||
| "block_size with %sx bigger.", | ||
| block_size_ratio, | ||
| ) | ||
|
|
||
| split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) | ||
| sample_cache = list(self.device_kv_caches.values())[0][0] | ||
| for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): | ||
| assert block_size_ratio > 1, "Only nP < nD supported currently." | ||
| block_ids_list = [[item for sublist in block_ids_list for item in sublist]] | ||
|
|
||
| for block_ids in block_ids_list: | ||
| indices = torch.tensor(block_ids, device=sample_cache.device) | ||
|
|
||
| for _, cache_or_caches in self.device_kv_caches.items(): | ||
| cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] | ||
| for cache in cache_list: | ||
| blocks_to_update = cache.index_select(0, indices) | ||
| # because kv_cache is always using original layout NHD as | ||
| # virtual shape while stride can be either HND / NHD at | ||
| # initialization. | ||
| # we need to firstly get physical view of the tensor | ||
| permuted_blocks = _process_local_gt_remote( | ||
| blocks_to_update.permute(0, 2, 1, 3), block_size_ratio | ||
| ).permute(0, 2, 1, 3) | ||
| cache.index_copy_(0, indices, permuted_blocks) | ||
|
|
||
| for block_ids in block_ids_list: | ||
| indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) | ||
|
|
||
| for _, cache_or_caches in self.device_kv_caches.items(): | ||
| cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] | ||
| for cache in cache_list: | ||
| if self.enable_permute_local_kv and block_size_ratio > 1: | ||
| kv_postprocess_blksize_and_layout_on_receive( | ||
| cache, indices, block_size_ratio | ||
| ) | ||
| elif self.enable_permute_local_kv: | ||
| kv_postprocess_layout_on_receive(cache, indices) | ||
| else: | ||
| kv_postprocess_blksize_on_receive( | ||
| cache, indices, block_size_ratio | ||
| ) | ||
|
xuechendi marked this conversation as resolved.
|
||
|
|
||
| def get_finished(self) -> tuple[set[str], set[str]]: | ||
| """ | ||
|
|
@@ -1854,7 +1831,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |
| len(done_recving), | ||
| ) | ||
|
|
||
| block_ids_to_permute = [] | ||
| block_ids_for_blocksize_post_process = defaultdict(list) | ||
| for req_id in done_recving: | ||
| # clean up metadata for completed requests | ||
|
|
@@ -1863,24 +1839,22 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |
| assert meta.remote is not None | ||
| if self.use_host_buffer: | ||
| self.sync_recved_kv_to_device(req_id, meta) | ||
| if self.enable_permute_local_kv: | ||
| block_ids_to_permute += meta.local_physical_block_ids | ||
|
|
||
| # post processing for heteroblocksize | ||
| block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( | ||
| meta.remote.engine_id | ||
| ) | ||
| if ( | ||
| not self.use_mla | ||
| and block_size_ratio > 1 | ||
| and self.kv_cache_layout == "HND" | ||
| if not self.use_mla and ( | ||
| block_size_ratio > 1 or self.enable_permute_local_kv | ||
| ): | ||
| block_ids_for_blocksize_post_process[block_size_ratio].append( | ||
| meta.local_block_ids | ||
| meta.local_physical_block_ids | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok this was a bug then right
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I missed that in previous PR |
||
| ) | ||
| self.blocksize_post_process(block_ids_for_blocksize_post_process) | ||
| if len(block_ids_to_permute) > 0: | ||
| self.permute_device_kv(block_ids_to_permute) | ||
| for ( | ||
| block_size_ratio, | ||
| block_ids_list, | ||
| ) in block_ids_for_blocksize_post_process.items(): | ||
| self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) | ||
|
|
||
| # Handle timeout to avoid stranding blocks on remote. | ||
| now = time.perf_counter() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.