Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bcd4fbf
Support TurboQuant for YOCO + sliding-window models (Gemma 4 E4B)
ctao456 Apr 17, 2026
8f4cee1
Update vllm/engine/arg_utils.py
ctao456 Apr 17, 2026
2905cc0
Merge branch 'vllm-project:main' into feature/turboquant-yoco-sliding…
ctao456 Apr 23, 2026
5974fb7
Safely remove head dim > 256 layer skipping for TQ, as latest vllm_xp…
ctao456 Apr 23, 2026
4df87c3
Merge remote-tracking branch 'upstream/main' into feature/turboquant-…
ctao456 Apr 23, 2026
55a4913
Arrange layer type variable init in vlllm engine
ctao456 Apr 23, 2026
dc4d856
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 23, 2026
90df2b0
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
713d1bd
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
d706819
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
0e30d21
Update attention.py
ctao456 Apr 24, 2026
6223fab
Moved the skipping layer logic for TQ to a static method in config.py
ctao456 Apr 24, 2026
f5c03e7
Proposing new page size grouping spec, now checks whether page sizes …
ctao456 Apr 24, 2026
cc48bc6
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
4465986
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
d0539d8
Fix mypy type error and apply ruff formatting
ctao456 Apr 24, 2026
15036ec
After unify_hybrid_kv_cache_specs succeeds, re-check page sizes. If t…
ctao456 Apr 24, 2026
2ffb99a
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 24, 2026
8626962
removed the page-size re-check
ctao456 Apr 25, 2026
d0c0189
Test now expects UniformTypeKVCacheSpecs grouping for head_size=64 + …
ctao456 Apr 25, 2026
02ed44e
Pre-commit ruff format fix
ctao456 Apr 25, 2026
1a7c5fd
Fix sliding_window=1 in CI test case
ctao456 Apr 25, 2026
c96b180
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 25, 2026
c414659
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 27, 2026
2f3d34c
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 27, 2026
1934291
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 Apr 28, 2026
279ae83
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 4, 2026
80a74b7
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 4, 2026
1cf334e
Merge upstream main and resolve conflicts
ctao456 May 6, 2026
f596700
Fix Pre-commit
ctao456 May 6, 2026
306296f
Fix one line for pre-commit
ctao456 May 6, 2026
0b28a7a
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 6, 2026
2d7e2a3
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 7, 2026
8202145
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 7, 2026
68afe62
Merge branch 'main' into feature/turboquant-yoco-sliding-window
ctao456 May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,16 +1742,41 @@ def test_get_kv_cache_config_one_worker():
],
)

# different hidden size that cannot be aligned by using different block size
# different hidden size and different type: the page-size guard converts
# SlidingWindowSpec → FullAttentionSpec, then UniformTypeKVCacheSpecs
# handles the two FullAttentionSpecs with different head sizes.
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=96),
}

with pytest.raises(NotImplementedError):
get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
)[0]
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config,
[kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32],
)[0]
expected_specs = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_kv_cache_spec(head_size=96, sliding_window=1),
}
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=25,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 25,
shared_by=["layer_1"],
),
KVCacheTensor(
size=new_kv_cache_spec(head_size=96).page_size_bytes * 25,
shared_by=["layer_2"],
),
],
kv_cache_groups=[
KVCacheGroupSpec(
["layer_1", "layer_2"],
UniformTypeKVCacheSpecs(block_size=16, kv_cache_specs=expected_specs),
),
],
)

# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
Expand Down
6 changes: 4 additions & 2 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,10 @@ def __post_init__(

if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None
# can be correctly capped to sliding window size.
# Use object.__setattr__ to bypass huggingface_hub strict
# dataclass validation which rejects None for int-typed fields.
object.__setattr__(self.hf_text_config, "sliding_window", None)

# Avoid running try_verify_and_update_config multiple times
self.config_updated = False
Expand Down
18 changes: 16 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,10 +1704,24 @@ def create_engine_config(
TurboQuantConfig,
)

num_layers = model_config.hf_text_config.num_hidden_layers
boundary = TurboQuantConfig.get_boundary_skip_layers(model_config)
existing = set(cache_config.kv_cache_dtype_skip_layers)
cache_config.kv_cache_dtype_skip_layers = sorted(
existing | set(boundary), key=int
merged = sorted(existing | set(boundary), key=lambda x: int(x))

hf_cfg = model_config.hf_text_config
merged = TurboQuantConfig.apply_yoco_skip_alignment(
merged=merged,
num_layers=num_layers,
layer_types=getattr(hf_cfg, "layer_types", None) or [],
num_kv_shared=getattr(hf_cfg, "num_kv_shared_layers", 0),
)

cache_config.kv_cache_dtype_skip_layers = merged
logger.info(
"TQ: skipping layers %s for boundary protection (num_layers=%d)",
merged,
num_layers,
)

ray_runtime_env = None
Expand Down
29 changes: 15 additions & 14 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,20 +541,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
quant_mode = get_kv_quant_mode(self.kv_cache_dtype)
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
sliding_window=self.sliding_window,
)
elif self.kv_cache_dtype.startswith("turboquant_"):
if self.kv_cache_dtype.startswith("turboquant_"):
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
Expand All @@ -570,6 +557,20 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
head_size_v=self.head_size,
dtype=self.kv_cache_torch_dtype,
tq_slot_size=tq_config.slot_size_aligned,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
Expand Down
76 changes: 76 additions & 0 deletions vllm/model_executor/layers/quantization/turboquant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,82 @@ def get_boundary_skip_layers(
indices = sorted(set(first + last))
return [str(i) for i in indices]

@staticmethod
def apply_yoco_skip_alignment(
merged: list[str],
num_layers: int,
layer_types: list,
num_kv_shared: int,
) -> list[str]:
"""Align the TQ skip list for YOCO (You Only Cache Once) architectures.

KV-shared layers reuse their target's cache tensor, so the
kv_cache_dtype of a shared layer MUST match its target's.
This method:
1. Skips all KV-sharing target layers (to prevent quantization
error amplification across every consumer layer).
2. Propagates the skip/no-skip decision from each target to its
corresponding shared layers so the layouts stay compatible.

Args:
merged: Current sorted skip-layer list (strings of layer indices).
num_layers: Total number of hidden layers.
layer_types: Per-layer type list from hf_text_config.layer_types.
num_kv_shared: Number of KV-sharing layers
(hf_text_config.num_kv_shared_layers).

Returns:
Updated sorted skip-layer list as strings.
"""
import logging

_logger = logging.getLogger(__name__)

if num_kv_shared <= 0 or not layer_types:
return merged

first_shared = num_layers - num_kv_shared
skip_set = set(merged)

# 1) Find all unique KV-sharing target layers and skip them
# to prevent error amplification through YOCO sharing.
target_set: set[str] = set()
for shared_idx in range(first_shared, num_layers):
current_type = layer_types[shared_idx]
for t in range(first_shared - 1, -1, -1):
if layer_types[t] == current_type:
target_set.add(str(t))
break
new_targets = target_set - skip_set
if new_targets:
skip_set |= new_targets
_logger.info(
"TQ: skipping KV-sharing target layers %s to "
"prevent error amplification in YOCO architecture",
sorted(new_targets, key=lambda x: int(x)),
)

# 2) Propagate skip/no-skip from target → shared layer.
for shared_idx in range(first_shared, num_layers):
current_type = layer_types[shared_idx]
target_idx = None
for t in range(first_shared - 1, -1, -1):
if layer_types[t] == current_type:
target_idx = t
break
if target_idx is None:
continue
target_skipped = str(target_idx) in skip_set
shared_skipped = str(shared_idx) in skip_set
if target_skipped and not shared_skipped:
skip_set.add(str(shared_idx))
elif not target_skipped and shared_skipped:
skip_set.discard(str(shared_idx))

result = sorted(skip_set, key=lambda x: int(x))
_logger.info("TQ: after KV-sharing alignment, skip list: %s", result)
return result

@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig:
"""Create config from a named preset.
Expand Down
47 changes: 39 additions & 8 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def __init__(
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
# window_size for flash_attn: [left, right]
if sliding_window is None:
self._fa_window_size: list[int] = [-1, -1]
else:
self._fa_window_size = [sliding_window - 1, 0]

from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
Expand Down Expand Up @@ -312,6 +318,7 @@ def _flash_attn_varlen(
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
window_size=self._fa_window_size,
)
return flash_attn_varlen_func(
q=q,
Expand All @@ -323,6 +330,7 @@ def _flash_attn_varlen(
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
window_size=self._fa_window_size,
fa_version=self.fa_version,
)

Expand Down Expand Up @@ -627,14 +635,31 @@ def _prefill_attention(
q_t = q_seq.transpose(0, 1).contiguous()
k_t = k_seq.transpose(0, 1).contiguous()
v_t = v_seq.transpose(0, 1).contiguous()
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
is_causal=True,
scale=self.scale,
enable_gqa=use_gqa,
).transpose(0, 1)
# Build sliding-window causal mask if needed
sw = self.sliding_window
if sw is not None:
q_pos = torch.arange(q_len, device=query.device)
k_pos = torch.arange(q_len, device=query.device)
mask = (k_pos.unsqueeze(0) <= q_pos.unsqueeze(1)) & (
q_pos.unsqueeze(1) - k_pos.unsqueeze(0) < sw
)
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=mask,
scale=self.scale,
enable_gqa=use_gqa,
).transpose(0, 1)
else:
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
is_causal=True,
scale=self.scale,
enable_gqa=use_gqa,
).transpose(0, 1)
output[q_start:q_end] = out.to(query.dtype)
else:
# Continuation chunk: tokens already stored to TQ cache
Expand Down Expand Up @@ -662,6 +687,7 @@ def _prefill_attention(
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
sliding_window=self.sliding_window,
)
else:
# Large continuation: dequant cached K/V and use
Expand Down Expand Up @@ -814,6 +840,10 @@ def _continuation_prefill(
q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
mask = k_pos <= q_pos # (q_len, seq_len)
# Apply sliding window constraint
sw = self.sliding_window
if sw is not None:
mask = mask & (q_pos - k_pos < sw)
out = F.scaled_dot_product_attention(
q_t,
k_t,
Expand Down Expand Up @@ -874,5 +904,6 @@ def _decode_attention(
lse_buf=lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
sliding_window=self.sliding_window,
)
return result
17 changes: 14 additions & 3 deletions vllm/v1/attention/ops/triton_turboquant_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _tq_decode_stage1(
KEY_FP8: tl.constexpr, # 1 if K is stored as FP8
NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
SLIDING_WINDOW: tl.constexpr = 0, # 0 = full attention, >0 = window size
):
bid = tl.program_id(0) # batch index
hid = tl.program_id(1) # q_head index
Expand All @@ -93,9 +94,16 @@ def _tq_decode_stage1(
# Sequence length for this batch
seq_len = tl.load(Seq_lens_ptr + bid)

# KV split range
split_len = tl.cdiv(seq_len, NUM_KV_SPLITS)
split_start = split_len * sid
# Sliding window: only attend to the last SLIDING_WINDOW tokens
if SLIDING_WINDOW > 0:
effective_start = tl.maximum(0, seq_len - SLIDING_WINDOW)
else:
effective_start = 0
effective_len = seq_len - effective_start

# KV split range (over the effective window only)
split_len = tl.cdiv(effective_len, NUM_KV_SPLITS)
split_start = effective_start + split_len * sid
split_end = tl.minimum(split_start + split_len, seq_len)

if split_start >= split_end:
Expand Down Expand Up @@ -503,6 +511,7 @@ def triton_turboquant_decode_attention(
lse_buf: torch.Tensor | None = None,
buf_holder: Any = None,
max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph)
sliding_window: int | None = None,
) -> torch.Tensor:
"""Launch fused TQ decode attention (Triton stage1 + stage2).

Expand Down Expand Up @@ -550,6 +559,7 @@ def triton_turboquant_decode_attention(
# Stage 1: split-KV tiled attention scoring + value accumulation
fp8_e4b15 = _use_fp8_e4b15(device.index or 0)
BLOCK_KV = 4
_sliding_window = sliding_window if sliding_window is not None else 0
grid = (B, Hq, NUM_KV_SPLITS)
_tq_decode_stage1[grid](
q_rot,
Expand Down Expand Up @@ -583,6 +593,7 @@ def triton_turboquant_decode_attention(
KEY_FP8=1 if key_fp8 else 0,
NORM_CORRECTION=1 if norm_correction else 0,
FP8_E4B15=fp8_e4b15,
SLIDING_WINDOW=_sliding_window,
num_warps=1,
num_stages=1,
)
Expand Down
Loading
Loading