Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,84 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)


def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio):
"""
Transforms the layout of received KV cache blocks to the local block_size.
(Only works for local blocksize > remote blocksize)

example:
local blocksize = 16 tokens, remote blocksize = 4 tokens
local block[0] = remote block[0, 1, 2, 3]
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
local is |h0-b0..................|h1-b0..................|...
permute is to:
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
2. permute => (H, nblocks, remoteN, D)
3. flatten => (H, localN, D)
"""
blocks_to_update = cache.index_select(0, indices)
# use physical order
blocks_to_update = blocks_to_update.permute(0, 2, 1, 3)
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio

permuted_blocks = (
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
.permute(0, 2, 1, 3, 4)
.flatten(2, 3)
)
permuted_blocks = permuted_blocks.permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)


def kv_postprocess_layout_on_receive(cache, indices):
"""Transforms the layout of received KV cache blocks to the local format.

This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.

- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`

Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back

"""
blocks_to_update = cache.index_select(0, indices)
target_shape = list(blocks_to_update.shape)
target_shape[0] = -1
inv_order = [0, 2, 1, 3]
src_shape = tuple(target_shape[i] for i in inv_order)
blocks_to_update = cache.index_select(0, indices)
Comment thread
xuechendi marked this conversation as resolved.
permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order)
cache.index_copy_(0, indices, permuted_blocks)


def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
"""
Transforms the layout of received KV cache to the local block_size and HND.
(Only works for local blocksize > remote blocksize)

prefill is HND, smaller block_size
decode(local) is NHD, larger block_size
"""
blocks_to_update = cache.index_select(0, indices)

block_size, n_kv_heads, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio

permuted_blocks = (
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
.permute(0, 1, 3, 2, 4)
.flatten(1, 2)
)
cache.index_copy_(0, indices, permuted_blocks)


def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
Expand Down
148 changes: 61 additions & 87 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
Expand Down Expand Up @@ -1749,88 +1752,62 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata):
"d2h",
)

def permute_device_kv(self, block_ids: list[int]):
"""Transforms the layout of received KV cache blocks to the local format.

This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.

- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`

Args:
block_ids: A list of block IDs to update and permute.
def post_process_device_kv_on_receive(
self,
block_size_ratio: int,
block_ids_list: list[list[int]],
):
"""
Post process device kv cache after receiving from remote.

Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
3 types of post processing supported:
* kv_cache_postprocess_layout => convert from HND to NHD
* kv_cache_postprocess_blksize => convert from small block size
to large block size
* kv_cache_postprocess_blksize_and_layout => convert from small
block size to large block size and convert from HND to NHD

"""
split_k_and_v = self.kv_topo.split_k_and_v
inv_order = [0, 2, 1, 3]
sample_cache = list(self.device_kv_caches.values())[0][0]
target_shape = list(sample_cache.shape)
target_shape[0] = -1
src_shape = tuple(target_shape[i] for i in inv_order)
indices = torch.tensor(block_ids, device=sample_cache.device)

for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
*inv_order
)
cache.index_copy_(0, indices, permuted_blocks)

def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio
# actual permute is to convert
# for local blocksize > remote blocksize
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
# local block[0] = remote block[0, 1, 2, 3]
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
# local is |h0-b0..................|h1-b0..................|...
# permute is to:
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
# 2. permute => (H, nblocks, remoteN, D)
# 3. flatten => (H, localN, D)
permuted_blocks = (
blocks_to_update.reshape(
-1, n_blocks, n_kv_heads, remote_block_size, head_size
)
.permute(0, 2, 1, 3, 4)
.flatten(2, 3)
)
return permuted_blocks

if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could probably use debug log here stating what's being post-processed

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used logger.info_once(), is that ok?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they serve two different purposes, a debug log would provide info on the proceeding of the transfer operation per-request which I think is ok being debug.
info_once may still be useful for the end user, although in theory we could later allow deployments where P1 block_size != P2 block_size=D block_size, hence the log info_once would fall short in reporting that.

if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger and permuting layout from HND"
" to NHD.",
block_size_ratio,
)
elif self.enable_permute_local_kv:
logger.debug(
"Post-processing device kv cache on receive by permuting layout"
"from HND to NHD."
)
else:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger.",
block_size_ratio,
)

split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently."
block_ids_list = [[item for sublist in block_ids_list for item in sublist]]

for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=sample_cache.device)

for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
# because kv_cache is always using original layout NHD as
# virtual shape while stride can be either HND / NHD at
# initialization.
# we need to firstly get physical view of the tensor
permuted_blocks = _process_local_gt_remote(
blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
).permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)

for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)

for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
if self.enable_permute_local_kv and block_size_ratio > 1:
kv_postprocess_blksize_and_layout_on_receive(
cache, indices, block_size_ratio
)
elif self.enable_permute_local_kv:
kv_postprocess_layout_on_receive(cache, indices)
else:
kv_postprocess_blksize_on_receive(
cache, indices, block_size_ratio
)
Comment thread
xuechendi marked this conversation as resolved.

def get_finished(self) -> tuple[set[str], set[str]]:
"""
Expand All @@ -1854,7 +1831,6 @@ def get_finished(self) -> tuple[set[str], set[str]]:
len(done_recving),
)

block_ids_to_permute = []
block_ids_for_blocksize_post_process = defaultdict(list)
for req_id in done_recving:
# clean up metadata for completed requests
Expand All @@ -1863,24 +1839,22 @@ def get_finished(self) -> tuple[set[str], set[str]]:
assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_physical_block_ids

# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote.engine_id
)
if (
not self.use_mla
and block_size_ratio > 1
and self.kv_cache_layout == "HND"
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
):
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_block_ids
meta.local_physical_block_ids
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this was a bug then right

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I missed that in previous PR

)
self.blocksize_post_process(block_ids_for_blocksize_post_process)
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
for (
block_size_ratio,
block_ids_list,
) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)

# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
Expand Down