Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,124 @@ def test_reset_cache(request_runner, async_scheduling: bool):
for req_status in runner.connector_scheduler._req_status.values():
for group_state in req_status.group_states:
assert group_state.next_stored_block_idx == 0


@pytest.mark.parametrize("async_scheduling", [True, False])
def test_swa_alignment_skip(request_runner, async_scheduling: bool):
"""SWA blocks unreachable by the load path are skipped during store.

Simulates a DeepSeek V4-like hybrid architecture where SWA groups have
much smaller block sizes than the full-attention (MLA) group, causing
most SWA blocks to be unreachable by the alignment-based load path.

Setup:
- Group 0: full attention (MLA-like), block_size=16
- Group 1: SWA, block_size=4, sliding_window=8

alignment_block_count = 16 / 4 = 4 SWA blocks per alignment segment.
sliding_window_size_in_blocks = ceil(8 / 4) = 2.
Within each segment of 4 SWA blocks, only the trailing 2 are stored.

With 32 tokens (2 full-attn blocks, 8 SWA blocks):
- Group 0 stores: blocks 0, 1 (all full-attn blocks)
- Group 1 stores: blocks 2, 3, 6, 7 (skip 0,1,4,5)

For real DeepSeek V4 (100K tokens), this reduces SWA stores by ~78%.
"""
full_attn_block_size = 16
swa_block_size = 4
sliding_window = 8
num_gpu_blocks = 200

