diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index c2c38f51c500..4c0167398a9b 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -27,18 +27,37 @@ if [[ -n "$ATTENTION_BACKEND" ]]; then echo "Using attention backend: $ATTENTION_BACKEND" fi +PREFILL_KV_LAYOUT=${PREFILL_KV_LAYOUT:-"HND"} DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD +AGREED_BLOCK_SIZE=${AGREED_BLOCK_SIZE:-""} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} +if [[ -n "$AGREED_BLOCK_SIZE" && "$AGREED_BLOCK_SIZE" != "$PREFILL_BLOCK_SIZE" ]]; then + PREFILL_HETERO_BLOCK_SIZE=1 +else + PREFILL_HETERO_BLOCK_SIZE=0 +fi +if [[ "$PREFILL_KV_LAYOUT" == "NHD" || $PREFILL_HETERO_BLOCK_SIZE -eq 1 ]]; then + PREFILL_KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' +else + PREFILL_KV_CONFIG_HETERO_LAYOUT='' +fi if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then - KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' + DECODE_KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' else - KV_CONFIG_HETERO_LAYOUT='' + DECODE_KV_CONFIG_HETERO_LAYOUT='' +fi +if [[ "$AGREED_BLOCK_SIZE" != "" ]]; then + EXTRA_KV_CONFIG='"agreed_block_size":'"$AGREED_BLOCK_SIZE" fi # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' + PREFILL_KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${PREFILL_KV_CONFIG_HETERO_LAYOUT}',"kv_connector_extra_config":{'${EXTRA_KV_CONFIG}'}}' + DECODE_KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${DECODE_KV_CONFIG_HETERO_LAYOUT}',"kv_connector_extra_config":{'${EXTRA_KV_CONFIG}'}}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" + PREFILL_KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${PREFILL_KV_CONFIG_HETERO_LAYOUT}",\"kv_connector_extra_config\":{"${EXTRA_KV_CONFIG}"}}" + DECODE_KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${DECODE_KV_CONFIG_HETERO_LAYOUT}",\"kv_connector_extra_config\":{"${EXTRA_KV_CONFIG}"}}" fi # Models to run @@ -57,8 +76,7 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} -PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} -DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} +DISABLE_PREFIX_CACHE=${DISABLE_PREFIX_CACHE:-false} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -93,6 +111,10 @@ get_model_args() { extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code" fi + if [[ "$DISABLE_PREFIX_CACHE" == "true" ]]; then + extra_args="${extra_args} --no-enable-prefix-caching" + fi + echo "$extra_args" } @@ -145,7 +167,7 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ - VLLM_KV_CACHE_LAYOUT='HND' \ + VLLM_KV_CACHE_LAYOUT=$PREFILL_KV_LAYOUT \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ @@ -154,7 +176,7 @@ run_tests_for_model() { --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$PREFILL_KV_CONFIG'" # Add attention backend config if specified if [[ -n "$ATTENTION_BACKEND" ]]; then @@ -200,7 +222,7 @@ run_tests_for_model() { --enforce-eager \ --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$DECODE_KV_CONFIG'" # Add attention backend config if specified if [[ -n "$ATTENTION_BACKEND" ]]; then diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index fd833e293938..bcad704f69da 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -282,6 +282,81 @@ def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_rati cache.index_copy_(0, indices, permuted_blocks) +def kv_postprocess_blksize_on_save(cache, indices, target_block_size): + """ + Convert current KV Cache blocks to smaller block size + + example: + src blocksize = 16 tokens, target blocksize = 4 tokens + src block[0] = target block[0, 1, 2, 3] + src is |h0-b0..................|h1-b0..................|... + target is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + """ + blocks_to_update = cache.index_select(0, indices) + n_blocks, block_size, n_kv_heads, head_size = blocks_to_update.shape + ratio = block_size // target_block_size + blocks_processed = ( + blocks_to_update + # 1. Split the block dimension: (N, 4, 4, H, D) + .view(n_blocks, ratio, target_block_size, n_kv_heads, head_size) + # 2. Flatten N and Ratio to get new total blocks: (4N, 4, H, D) + .flatten(0, 1) + # 3. Swap Head and Block_Size (NHD -> HND): (4N, H, 4, D) + .permute(0, 2, 1, 3) + ) + expanded_indices = ( + indices.unsqueeze(1) * ratio + torch.arange(ratio, device=indices.device) + ).flatten() + cache_physical = cache.permute(0, 2, 1, 3) + cache_resized_view = cache_physical.view( + -1, n_kv_heads, target_block_size, head_size + ) + cache_resized_view.index_copy_(0, expanded_indices, blocks_processed) + + +def kv_postprocess_layout_and_blksize_on_save(cache, indices, target_block_size): + """ + Convert current KV Cache blocks to smaller block size and permute KV layout + + example: + src blocksize = 16 tokens, target blocksize = 4 tokens + src block[0] = target block[0, 1, 2, 3] + src is |b0-h0..................|b0-h1..................|... + target is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + """ + blocks_to_update = cache.index_select(0, indices) + n_blocks, block_size, n_kv_heads, head_size = blocks_to_update.shape + ratio = block_size // target_block_size + blocks_processed = ( + blocks_to_update + # 1. Split the block dimension: (N, 4, 4, H, D) + .view(n_blocks, ratio, target_block_size, n_kv_heads, head_size) + # 2. Swap Head and Block_Size (NHD -> HND): (4N, H, 4, D) + .permute(0, 1, 3, 2, 4) + .contiguous() + # 3. reshape to fit + .view(-1, target_block_size, n_kv_heads, head_size) + ) + expanded_indices = ( + indices.unsqueeze(1) * ratio + torch.arange(ratio, device=indices.device) + ).flatten() + cache_physical = cache + cache_resized_view = cache_physical.view( + -1, target_block_size, n_kv_heads, head_size + ) + cache_resized_view.index_copy_(0, expanded_indices, blocks_processed) + + +def kv_postprocess_layout_on_save(cache, indices): + blocks_to_update = cache.index_select(0, indices) + target_shape = blocks_to_update.shape + # NHD => HND + blocks_processed = ( + blocks_to_update.permute(0, 2, 1, 3).contiguous().view(target_shape) + ) + cache.index_copy_(0, indices, blocks_processed) + + def yield_req_data( scheduler_output, ) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7e7e3ca55481..63381aa61558 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -29,7 +29,10 @@ get_current_attn_backend, kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_on_receive, + kv_postprocess_blksize_on_save, + kv_postprocess_layout_and_blksize_on_save, kv_postprocess_layout_on_receive, + kv_postprocess_layout_on_save, yield_req_data, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -250,6 +253,66 @@ class ReqMeta: remote: RemoteMeta | None = None +def should_transform_kv_for_transfer( + vllm_config, current_block_size, current_kv_cache_layout +): + assert vllm_config.kv_transfer_config.enable_permute_local_kv + if vllm_config.cache_config.enable_prefix_caching: + logger.warning_once( + "KV cache postprocess is not compatible with prefix caching." + ) + return False, current_kv_cache_layout, current_block_size + postprocess_kv_caches_on_save = False + kv_cache_layout_on_save = "HND" + agreed_block_size = int( + vllm_config.kv_transfer_config.get_from_extra_config( + "agreed_block_size", current_block_size + ) + ) + # Only allow save to smaller block size (larger required additional allocation) + block_size_on_save = ( + agreed_block_size + if agreed_block_size <= current_block_size + else current_block_size + ) + if ( + kv_cache_layout_on_save != current_kv_cache_layout + or block_size_on_save != current_block_size + ): + postprocess_kv_caches_on_save = True + logger.info( + "KV cache postprocess on save is enabled. " + "Local kv cache layout: %s -> %s, " + "block size: %d -> %d", + current_kv_cache_layout, + kv_cache_layout_on_save, + current_block_size, + block_size_on_save, + ) + return postprocess_kv_caches_on_save, kv_cache_layout_on_save, block_size_on_save + + +def get_mapped_blocks(block_ids, block_size_ratio, num_blocks): + """ + Calculates the new set of block IDs by mapping every element + in the (potentially sparse) input array. + Example: block_ids=[0, 2], block_size_ratio=2 + get_mapped_blocks 0 1 [2 3] 4 5 + # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1|| + # local is |h0-b0......||h1-b0......||h2-b0........ + local_block_ids 0 [1] 2 + """ + if block_ids.size == 0: + return [] + + start_ids = block_ids * block_size_ratio + offsets = np.arange(block_size_ratio) + mapped_2d = start_ids[:, None] + offsets[None, :] + ret = mapped_2d.flatten().tolist()[:num_blocks] + + return ret + + class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} @@ -460,6 +523,8 @@ def wait_for_save(self): assert isinstance(self._connector_metadata, NixlConnectorMetadata) if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks: self.connector_worker.save_kv_to_host(self._connector_metadata) + elif self.connector_worker.postprocess_kv_caches_on_save: + self.connector_worker.kv_caches_postprocess(self._connector_metadata) def shutdown(self): if self.connector_worker is not None: @@ -487,6 +552,7 @@ class NixlConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + self.kv_cache_layout = get_kv_cache_layout() self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( @@ -501,6 +567,21 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): vllm_config.kv_transfer_config.kv_buffer_device == "cpu" ) + self.postprocess_kv_caches_on_save = False + self.kv_cache_layout_on_save = self.kv_cache_layout + self.block_size_on_save = self.block_size + + # list of chunked prefill partials + self.partial_reqs: dict[ReqId, list] = {} + + if vllm_config.kv_transfer_config.enable_permute_local_kv: + ( + self.postprocess_kv_caches_on_save, + self.kv_cache_layout_on_save, + self.block_size_on_save, + ) = should_transform_kv_for_transfer( + self.vllm_config, self.block_size, self.kv_cache_layout + ) logger.info("Initializing NIXL Scheduler %s", engine_id) # Background thread for handling new handshake requests. @@ -655,7 +736,9 @@ def update_state_after_alloc( if params.get("do_remote_decode"): self._reqs_in_batch.add(request.request_id) - if self.use_host_buffer and params.get("do_remote_decode"): + if (self.use_host_buffer or self.postprocess_kv_caches_on_save) and params.get( + "do_remote_decode" + ): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. self._reqs_need_save[request.request_id] = request @@ -719,16 +802,27 @@ def build_connector_meta( req = req_to_save assert req.kv_transfer_params is not None - meta.add_new_req_to_save( - request_id=req_id, - local_block_ids=new_block_id_groups[0], - kv_transfer_params=req.kv_transfer_params, - ) + block_ids = new_block_id_groups[0] assert scheduler_output.num_scheduled_tokens is not None num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] is_partial = ( req.num_computed_tokens + num_scheduled_tokens ) < req.num_prompt_tokens + if self.postprocess_kv_caches_on_save: + new_block_ids = self.partial_reqs.get(req_id, []) + new_block_ids = new_block_ids + block_ids + self.partial_reqs[req_id] = new_block_ids + if is_partial: + continue + else: + new_block_ids = block_ids + # set any chunked prefill as partial, except the last chunk + # only submit as new req when not partial + meta.add_new_req_to_save( + request_id=req_id, + local_block_ids=new_block_ids, + kv_transfer_params=req.kv_transfer_params, + ) if not is_partial: # For non-partial prefills, once new req_meta is scheduled, it # can be removed from _reqs_need_save. @@ -736,6 +830,7 @@ def build_connector_meta( # _reqs_need_save until all blocks are scheduled with req_meta. # Therefore, only pop if `not is_partial`. self._reqs_need_save.pop(req_id) + self.partial_reqs.pop(req_id, None) meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch @@ -790,6 +885,8 @@ def request_finished( self._reqs_not_processed.add(request.request_id) # Clear _reqs_need_save if a request is aborted as partial prefill. self._reqs_need_save.pop(request.request_id, None) + # Clear partial_reqs if a request is aborted as partial prefill. + self.partial_reqs.pop(request.request_id, None) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -808,10 +905,23 @@ def request_finished( time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT ) + block_size_ratio = self.block_size // self.block_size_on_save + block_ids_on_save = block_ids + if block_size_ratio > 1: + num_blocks = math.ceil((request.num_tokens - 1) / self.block_size_on_save) + block_ids_on_save = get_mapped_blocks( + np.asarray(block_ids), block_size_ratio, num_blocks + ) + logger.debug( + "request.num_tokens is %s, block_ids is %s, block_ids_on_save is %s", + request.num_tokens, + block_ids, + block_ids_on_save, + ) return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, - remote_block_ids=block_ids, + remote_block_ids=block_ids_on_save, remote_engine_id=self.engine_id, remote_request_id=request.request_id, remote_host=self.side_channel_host, @@ -874,6 +984,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.tp_group = get_tp_group() self.num_blocks = 0 self.enable_permute_local_kv = False + self.postprocess_kv_caches_on_save = False # KV Caches and nixl tracking data. self.device_type = current_platform.device_type @@ -990,6 +1101,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( "enforce_handshake_compat", True ) + self.kv_cache_layout_on_save = self.kv_cache_layout + self.block_size_on_save = self.block_size + if self.kv_transfer_config.enable_permute_local_kv: + ( + self.postprocess_kv_caches_on_save, + self.kv_cache_layout_on_save, + self.block_size_on_save, + ) = should_transform_kv_for_transfer( + self.vllm_config, self.block_size, self.kv_cache_layout + ) + + self.block_size_ratio_on_save = self.block_size // self.block_size_on_save self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} @@ -1314,6 +1437,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms + block_len_per_layer_on_save = list[int]() for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] @@ -1336,6 +1460,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) self.block_size = kernel_block_size self._block_size[self.engine_id] = kernel_block_size + self.block_size_ratio_on_save = ( + self.block_size // self.block_size_on_save + ) seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size() @@ -1343,6 +1470,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if tensor_size_bytes is None: tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] + self.num_blocks_on_save = ( + self.num_blocks * self.block_size_ratio_on_save + ) assert cache.shape[0] == self.num_blocks, ( "All kv cache tensors must have the same number of blocks" @@ -1354,6 +1484,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.slot_size_per_layer.append( self.block_len_per_layer[-1] // self.block_size ) + block_len_per_layer_on_save.append( + self.block_len_per_layer[-1] // self.block_size_ratio_on_save + ) if not self.use_mla: # Different kv cache shape is not supported by HeteroTP @@ -1400,9 +1533,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. self.seen_base_addresses = seen_base_addresses - self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = ( - self.register_local_xfer_handler(self.block_size) - ) + ( + self.src_xfer_handles_by_block_size[self.block_size_on_save], + self.src_blocks_data, + ) = self.register_local_xfer_handler(self.block_size_on_save) # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1432,12 +1566,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): agent_metadata=self.nixl_wrapper.get_agent_metadata(), device_id=self.device_id, kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], - num_blocks=self.num_blocks, - block_lens=self.block_len_per_layer, - kv_cache_layout=self.kv_cache_layout + num_blocks=self.num_blocks_on_save, + block_lens=block_len_per_layer_on_save, + kv_cache_layout=self.kv_cache_layout_on_save if not self.use_host_buffer else self.host_buffer_kv_cache_layout, - block_size=self.block_size, + block_size=self.block_size_on_save, ) # Wrap metadata in payload with hash for defensive decoding encoder = msgspec.msgpack.Encoder() @@ -1470,7 +1604,9 @@ def register_local_xfer_handler( self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio ) block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio - num_blocks = self.num_blocks * block_size_ratio + num_blocks = ( + self.num_blocks * self.block_len_per_layer[i] // block_len_per_layer + ) for block_id in range(num_blocks): block_offset = block_id * block_len_per_layer addr = base_addr + block_offset @@ -1818,6 +1954,51 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): "d2h", ) + def kv_caches_postprocess(self, metadata: NixlConnectorMetadata): + """Post-process the kv caches after receiving from remote. + + This includes permuting the layout if needed and handling + block size mismatches. + """ + block_ids_to_permute = [] + for _, meta in metadata.reqs_to_save.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) + block_ids_to_permute.append(meta.local_physical_block_ids) + for block_ids in block_ids_to_permute: + self.post_process_device_kv_on_save(block_ids) + + def post_process_device_kv_on_save(self, block_ids: list[int]): + """Transforms the local KV cache shape to target shape. + + scenario 1. change KV layout from NHD to HND + scenario 2. change block_size to target block_size + scenario 3. change both layout and block_size + """ + + if len(block_ids) == 0: + return + target_block_size = self.block_size_on_save + split_k_and_v = self.kv_topo.split_k_and_v + sample_cache = list(self.device_kv_caches.values())[0][0] + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + if ( + self.kv_cache_layout_on_save != self.kv_cache_layout + and self.block_size_on_save != self.block_size + ): + kv_postprocess_layout_and_blksize_on_save( + cache, indices, target_block_size + ) + elif self.kv_cache_layout_on_save != self.kv_cache_layout: + kv_postprocess_layout_on_save(cache, indices) + elif self.block_size_on_save != self.block_size: + kv_postprocess_blksize_on_save(cache, indices, target_block_size) + def post_process_device_kv_on_receive( self, block_size_ratio: int, @@ -2183,22 +2364,20 @@ def _read_blocks( """ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: - local_block_ids = self.get_mapped_blocks( - np.asarray(local_block_ids), block_size_ratio + # NOTE: + # get_mapped_blocks will always expand block_ids for n times. + # ex: + # prefill block_ids with block_size as 4: + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # Local decode block_ids with block_size as 16: [1, 2, 3] + # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # Then we clip local to align with prefill + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + local_block_ids = get_mapped_blocks( + np.asarray(local_block_ids), block_size_ratio, len(remote_block_ids) ) - if len(local_block_ids) > len(remote_block_ids): - # NOTE: - # get_mapped_blocks will always expand block_ids for n times. - # ex: - # prefill block_ids with block_size as 4: - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # Local decode block_ids with block_size as 16: [1, 2, 3] - # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - # Then we clip local to align with prefill - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - local_block_ids = local_block_ids[: len(remote_block_ids)] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -2331,25 +2510,6 @@ def _read_blocks( self.nixl_wrapper.release_xfer_handle(handle) self._failed_recv_reqs.add(request_id) - def get_mapped_blocks(self, block_ids, block_size_ratio): - """ - Calculates the new set of block IDs by mapping every element - in the (potentially sparse) input array. - Example: block_ids=[0, 2], block_size_ratio=2 - get_mapped_blocks 0 1 [2 3] 4 5 - # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1|| - # local is |h0-b0......||h1-b0......||h2-b0........ - local_block_ids 0 [1] 2 - """ - if block_ids.size == 0: - return np.array([], dtype=np.int64) - - start_ids = block_ids * block_size_ratio - offsets = np.arange(block_size_ratio) - mapped_2d = start_ids[:, None] + offsets[None, :] - - return mapped_2d.flatten().astype(np.int64) - def _get_block_descs_ids( self, engine_id: str,