Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,37 @@ if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi

PREFILL_KV_LAYOUT=${PREFILL_KV_LAYOUT:-"HND"}
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
AGREED_BLOCK_SIZE=${AGREED_BLOCK_SIZE:-""}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
if [[ -n "$AGREED_BLOCK_SIZE" && "$AGREED_BLOCK_SIZE" != "$PREFILL_BLOCK_SIZE" ]]; then
PREFILL_HETERO_BLOCK_SIZE=1
else
PREFILL_HETERO_BLOCK_SIZE=0
fi
if [[ "$PREFILL_KV_LAYOUT" == "NHD" || $PREFILL_HETERO_BLOCK_SIZE -eq 1 ]]; then
PREFILL_KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
PREFILL_KV_CONFIG_HETERO_LAYOUT=''
fi
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
DECODE_KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
KV_CONFIG_HETERO_LAYOUT=''
DECODE_KV_CONFIG_HETERO_LAYOUT=''
fi
if [[ "$AGREED_BLOCK_SIZE" != "" ]]; then
EXTRA_KV_CONFIG='"agreed_block_size":'"$AGREED_BLOCK_SIZE"
fi

# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
PREFILL_KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${PREFILL_KV_CONFIG_HETERO_LAYOUT}',"kv_connector_extra_config":{'${EXTRA_KV_CONFIG}'}}'
DECODE_KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${DECODE_KV_CONFIG_HETERO_LAYOUT}',"kv_connector_extra_config":{'${EXTRA_KV_CONFIG}'}}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
PREFILL_KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${PREFILL_KV_CONFIG_HETERO_LAYOUT}",\"kv_connector_extra_config\":{"${EXTRA_KV_CONFIG}"}}"
DECODE_KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${DECODE_KV_CONFIG_HETERO_LAYOUT}",\"kv_connector_extra_config\":{"${EXTRA_KV_CONFIG}"}}"
fi

# Models to run
Expand All @@ -57,8 +76,7 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
DISABLE_PREFIX_CACHE=${DISABLE_PREFIX_CACHE:-false}

# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
Expand Down Expand Up @@ -93,6 +111,10 @@ get_model_args() {
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi

if [[ "$DISABLE_PREFIX_CACHE" == "true" ]]; then
extra_args="${extra_args} --no-enable-prefix-caching"
fi

echo "$extra_args"
}

Expand Down Expand Up @@ -145,7 +167,7 @@ run_tests_for_model() {

# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
VLLM_KV_CACHE_LAYOUT=$PREFILL_KV_LAYOUT \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
Expand All @@ -154,7 +176,7 @@ run_tests_for_model() {
--block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
--kv-transfer-config '$PREFILL_KV_CONFIG'"

# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
Expand Down Expand Up @@ -200,7 +222,7 @@ run_tests_for_model() {
--enforce-eager \
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
--kv-transfer-config '$DECODE_KV_CONFIG'"

# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
Expand Down
75 changes: 75 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,81 @@ def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_rati
cache.index_copy_(0, indices, permuted_blocks)


def kv_postprocess_blksize_on_save(cache, indices, target_block_size):
"""
Convert current KV Cache blocks to smaller block size

example:
src blocksize = 16 tokens, target blocksize = 4 tokens
src block[0] = target block[0, 1, 2, 3]
src is |h0-b0..................|h1-b0..................|...
target is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
"""
blocks_to_update = cache.index_select(0, indices)
n_blocks, block_size, n_kv_heads, head_size = blocks_to_update.shape
ratio = block_size // target_block_size
blocks_processed = (
blocks_to_update
# 1. Split the block dimension: (N, 4, 4, H, D)
.view(n_blocks, ratio, target_block_size, n_kv_heads, head_size)
# 2. Flatten N and Ratio to get new total blocks: (4N, 4, H, D)
.flatten(0, 1)
# 3. Swap Head and Block_Size (NHD -> HND): (4N, H, 4, D)
.permute(0, 2, 1, 3)
)
expanded_indices = (
indices.unsqueeze(1) * ratio + torch.arange(ratio, device=indices.device)
).flatten()
cache_physical = cache.permute(0, 2, 1, 3)
cache_resized_view = cache_physical.view(
-1, n_kv_heads, target_block_size, head_size
)
cache_resized_view.index_copy_(0, expanded_indices, blocks_processed)


def kv_postprocess_layout_and_blksize_on_save(cache, indices, target_block_size):
"""
Convert current KV Cache blocks to smaller block size and permute KV layout

example:
src blocksize = 16 tokens, target blocksize = 4 tokens
src block[0] = target block[0, 1, 2, 3]
src is |b0-h0..................|b0-h1..................|...
target is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
"""
blocks_to_update = cache.index_select(0, indices)
n_blocks, block_size, n_kv_heads, head_size = blocks_to_update.shape
ratio = block_size // target_block_size
blocks_processed = (
blocks_to_update
# 1. Split the block dimension: (N, 4, 4, H, D)
.view(n_blocks, ratio, target_block_size, n_kv_heads, head_size)
# 2. Swap Head and Block_Size (NHD -> HND): (4N, H, 4, D)
.permute(0, 1, 3, 2, 4)
.contiguous()
# 3. reshape to fit
.view(-1, target_block_size, n_kv_heads, head_size)
)
expanded_indices = (
indices.unsqueeze(1) * ratio + torch.arange(ratio, device=indices.device)
).flatten()
cache_physical = cache
cache_resized_view = cache_physical.view(
-1, target_block_size, n_kv_heads, head_size
)
cache_resized_view.index_copy_(0, expanded_indices, blocks_processed)


def kv_postprocess_layout_on_save(cache, indices):
blocks_to_update = cache.index_select(0, indices)
target_shape = blocks_to_update.shape
# NHD => HND
blocks_processed = (
blocks_to_update.permute(0, 2, 1, 3).contiguous().view(target_shape)
)
cache.index_copy_(0, indices, blocks_processed)


def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
Expand Down
Loading