kv_cache_groups = [
KVCacheGroupSpec(
["layer0"],
FullAttentionSpec(
block_size=full_attn_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer1"],
SlidingWindowSpec(
block_size=swa_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
]

runner = request_runner(
block_size=swa_block_size,
num_gpu_blocks=num_gpu_blocks,
async_scheduling=async_scheduling,
kv_cache_groups=kv_cache_groups,
)

# Verify config: alignment_block_count computed correctly
kv_group_configs = runner.connector_scheduler.config.kv_group_configs
assert len(kv_group_configs) == 2
# Group 0: full attention -> no alignment skip
assert kv_group_configs[0].alignment_block_count is None
assert kv_group_configs[0].sliding_window_size_in_blocks is None
assert kv_group_configs[0].offloaded_block_size == full_attn_block_size
# Group 1: SWA -> alignment_block_count = 16/4 = 4, tail = 2
assert kv_group_configs[1].alignment_block_count == 4
assert kv_group_configs[1].sliding_window_size_in_blocks == 2
assert kv_group_configs[1].offloaded_block_size == swa_block_size

# Send 32 tokens = 2 full-attn blocks (block_size=16) = 8 SWA blocks
# (block_size=4). Decode 1 token to kick off processing (stores are
# deferred to next step).
num_tokens = 32
runner.new_request(token_ids=[0] * num_tokens)
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[0])

# Decode 1 more token to complete the deferred stores from above.
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
# Group 0 (full attn, block_size=16): 2 offloaded blocks
# -> GPU blocks (0, 0) and (0, 1)
# Group 1 (SWA, block_size=4): 8 offloaded blocks, skip first 2
# per segment of 4:
# Segment 0 (blocks 0-3): skip 0,1 -> store (1, 2), (1, 3)
# Segment 1 (blocks 4-7): skip 4,5 -> store (1, 6), (1, 7)
expected_stored=(
(0, 0),
(0, 1),
(1, 2),
(1, 3),
(1, 6),
(1, 7),
),
)

# Verify that loads still work correctly for the stored SWA blocks.
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * num_tokens + [1])
runner.manager.lookup.return_value = True
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 2
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
# Group 0: full prefix lookup hits 2 offloaded blocks
# -> loads GPU blocks (0, 0), (0, 1)
# Group 1: sliding window lookup finds trailing 2 from last segment
# (blocks 6, 7 which were stored)
# -> loads GPU blocks (1, 6), (1, 7)
expected_loaded=(
(0, 0),
(0, 1),
(1, 6),
(1, 7),
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class GroupOffloadConfig(NamedTuple):
hash_block_size_factor: int
# None below means full attention
sliding_window_size_in_blocks: int | None
# Number of this group's offloaded blocks per full-attention alignment
# segment. Used to skip storing SWA blocks that can never serve a load
# hit (e.g. DeepSeek V4 where SWA groups have much smaller block sizes
# than the MLA full-attention group).
# None for full-attention groups or when the optimization doesn't apply.
alignment_block_count: int | None = None


def get_sliding_window_size_in_blocks(
Expand All @@ -89,6 +95,41 @@ class SchedulerOffloadConfig(NamedTuple):

@classmethod
def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig":
# Determine the alignment token count from the full-attention group(s).
# This is the offloaded_block_size of the full-attention group; load
# hits are always aligned to this boundary, so SWA blocks earlier in
# each segment can never serve a load hit. Relevant for hybrid
# architectures like DeepSeek V4 (MLA + SWA groups).
full_attn_offloaded_block_sizes: set[int] = set()
for idx, gpu_block_size in enumerate(spec.gpu_block_size):
kv_spec = spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec
sw = get_sliding_window_size_in_blocks(
kv_spec, gpu_block_size * spec.block_size_factor
)
if sw is None:
full_attn_offloaded_block_sizes.add(
gpu_block_size * spec.block_size_factor
)

# Only apply the optimization if there's a single consistent
# full-attention alignment size.
alignment_tokens: int | None = None
if len(full_attn_offloaded_block_sizes) == 1:
Comment on lines +114 to +117
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

man this looks like more complexity :(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.
We can try to generalize it, but I don't think it's currently worth the effort.
It's only applicable to DSv4, but it's a major optimization so don't want to skip it.
Note the same optimization was introduced for the prefix caching.
With connector v2 the connectors will be saved from these complexities.

alignment_tokens = full_attn_offloaded_block_sizes.pop()

def _alignment_block_count(
offloaded_block_size: int,
sliding_window_size_in_blocks: int | None,
) -> int | None:
if alignment_tokens is None or sliding_window_size_in_blocks is None:
return None
if alignment_tokens <= offloaded_block_size:
return None
per_segment = alignment_tokens // offloaded_block_size
if sliding_window_size_in_blocks >= per_segment:
return None
return per_segment

return cls(
num_workers=spec.vllm_config.parallel_config.world_size,
kv_group_configs=tuple(
Expand All @@ -100,9 +141,14 @@ def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig":
(gpu_block_size * spec.block_size_factor)
// spec.hash_block_size
),
sliding_window_size_in_blocks=get_sliding_window_size_in_blocks(
spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec,
gpu_block_size * spec.block_size_factor,
sliding_window_size_in_blocks=(
sw := get_sliding_window_size_in_blocks(
spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec,
gpu_block_size * spec.block_size_factor,
)
),
alignment_block_count=_alignment_block_count(
gpu_block_size * spec.block_size_factor, sw
),
)
for idx, gpu_block_size in enumerate(spec.gpu_block_size)
Expand Down Expand Up @@ -639,6 +685,7 @@ def _build_store_jobs(
num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)

# Filter out blocks skipped due to sliding window attention / SSM
# or unreachable by the load path's alignment constraints.
new_offload_keys: list[OffloadKey] = []
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
Expand All @@ -663,9 +710,26 @@ def _build_store_jobs(
]
assert len(offload_keys) == len(offload_block_ids)

for offload_key, block_id in zip(offload_keys, offload_block_ids):
if block_id != 0:
new_offload_keys.append(offload_key)
alignment_block_count = group_config.alignment_block_count
tail = group_config.sliding_window_size_in_blocks

for key_idx, (offload_key, block_id) in enumerate(
zip(offload_keys, offload_block_ids)
):
if block_id == 0:
continue
# Skip SWA blocks that can never serve a load hit:
# within each full-attention alignment segment, only the
# trailing `tail` blocks are reachable by
# _sliding_window_lookup. For DeepSeek V4 with 100K
# tokens this reduces SWA stores by ~78%.
if alignment_block_count is not None:
assert tail is not None
abs_block_idx = start_block_idx + key_idx
pos_in_segment = abs_block_idx % alignment_block_count
if pos_in_segment < alignment_block_count - tail:
continue
new_offload_keys.append(offload_key)

if not new_offload_keys:
req_status.advance_stored_idx(num_offloadable_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def register_kv_caches(
page_size_bytes[layer_name] = (
layer_kv_cache_spec.page_size_bytes
)
unpadded_page_size_bytes[layer_name] = (
layer_kv_cache_spec.real_page_size_bytes
)
else:
# Flash Attention case: (2, num_blocks, ...)
assert test_shape[0] == 2
Expand All @@ -157,8 +160,9 @@ def register_kv_caches(
tensors_per_block[layer_name] = tuple(raw.unbind(0))

page_size_bytes[layer_name] = half_page_size

unpadded_page_size_bytes[layer_name] = page_size_bytes[layer_name]
unpadded_page_size_bytes[layer_name] = (
layer_kv_cache_spec.real_page_size_bytes // 2
)

elif isinstance(layer_kv_cache_spec, MambaSpec):
state_tensors = kv_caches[layer_name]
Expand Down Expand Up @@ -191,7 +195,16 @@ def register_kv_caches(
block_tensors: list[CanonicalKVCacheTensor] = []
block_data_refs: dict[str, list[CanonicalKVCacheRef]] = defaultdict(list)
for kv_cache_tensor in self.spec.kv_cache_config.kv_cache_tensors:
tensor_layer_names = kv_cache_tensor.shared_by
# Filter to layers that were actually processed above.
# _get_kv_cache_config_deepseek_v4 emits KVCacheTensor entries for
# every (tuple_idx, page_size) slot; slots where no group has a
# layer at that index produce an empty shared_by (reserved memory
# with no corresponding model layer).
tensor_layer_names = [
n for n in kv_cache_tensor.shared_by if n in tensors_per_block
]
if not tensor_layer_names:
continue

# verify all layers in the group reference the exact same tensors
assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1
Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/kv_offload/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch

from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import resolve_kv_cache_block_sizes

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -348,17 +349,18 @@ def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"):
* parallel_config.prefill_context_parallel_size
)

# block size used by vLLM for hashing request tokens for the sake
# of enabling prefix caching
self.hash_block_size = (
vllm_config.cache_config.block_size * context_parallel_factor
)
# gpu block size per group
self.gpu_block_size: tuple[int, ...] = tuple(
kv_cache_group.kv_cache_spec.block_size * context_parallel_factor
for kv_cache_group in kv_cache_config.kv_cache_groups
)

# hash_block_size must match what the scheduler uses for
# Request.block_hashes (resolved via resolve_kv_cache_block_sizes).
_, self.hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)

for block_size in self.gpu_block_size:
assert block_size % self.hash_block_size == 0, (
f"gpu_block_size={block_size} not divisible by "
Expand Down
Loading