Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import hashlib
import math
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
Expand Down Expand Up @@ -917,8 +918,12 @@ def unify_kv_cache_spec_page_size(
"""
Unify the page size of the given KVCacheSpec. If the page size of all layers
are the same, return the original KVCacheSpec. If not same, unify the page
size by increasing the block size of layers with smaller page size. Raise
NotImplementedError if failed to unify the page size.
size by increasing the block size of layers with smaller page size.

For hybrid models (e.g. TurboQuant + DeltaNet/Mamba), page sizes may not
be naturally divisible. In this case, the largest page size is padded UP
to the nearest multiple of all smaller page sizes using page_size_padded.
Memory overhead is typically <0.1%.

Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Expand All @@ -932,21 +937,59 @@ def unify_kv_cache_spec_page_size(
return kv_cache_spec

max_page_size = max(page_sizes)

# Check if all smaller pages divide max evenly (fast path)
smaller_sizes = sorted(ps for ps in page_sizes if ps < max_page_size)
all_divide = all(max_page_size % ps == 0 for ps in smaller_sizes)

if all_divide:
target_page_size = max_page_size
else:
# Hybrid model: page sizes not naturally divisible.
# Pad max_page_size UP to nearest multiple of LCM of smaller sizes.
smaller_lcm = math.lcm(*smaller_sizes)
target_page_size = (
(max_page_size + smaller_lcm - 1) // smaller_lcm
) * smaller_lcm
logger.info(
"Page size unification: padding max %d -> %d (LCM of smaller = %d, "
"overhead %.3f%%)",
max_page_size,
target_page_size,
smaller_lcm,
(target_page_size - max_page_size) / max_page_size * 100,
)

new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
layer_page = layer_spec.page_size_bytes
if layer_page == target_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
elif layer_page < target_page_size and target_page_size % layer_page == 0:
# Scale up block_size so page matches target
ratio = target_page_size // layer_page
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
assert new_spec.page_size_bytes == target_page_size, (
f"Page size mismatch after block_size adjust: "
f"{new_spec.page_size_bytes} != {target_page_size}"
)
new_kv_cache_spec[layer_name] = new_spec
else:
# Layer had the original max page size but target was padded up.
# Use page_size_padded to pad this layer to target.
try:
new_spec = replace(layer_spec, page_size_padded=target_page_size)
except TypeError as e:
raise NotImplementedError(
f"Cannot pad page size for {type(layer_spec).__name__}: "
f"page_size_padded not supported. "
f"Layer page={layer_page}, target={target_page_size}"
) from e
assert new_spec.page_size_bytes == target_page_size, (
f"Page size mismatch after padding: "
f"{new_spec.page_size_bytes} != {target_page_size}"
)
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec

Expand Down