diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md
index 9d41185f1ae1..5c99eb20bc7d 100644
--- a/docs/features/reasoning_outputs.md
+++ b/docs/features/reasoning_outputs.md
@@ -249,7 +249,7 @@ Token counting starts from `reasoning_start_str`. Once the reasoning token count
To use this feature:
- `--reasoning-parser` enables reasoning extraction.
-- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
+- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`). If not set, vLLM will attempt to automatically initialize these tokens from the reasoning parser.
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py
similarity index 72%
rename from tests/v1/entrypoints/openai/test_thinking_token_budget.py
rename to tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py
index b3f6d53ab1d4..d2db50082a55 100644
--- a/tests/v1/entrypoints/openai/test_thinking_token_budget.py
+++ b/tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py
@@ -24,6 +24,24 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
+ "--gpu-memory-utilization",
+ "0.4",
+ "--no-async-scheduling",
+ ]
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ yield remote_server
+
+
+@pytest.fixture(scope="module")
+def server_with_auto_reasoning_config():
+ args = [
+ "--reasoning-parser",
+ "qwen3",
+ "--max-model-len",
+ "2048",
+ "--enforce-eager",
+ "--gpu-memory-utilization",
+ "0.4",
"--no-async-scheduling",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -31,12 +49,18 @@ def server():
@pytest_asyncio.fixture
-async def client(server):
- async with server.get_async_client() as async_client:
+async def client(request, server, server_with_auto_reasoning_config):
+ server_map = {
+ "default": server,
+ "auto_config": server_with_auto_reasoning_config,
+ }
+ target_server = server_map[request.param]
+ async with target_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
+@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True)
async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
"""Test that mixed requests (some with thinking_token_budget, some without)
complete successfully without errors."""
@@ -61,6 +85,7 @@ async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
+@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True)
async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI):
"""Test that thinking_token_budget limits the number of reasoning tokens.
@@ -82,6 +107,6 @@ async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI
reasoning_token_count += 1
assert reasoning_token_count == THINK_BUDGET, (
- f"reasoning tokens ({reasoning_token_count}) != "
+ f"reasoning tokens ({reasoning_token_count}) exceeded "
f"thinking_token_budget ({THINK_BUDGET})"
)
diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py
index 3ebf9cc3713a..3da49f3f7574 100644
--- a/tests/kernels/attention/test_attention_selector.py
+++ b/tests/kernels/attention/test_attention_selector.py
@@ -368,6 +368,23 @@ def test_auto_backend_selection_behavior():
assert backend_auto.get_name() == backend_none.get_name()
+def test_flash_attn_rejects_int4_kv_cache(monkeypatch: pytest.MonkeyPatch):
+ try:
+ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
+ except ImportError:
+ pytest.skip("vllm_flash_attn extension is not available in this env")
+
+ monkeypatch.setattr(
+ "vllm.v1.attention.backends.flash_attn.flash_attn_supports_fp8",
+ lambda: True,
+ )
+
+ assert FlashAttentionBackend.supports_kv_cache_dtype("fp8")
+ assert not FlashAttentionBackend.supports_kv_cache_dtype(
+ "int4_per_token_head"
+ )
+
+
@pytest.mark.parametrize(
"backend_name,flash_attn_version,should_succeed",
[
diff --git a/tests/kernels/attention/test_triton_int4_kv_cache.py b/tests/kernels/attention/test_triton_int4_kv_cache.py
new file mode 100644
index 000000000000..50d56cc47125
--- /dev/null
+++ b/tests/kernels/attention/test_triton_int4_kv_cache.py
@@ -0,0 +1,391 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+import torch
+
+from vllm.platforms import current_platform
+from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.attention.backend import AttentionType
+from vllm.v1.attention.backends.triton_attn import (
+ TritonAttentionBackend,
+ TritonAttentionImpl,
+)
+from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
+ triton_reshape_and_cache_flash,
+)
+from vllm.v1.attention.ops.triton_unified_attention import unified_attention
+from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode, get_kv_quant_mode
+
+
+def _expand_scale(scale: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ return scale.unsqueeze(-1) if scale.ndim == x.ndim - 1 else scale
+
+
+def _quantize_int4(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ return torch.round(x.float() / _expand_scale(scale, x)).clamp(-8, 7).to(
+ torch.int32
+ )
+
+
+def _dequantize_int4_values(q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ return q.to(torch.float32) * _expand_scale(scale, q)
+
+
+def _pack_int4(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ q = _quantize_int4(x, scale)
+ q_u4 = torch.where(q < 0, q + 16, q).to(torch.uint8)
+ if q_u4.shape[-1] % 2:
+ q_u4 = torch.nn.functional.pad(q_u4, (0, 1))
+ return q_u4[..., 0::2] | (q_u4[..., 1::2] << 4)
+
+
+def _dequantize_packed_int4(
+ packed: torch.Tensor,
+ scale: torch.Tensor,
+ head_size: int,
+) -> torch.Tensor:
+ low = (packed & 0x0F).to(torch.int16)
+ high = (packed >> 4).to(torch.int16)
+ low = torch.where(low >= 8, low - 16, low)
+ high = torch.where(high >= 8, high - 16, high)
+
+ out = torch.empty(
+ *packed.shape[:-1],
+ packed.shape[-1] * 2,
+ device=packed.device,
+ dtype=torch.float32,
+ )
+ scale_expanded = _expand_scale(scale.to(torch.float32), out[..., 0::2])
+ out[..., 0::2] = low.to(torch.float32) * scale_expanded
+ out[..., 1::2] = high.to(torch.float32) * scale_expanded
+ return out[..., :head_size]
+
+
+def _int4_scales_per_token_head(x: torch.Tensor) -> torch.Tensor:
+ return x.abs().amax(dim=-1).float().clamp(min=1e-6) / 7.0
+
+
+def _get_int4_inline_scale_views(
+ kv_cache: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ key_cache, value_cache = kv_cache.unbind(1)
+ _, block_size, num_heads, padded_hs = key_cache.shape
+ scale_pad = 4
+ data_hs_padded = padded_hs - scale_pad
+ assert data_hs_padded % 4 == 0
+
+ raw = kv_cache.untyped_storage()
+ base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(raw)
+
+ kv_half_bytes = block_size * num_heads * padded_hs
+ full_block_f32 = 2 * kv_half_bytes // 4
+ slot_f32 = num_heads * padded_hs // 4
+ head_f32 = padded_hs // 4
+ scale_off_f32 = data_hs_padded // 4
+
+ k_scale_cache = torch.as_strided(
+ base_f32,
+ size=(kv_cache.shape[0], block_size, num_heads),
+ stride=(full_block_f32, slot_f32, head_f32),
+ storage_offset=scale_off_f32,
+ )
+ v_scale_cache = torch.as_strided(
+ base_f32,
+ size=(kv_cache.shape[0], block_size, num_heads),
+ stride=(full_block_f32, slot_f32, head_f32),
+ storage_offset=(kv_half_bytes // 4) + scale_off_f32,
+ )
+ k_scale_cache.fill_(1.0)
+ v_scale_cache.fill_(1.0)
+ return key_cache, value_cache, k_scale_cache, v_scale_cache
+
+
+def _ref_paged_attn(
+ query: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ query_lens: list[int],
+ kv_lens: list[int],
+ block_tables: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ outputs: list[torch.Tensor] = []
+ start_idx = 0
+ block_tables = block_tables.cpu().numpy()
+ _, block_size, num_kv_heads, head_size = key_cache.shape
+
+ for i, query_len in enumerate(query_lens):
+ kv_len = kv_lens[i]
+ q = query[start_idx : start_idx + query_len] * scale
+ num_kv_blocks = (kv_len + block_size - 1) // block_size
+ block_indices = block_tables[i, :num_kv_blocks]
+
+ k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len]
+ v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len]
+ if q.shape[1] != k.shape[1]:
+ k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
+ v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
+
+ scores = torch.einsum("qhd,khd->hqk", q, k).float()
+ mask = torch.triu(
+ torch.ones(query_len, kv_len, device=q.device),
+ diagonal=kv_len - query_len + 1,
+ ).bool()
+ scores.masked_fill_(mask, float("-inf"))
+ probs = torch.softmax(scores, dim=-1).to(v.dtype)
+ outputs.append(torch.einsum("hqk,khd->qhd", probs, v))
+ start_idx += query_len
+
+ return torch.cat(outputs, dim=0)
+
+
+def test_int4_kv_cache_shape_and_page_size():
+ assert get_kv_quant_mode("int4_per_token_head") == KVQuantMode.INT4_PER_TOKEN_HEAD
+ assert TritonAttentionBackend.get_kv_cache_shape(
+ 8, 16, 4, 128, "int4_per_token_head"
+ ) == (
+ 8,
+ 2,
+ 16,
+ 4,
+ 68,
+ )
+
+ spec = FullAttentionSpec(
+ block_size=16,
+ num_kv_heads=4,
+ head_size=128,
+ dtype=torch.uint8,
+ kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD,
+ )
+ assert spec.page_size_bytes == 16 * 4 * (68 + 68)
+
+
+@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA")
+@torch.inference_mode()
+def test_int4_inline_scale_views_respect_padded_block_stride():
+ device = "cuda"
+ num_blocks = 3
+ block_size = 16
+ num_kv_heads = 4
+ head_size = 128
+
+ kv_cache_shape = TritonAttentionBackend.get_kv_cache_shape(
+ num_blocks, block_size, num_kv_heads, head_size, "int4_per_token_head"
+ )
+ contiguous_block_elems = kv_cache_shape[1] * kv_cache_shape[2] * kv_cache_shape[3] * kv_cache_shape[4]
+ padded_block_elems = contiguous_block_elems + 128
+ raw = torch.zeros(num_blocks * padded_block_elems, dtype=torch.uint8, device=device)
+
+ half_stride = kv_cache_shape[2] * kv_cache_shape[3] * kv_cache_shape[4]
+ slot_stride = kv_cache_shape[3] * kv_cache_shape[4]
+ head_stride = kv_cache_shape[4]
+ kv_cache = torch.as_strided(
+ raw,
+ size=kv_cache_shape,
+ stride=(padded_block_elems, half_stride, slot_stride, head_stride, 1),
+ )
+
+ impl = TritonAttentionImpl(
+ num_heads=num_kv_heads,
+ head_size=head_size,
+ scale=head_size**-0.5,
+ num_kv_heads=num_kv_heads,
+ alibi_slopes=None,
+ sliding_window=None,
+ kv_cache_dtype="int4_per_token_head",
+ attn_type=AttentionType.DECODER,
+ )
+ impl._ensure_scale_caches(kv_cache)
+ assert impl._k_scale_cache is not None
+ assert impl._v_scale_cache is not None
+
+ for blk in range(num_blocks):
+ impl._k_scale_cache[blk].fill_(10.0 + blk)
+ impl._v_scale_cache[blk].fill_(20.0 + blk)
+
+ base_f32 = torch.tensor([], dtype=torch.float32, device=device).set_(kv_cache.untyped_storage())
+ block_f32 = padded_block_elems // 4
+ half_f32 = half_stride // 4
+ slot_f32 = slot_stride // 4
+ head_f32 = head_stride // 4
+ scale_off_f32 = (kv_cache_shape[-1] - 4) // 4
+
+ for blk in range(num_blocks):
+ for slot in (0, block_size - 1):
+ for head in (0, num_kv_heads - 1):
+ k_idx = blk * block_f32 + slot * slot_f32 + head * head_f32 + scale_off_f32
+ v_idx = blk * block_f32 + half_f32 + slot * slot_f32 + head * head_f32 + scale_off_f32
+ assert base_f32[k_idx].item() == pytest.approx(10.0 + blk)
+ assert base_f32[v_idx].item() == pytest.approx(20.0 + blk)
+
+
+@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA")
+@torch.inference_mode()
+def test_triton_reshape_and_cache_flash_int4():
+ device = "cuda"
+ set_random_seed(0)
+ torch.set_default_device(device)
+
+ num_tokens = 23
+ num_heads = 4
+ head_size = 97
+ block_size = 16
+ num_blocks = 8
+ packed_head_size = (head_size + 1) // 2
+ packed_head_size_padded = ((packed_head_size + 3) // 4) * 4
+
+ key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
+ value = torch.randn_like(key)
+ slot_mapping = torch.randperm(num_blocks * block_size, device=device)[:num_tokens]
+
+ kv_cache = torch.zeros(
+ num_blocks,
+ 2,
+ block_size,
+ num_heads,
+ packed_head_size_padded + 4,
+ dtype=torch.uint8,
+ device=device,
+ )
+ key_cache, value_cache, k_scale_cache, v_scale_cache = _get_int4_inline_scale_views(
+ kv_cache
+ )
+
+ triton_reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slot_mapping,
+ "int4_per_token_head",
+ torch.tensor(1.0, device=device),
+ torch.tensor(1.0, device=device),
+ k_scale_cache,
+ v_scale_cache,
+ )
+
+ got_key = _dequantize_packed_int4(
+ key_cache[..., :packed_head_size], k_scale_cache, head_size
+ )
+ got_value = _dequantize_packed_int4(
+ value_cache[..., :packed_head_size], v_scale_cache, head_size
+ )
+
+ expected_key = torch.zeros(
+ num_blocks, block_size, num_heads, head_size, device=device, dtype=torch.float32
+ )
+ expected_value = torch.zeros_like(expected_key)
+ expected_k_scales = torch.ones(
+ num_blocks, block_size, num_heads, device=device, dtype=torch.float32
+ )
+ expected_v_scales = torch.ones_like(expected_k_scales)
+ key_scales = _int4_scales_per_token_head(key)
+ value_scales = _int4_scales_per_token_head(value)
+ key_roundtrip = _dequantize_int4_values(
+ _quantize_int4(key, key_scales), key_scales
+ )
+ value_roundtrip = _dequantize_int4_values(
+ _quantize_int4(value, value_scales), value_scales
+ )
+
+ for token_idx, slot in enumerate(slot_mapping.cpu().tolist()):
+ block_idx = slot // block_size
+ block_offset = slot % block_size
+ expected_key[block_idx, block_offset] = key_roundtrip[token_idx]
+ expected_value[block_idx, block_offset] = value_roundtrip[token_idx]
+ expected_k_scales[block_idx, block_offset] = key_scales[token_idx]
+ expected_v_scales[block_idx, block_offset] = value_scales[token_idx]
+
+ atol = max(float(key_scales.max().item()), float(value_scales.max().item())) + 1e-4
+ torch.testing.assert_close(got_key, expected_key, atol=atol, rtol=0.0)
+ torch.testing.assert_close(got_value, expected_value, atol=atol, rtol=0.0)
+ torch.testing.assert_close(k_scale_cache, expected_k_scales, atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(v_scale_cache, expected_v_scales, atol=1e-4, rtol=1e-4)
+
+
+@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA")
+@torch.inference_mode()
+def test_triton_unified_attention_int4():
+ device = "cuda"
+ set_random_seed(1)
+ torch.set_default_device(device)
+
+ query_lens = [1, 5, 3]
+ kv_lens = [33, 18, 27]
+ num_blocks = 64
+ block_size = 16
+ num_query_heads = 4
+ num_kv_heads = 2
+ head_size = 128
+
+ query = torch.randn(
+ sum(query_lens),
+ num_query_heads,
+ head_size,
+ dtype=torch.bfloat16,
+ device=device,
+ )
+ key_cache = torch.randn(
+ num_blocks,
+ block_size,
+ num_kv_heads,
+ head_size,
+ dtype=torch.bfloat16,
+ device=device,
+ )
+ value_cache = torch.randn_like(key_cache)
+
+ k_scale = _int4_scales_per_token_head(key_cache)
+ v_scale = _int4_scales_per_token_head(value_cache)
+ key_cache_int4 = _pack_int4(key_cache, k_scale)
+ value_cache_int4 = _pack_int4(value_cache, v_scale)
+
+ cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32, device=device)
+ cu_query_lens = cu_query_lens.cumsum(dim=0, dtype=torch.int32)
+ kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32, device=device)
+
+ max_kv_len = max(kv_lens)
+ max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
+ block_tables = torch.randint(
+ 0,
+ num_blocks,
+ (len(query_lens), max_num_blocks_per_seq),
+ dtype=torch.int32,
+ device=device,
+ )
+
+ output = torch.empty_like(query)
+ unified_attention(
+ q=query,
+ k=key_cache_int4,
+ v=value_cache_int4,
+ out=output,
+ cu_seqlens_q=cu_query_lens,
+ seqused_k=kv_lens_tensor,
+ max_seqlen_q=max(query_lens),
+ max_seqlen_k=max_kv_len,
+ softmax_scale=head_size**-0.5,
+ causal=True,
+ window_size=(-1, -1),
+ block_table=block_tables,
+ softcap=0.0,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD,
+ k_scale_cache=k_scale,
+ v_scale_cache=v_scale,
+ )
+
+ ref_output = _ref_paged_attn(
+ query=query.float(),
+ key_cache=_dequantize_packed_int4(key_cache_int4, k_scale, head_size),
+ value_cache=_dequantize_packed_int4(value_cache_int4, v_scale, head_size),
+ query_lens=query_lens,
+ kv_lens=kv_lens,
+ block_tables=block_tables,
+ scale=head_size**-0.5,
+ )
+ torch.testing.assert_close(output.float(), ref_output.float(), atol=2e-2, rtol=2e-2)
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index d8ecf28cbed1..fafb2952d51a 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -24,6 +24,7 @@
BlockHash,
FreeKVCacheBlockQueue,
KVCacheBlock,
+ estimate_token_capacity_for_kv_cache_config,
estimate_max_model_len,
generate_block_hash_extra_keys,
generate_scheduler_kv_cache_config,
@@ -43,6 +44,7 @@
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
+ KVQuantMode,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
@@ -107,6 +109,7 @@ def new_kv_cache_spec(
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
+ kv_quant_mode=KVQuantMode.NONE,
page_size_padded=None,
sliding_window=None,
attention_chunk_size=None,
@@ -116,6 +119,7 @@ def new_kv_cache_spec(
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
+ kv_quant_mode=kv_quant_mode,
page_size_padded=page_size_padded,
sliding_window=sliding_window,
attention_chunk_size=attention_chunk_size,
@@ -127,6 +131,7 @@ def new_sliding_window_spec(
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
+ kv_quant_mode=KVQuantMode.NONE,
page_size_padded=None,
sliding_window=1,
):
@@ -135,6 +140,7 @@ def new_sliding_window_spec(
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
+ kv_quant_mode=kv_quant_mode,
page_size_padded=page_size_padded,
sliding_window=sliding_window,
)
@@ -1418,6 +1424,73 @@ def test_get_max_concurrency_for_kv_cache_config():
assert max_concurrency_hybrid_model == 3
+def test_estimate_token_capacity_for_kv_cache_config_reports_total_cached_tokens():
+ model_id = "Qwen/Qwen1.5-7B"
+ max_model_len = 32768
+ model_config = ModelConfig(
+ model_id,
+ runner="generate",
+ dtype="float16",
+ max_model_len=max_model_len,
+ )
+ scheduler_config = SchedulerConfig(
+ max_num_batched_tokens=1024,
+ enable_chunked_prefill=True,
+ max_model_len=model_config.max_model_len,
+ is_encoder_decoder=model_config.is_encoder_decoder,
+ )
+ vllm_config = VllmConfig(
+ model_config=model_config,
+ scheduler_config=scheduler_config,
+ )
+
+ full_attention_spec = FullAttentionSpec(
+ block_size=16,
+ num_kv_heads=4,
+ head_size=128,
+ dtype=torch.uint8,
+ kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD,
+ )
+ sliding_window_spec = SlidingWindowSpec(
+ block_size=16,
+ num_kv_heads=4,
+ head_size=128,
+ dtype=torch.uint8,
+ kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD,
+ sliding_window=1024,
+ )
+
+ num_blocks = (1024 + 64) * 3
+ group_size = 32
+ page_size = full_attention_spec.page_size_bytes
+ kv_cache_config = KVCacheConfig(
+ num_blocks=num_blocks,
+ kv_cache_tensors=[
+ KVCacheTensor(size=page_size * num_blocks, shared_by=[])
+ for _ in range(group_size)
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec([f"full_{i}" for i in range(group_size)], full_attention_spec),
+ KVCacheGroupSpec(
+ [f"sw_{i}" for i in range(group_size)],
+ sliding_window_spec,
+ ),
+ ],
+ )
+
+ expected_capacity = int(
+ get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config)
+ * vllm_config.model_config.max_model_len
+ )
+
+ assert (
+ estimate_token_capacity_for_kv_cache_config(vllm_config, kv_cache_config)
+ == expected_capacity
+ )
+
+ assert expected_capacity > vllm_config.model_config.max_model_len
+
+
def test_allocate_with_lookahead():
"""Verify that lookahead tokens correctly affect block allocation"""
block_size = 4
@@ -1752,6 +1825,30 @@ def test_get_kv_cache_config_one_worker():
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
)[0]
+ # The same non-divisible hybrid layout can be supported for runtime-scale
+ # quantized KV cache by padding the smaller page up to the larger one.
+ kv_cache_specs_hybrid_runtime_quant = {
+ "layer_1": new_kv_cache_spec(
+ head_size=64, dtype=torch.uint8, kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD
+ ),
+ "layer_2": new_sliding_window_spec(
+ head_size=96, dtype=torch.uint8, kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD
+ ),
+ }
+ kv_cache_config_hybrid_runtime_quant = get_kv_cache_configs(
+ vllm_config,
+ [kv_cache_specs_hybrid_runtime_quant],
+ [mem_per_block_per_layer * 2 * 32],
+ )[0]
+ assert len(kv_cache_config_hybrid_runtime_quant.kv_cache_groups) == 2
+ assert all(
+ group.kv_cache_spec.block_size == 16
+ for group in kv_cache_config_hybrid_runtime_quant.kv_cache_groups
+ )
+ assert len(
+ {group.kv_cache_spec.page_size_bytes for group in kv_cache_config_hybrid_runtime_quant.kv_cache_groups}
+ ) == 1
+
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs(
diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py
index bf29793710a4..9ee6a70abe4c 100644
--- a/tests/v1/logits_processors/test_correctness.py
+++ b/tests/v1/logits_processors/test_correctness.py
@@ -106,6 +106,7 @@ class MockReasoningConfig:
reasoning_start_token_ids = [THINK_START_TOKEN_ID]
reasoning_end_token_ids = [THINK_END_TOKEN_ID]
+ enabled = True
def _generate_fake_sampling_metadata(
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index d59b74782be2..8c2659b9c7e6 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -336,9 +336,13 @@ def _rocm_aiter_fused_topk_fake(
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
-) -> None:
- # tuple[torch.Tensor, torch.Tensor]:
- pass
+) -> tuple[torch.Tensor, torch.Tensor]:
+ num_tokens = x.shape[0]
+ topk_weights = torch.empty(
+ (num_tokens, top_k), dtype=torch.float32, device=x.device
+ )
+ topk_indices = torch.empty((num_tokens, top_k), dtype=torch.int32, device=x.device)
+ return topk_weights, topk_indices
# Cache whether aiter supports FP8 MLA parameters
@@ -1918,7 +1922,7 @@ def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
@staticmethod
def shuffle_weight(
- self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
+ tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
index cd1554590ea3..e1d5e3363492 100644
--- a/vllm/config/cache.py
+++ b/vllm/config/cache.py
@@ -11,6 +11,7 @@
from vllm.utils.torch_utils import (
is_quantized_kv_cache,
kv_cache_uses_per_token_head_scales,
+ kv_cache_uses_runtime_scale_cache,
)
logger = init_logger(__name__)
@@ -19,6 +20,7 @@
"auto",
"float16",
"bfloat16",
+ "int4_per_token_head",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
@@ -242,11 +244,11 @@ def _warn_deprecated_calculate_kv_scales(cls, calculate_kv_scales: bool) -> bool
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
- if kv_cache_uses_per_token_head_scales(cache_dtype):
+ if kv_cache_uses_runtime_scale_cache(cache_dtype):
logger.info(
"Using %s data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
- "Dynamic per-token-head scales will be computed at runtime.",
+ "Runtime scales will be computed dynamically.",
str(cache_dtype),
)
elif is_quantized_kv_cache(cache_dtype):
diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py
index be1e2b6da58d..ff5546e05ebf 100644
--- a/vllm/config/reasoning.py
+++ b/vllm/config/reasoning.py
@@ -5,6 +5,7 @@
from vllm.config.model import ModelConfig
from vllm.config.utils import config
+from vllm.reasoning import ReasoningParserManager
from vllm.tokenizers import cached_tokenizer_from_config
@@ -18,11 +19,11 @@ class ReasoningConfig:
`initialize_token_ids` and are not intended to be set directly.
"""
- # NOTE: These parameters are temporary, the intent is to derive them
- # automatically from the reasoning parser in a future version.
- reasoning_start_str: str = ""
+ reasoning_parser: str = ""
+ """The name of the ReasoningParser to use for this model."""
+ reasoning_start_str: str = ""
"""String that indicates the start of reasoning."""
- reasoning_end_str: str = ""
+ reasoning_end_str: str = ""
"""String that indicates the end of reasoning content."""
_reasoning_start_token_ids: list[int] | None = field(
@@ -36,6 +37,16 @@ class ReasoningConfig:
"""Private backing field for `reasoning_end_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
+ _enabled: bool = field(default=False, init=False, repr=False)
+ """Private field indicating whether reasoning token IDs have been initialized.
+ Set to True by `initialize_token_ids` once token IDs are initialized."""
+
+ @property
+ def enabled(self) -> bool:
+ """Returns True if reasoning is enabled (i.e. if token IDs have been
+ initialized), False otherwise."""
+ return self._enabled
+
@property
def reasoning_start_token_ids(self) -> list[int] | None:
"""Token IDs derived from `reasoning_start_str`. Set automatically by
@@ -54,15 +65,36 @@ def initialize_token_ids(self, model_config: ModelConfig) -> None:
self._reasoning_start_token_ids is not None
and self._reasoning_end_token_ids is not None
):
- return
+ self._enabled = True
+ return # Already initialized
tokenizer = cached_tokenizer_from_config(model_config=model_config)
+ reasoning_start_str = self.reasoning_start_str
+ reasoning_end_str = self.reasoning_end_str
+ if self.reasoning_parser is not None and (
+ not reasoning_start_str or not reasoning_end_str
+ ):
+ parser_cls = ReasoningParserManager.get_reasoning_parser(
+ self.reasoning_parser
+ )
+ reasoning_parser = parser_cls(tokenizer)
+ start_token = reasoning_parser.reasoning_start_str
+ if start_token and not reasoning_start_str:
+ reasoning_start_str = start_token
+ end_token = reasoning_parser.reasoning_end_str
+ if end_token and not reasoning_end_str:
+ reasoning_end_str = end_token
+
+ if not reasoning_start_str or not reasoning_end_str:
+ # If we don't have valid strings to tokenize,
+ # we can't initialize the token IDs.
+ return
self._reasoning_start_token_ids = tokenizer.encode(
- self.reasoning_start_str, add_special_tokens=False
+ reasoning_start_str, add_special_tokens=False
)
self._reasoning_end_token_ids = tokenizer.encode(
- self.reasoning_end_str, add_special_tokens=False
+ reasoning_end_str, add_special_tokens=False
)
if not self._reasoning_start_token_ids or not self._reasoning_end_token_ids:
@@ -72,3 +104,4 @@ def initialize_token_ids(self, model_config: ModelConfig) -> None:
f"reasoning_end_str='{self.reasoning_end_str}'. "
"Ensure the strings are valid tokens in the model's vocabulary."
)
+ self._enabled = True
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index 6229b44d52a8..3b8431e9530a 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -1210,6 +1210,12 @@ def has_blocked_weights():
if self.reasoning_config is not None and self.model_config is not None:
self.reasoning_config.initialize_token_ids(self.model_config)
+ if not self.reasoning_config.enabled:
+ logger.warning_once(
+ "Auto-initialization of reasoning token IDs failed. "
+ "Please check whether your reasoning parser has implemented "
+ "the `reasoning_start_str` and `reasoning_end_str`."
+ )
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 55c87bf356c5..c9b90848ff04 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1591,7 +1591,7 @@ def create_engine_config(
self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context, model_config
)
-
+ self._set_default_reasoning_config_args()
sliding_window: int | None = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding
@@ -2233,6 +2233,13 @@ def _set_default_chunked_prefill_and_prefix_caching_args(
)
self.enable_prefix_caching = False
+ def _set_default_reasoning_config_args(self):
+ if not self.reasoning_parser:
+ return
+ if self.reasoning_config is None:
+ self.reasoning_config = ReasoningConfig()
+ self.reasoning_config.reasoning_parser = self.reasoning_parser
+
def _set_default_max_num_seqs_and_batched_tokens_args(
self,
usage_context: UsageContext | None,
diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py
index f718c239984b..6a026eef839d 100644
--- a/vllm/model_executor/kernels/linear/__init__.py
+++ b/vllm/model_executor/kernels/linear/__init__.py
@@ -177,6 +177,21 @@
],
}
+_POSSIBLE_WFP8A16_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
+ PlatformEnum.CUDA: [
+ MarlinFP8ScaledMMLinearKernel,
+ ],
+ PlatformEnum.ROCM: [
+ # To be added
+ ],
+ PlatformEnum.CPU: [
+ # To be added
+ ],
+ PlatformEnum.XPU: [
+ # To be added
+ ],
+}
+
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
@@ -463,6 +478,41 @@ def choose_mp_linear_kernel(
)
+def init_wfp8_a16_linear_kernel(
+ weight_quant_key: QuantKey,
+ activation_quant_key: QuantKey,
+ weight_shape: tuple[int, int],
+ input_dtype: torch.dtype,
+ out_dtype: torch.dtype,
+ force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
+ module_name: str | None = None,
+) -> FP8ScaledMMLinearKernel:
+ config = FP8ScaledMMLinearLayerConfig(
+ weight_quant_key=weight_quant_key,
+ activation_quant_key=activation_quant_key,
+ weight_shape=weight_shape,
+ input_dtype=input_dtype,
+ out_dtype=out_dtype,
+ )
+
+ kernel_type = choose_scaled_mm_linear_kernel(
+ config, _POSSIBLE_WFP8A16_KERNELS, force_kernel=force_kernel
+ )
+
+ if module_name:
+ logger.info_once(
+ "Selected %s for %s",
+ kernel_type.__name__,
+ module_name,
+ scope="global",
+ )
+
+ return kernel_type(
+ config,
+ layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
+ )
+
+
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
@@ -588,6 +638,7 @@ def register_linear_kernel(
"init_nvfp4_linear_kernel",
"choose_mp_linear_kernel",
"register_linear_kernel",
+ "init_wfp8_a16_linear_kernel",
"FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel",
diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py
index aec855d9aeb1..3ac90c942e73 100644
--- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py
+++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py
@@ -702,19 +702,33 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
- # All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with
- # T = chunk_size is sufficient to populate every autotuner cache.
+ # All kernels use BT = chunk_size, so a single pass with T = chunk_size
+ # is sufficient to populate every autotuner cache. Mirror the real
+ # prefill path here: build q/k/v/g/beta via fused_post_conv_prep and
+ # then run chunk_gated_delta_rule with in-kernel L2 norm disabled.
T = FLA_CHUNK_SIZE
- q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
- k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
- v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype)
- # NOTE: g and beta must have the same dtypes as during
- # inference, so we construct them with the same function
- # (fused_gdn_gating). dummy_a and dummy_b are throwaway
- # inputs required by that function.
+ dummy_mixed_qkv = torch.randn(
+ T, mixed_qkv.shape[-1], device=device, dtype=dtype
+ )
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
- g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
+ q, k, v, g, beta = fused_post_conv_prep(
+ conv_output=dummy_mixed_qkv,
+ a=dummy_a,
+ b=dummy_b,
+ A_log=self.A_log,
+ dt_bias=self.dt_bias,
+ num_k_heads=num_k_heads,
+ head_k_dim=self.head_k_dim,
+ head_v_dim=self.head_v_dim,
+ apply_l2norm=True,
+ output_g_exp=False,
+ )
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+ g = g.unsqueeze(0)
+ beta = beta.unsqueeze(0)
state = torch.zeros(
1,
num_v_heads,
@@ -735,7 +749,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
- use_qk_l2norm_in_kernel=True,
+ use_qk_l2norm_in_kernel=False,
)
except Exception:
logger.warning(
@@ -753,7 +767,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
self.prefix,
)
finally:
- del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
+ del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.accelerator.empty_cache()
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
index 7bffc3218b42..42b35a420cab 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
@@ -6,45 +6,49 @@
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
+from vllm.config import get_current_vllm_config
+from vllm.model_executor.kernels.linear import (
+ init_wfp8_a16_linear_kernel,
+)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
+from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
+ STRATEGY_TO_PARAMETER_TYPE,
+ STRATEGY_TO_WEIGHT_QUANT_KEY,
+)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter,
create_fp8_weight_parameter,
- process_fp8_weight_block_strategy,
validate_fp8_block_shape,
)
-from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
- apply_fp8_marlin_linear,
- prepare_fp8_layer_for_marlin,
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ kFp8DynamicTensorSym,
+ kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
-from vllm.model_executor.parameter import (
- BlockQuantScaleParameter,
- ChannelQuantScaleParameter,
- PerTensorScaleParameter,
-)
+from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import replace_parameter
__all__ = ["CompressedTensorsW8A16Fp8"]
-strategy_to_parameter_type = {
- QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
- QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
- QuantizationStrategy.TENSOR: PerTensorScaleParameter,
-}
-
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
+ self.out_dtype = torch.get_default_dtype()
+ self.input_dtype = get_current_vllm_config().model_config.dtype
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
+ self.weight_quant_key = STRATEGY_TO_WEIGHT_QUANT_KEY[self.strategy]
+ self.activation_quant_key = (
+ kFp8StaticTensorSym if is_static_input_scheme else kFp8DynamicTensorSym
+ )
+
@classmethod
def get_min_capability(cls) -> int:
# turing and up
@@ -89,7 +93,7 @@ def create_weights(
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
- strategy_to_parameter_type[self.strategy],
+ STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
@@ -105,32 +109,36 @@ def create_weights(
)
layer.register_parameter("input_scale", input_scale)
+ self.linear_kernel = init_wfp8_a16_linear_kernel(
+ weight_quant_key=self.weight_quant_key,
+ activation_quant_key=self.activation_quant_key,
+ weight_shape=layer.weight.shape,
+ input_dtype=self.input_dtype,
+ out_dtype=self.out_dtype,
+ )
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- weight = layer.weight
- weight_scale = layer.weight_scale
- size_k_first = True
- # TODO(rob): refactor block quant into separate class.
if self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
- size_k_first = False
- weight, weight_scale = process_fp8_weight_block_strategy(
- weight, weight_scale
- )
+ # MarlinFP8ScaledMMLinearKernel uses "weight_scale_inv" for block
+ # quant, while CT registers the scale as "weight_scale".
+ # Rename by deleting the old parameter and adding the new one so
+ # that prepare_fp8_layer_for_marlin (which prefers "weight_scale"
+ # over "weight_scale_inv") picks up "weight_scale_inv" correctly.
+ weight_scale_data = layer.weight_scale.data
+ del layer._parameters["weight_scale"]
+ replace_parameter(layer, "weight_scale_inv", weight_scale_data)
else:
- # Weights must be transposed for marlin
- weight = weight.t()
if self.strategy == QuantizationStrategy.TENSOR:
- # If we have a fused module (QKV, MLP) with per tensor scales,
- # we expand each scale to its shard's channels.
- weight_scale = convert_to_channelwise(
- weight_scale, layer.logical_widths
+ # For fused modules with per-tensor scales, expand each scale
+ # to its shard's channels.
+ replace_parameter(
+ layer,
+ "weight_scale",
+ convert_to_channelwise(layer.weight_scale, layer.logical_widths),
)
- # Update layer with new values
- replace_parameter(layer, "weight", weight.data)
- replace_parameter(layer, "weight_scale", weight_scale.data)
-
- prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
+ self.linear_kernel.process_weights_after_loading(layer)
def apply_weights(
self,
@@ -138,12 +146,4 @@ def apply_weights(
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
- return apply_fp8_marlin_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- workspace=layer.workspace,
- size_n=layer.output_size_per_partition,
- size_k=layer.input_size_per_partition,
- bias=bias,
- )
+ return self.linear_kernel.apply_weights(layer, x, bias)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
index c6b810eb9679..3bf606ddb332 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@@ -16,6 +16,9 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
+from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
+ STRATEGY_TO_PARAMETER_TYPE,
+)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale,
create_fp8_scale_parameter,
@@ -34,20 +37,9 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
-from vllm.model_executor.parameter import (
- BlockQuantScaleParameter,
- ChannelQuantScaleParameter,
- PerTensorScaleParameter,
-)
__all__ = ["CompressedTensorsW8A8Fp8"]
-strategy_to_parameter_type = {
- QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
- QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
- QuantizationStrategy.TENSOR: PerTensorScaleParameter,
-}
-
STATIC_QUANT = True
DYNAMIC_QUANT = False
activation_quant_key_mapping = {
@@ -130,7 +122,7 @@ def create_weights(
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
- strategy_to_parameter_type[self.strategy],
+ STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
index f88092169110..04c64d9bd56f 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
@@ -6,8 +6,36 @@
import regex as re
from compressed_tensors import CompressionFormat
+from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Module
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
+from vllm.model_executor.parameter import (
+ BlockQuantScaleParameter,
+ ChannelQuantScaleParameter,
+ PerTensorScaleParameter,
+)
+
+# Maps quantization strategy to the corresponding scale parameter type.
+# Shared across compressed-tensor scheme classes (w8a16_fp8, w8a8_fp8, …).
+STRATEGY_TO_PARAMETER_TYPE = {
+ QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
+ QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
+ QuantizationStrategy.TENSOR: PerTensorScaleParameter,
+}
+
+# Maps quantization strategy to the vLLM weight-quant key used for
+# kernel selection. Shared across compressed-tensor scheme classes.
+STRATEGY_TO_WEIGHT_QUANT_KEY = {
+ QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
+ QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
+ QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
+}
+
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py
index 726ac2232af9..3ed14ae808df 100644
--- a/vllm/model_executor/layers/quantization/kv_cache.py
+++ b/vllm/model_executor/layers/quantization/kv_cache.py
@@ -10,7 +10,7 @@
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
-from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
+from vllm.v1.kv_cache_interface import kv_cache_uses_runtime_scales
logger = init_logger(__name__)
@@ -54,14 +54,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert not hasattr(layer, "prob_scale")
return
- # Per-token-head quantized KV cache: scales are computed dynamically
- # per (token, head) in the kernel at cache-write time. Checkpoint
- # scales are never used regardless of calculate_kv_scales.
- if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype):
+ # Runtime-scale quantized KV cache: scales are computed dynamically
+ # at cache-write time. Checkpoint scales are never used.
+ if kv_cache_uses_runtime_scales(layer.kv_cache_dtype):
layer._k_scale.copy_(1.0)
layer._v_scale.copy_(1.0)
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
+ layer.calculate_kv_scales = False
del layer.k_scale
del layer.v_scale
del layer.q_scale
diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py
index 7244b9fac84f..2ff72ea2f5e0 100644
--- a/vllm/reasoning/abs_reasoning_parsers.py
+++ b/vllm/reasoning/abs_reasoning_parsers.py
@@ -39,6 +39,20 @@ def vocab(self) -> dict[str, int]:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
+ @property
+ def reasoning_start_str(self) -> str | None:
+ """Set `reasoning_start_str` to the strings that delimit
+ the reasoning block (e.g. `""""` and `""`).
+ """
+ return None
+
+ @property
+ def reasoning_end_str(self) -> str | None:
+ """Set `reasoning_end_str` to the strings that delimit
+ the reasoning block (e.g. `""""` and `""`).
+ """
+ return None
+
@abstractmethod
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""
diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py
index a8bb33d2c9cd..938b7f736b2c 100644
--- a/vllm/reasoning/basic_parsers.py
+++ b/vllm/reasoning/basic_parsers.py
@@ -39,6 +39,14 @@ def end_token(self) -> str:
"""The token that ends reasoning content."""
raise NotImplementedError
+ @property
+ def reasoning_start_str(self) -> str:
+ return self.start_token
+
+ @property
+ def reasoning_end_str(self) -> str:
+ return self.end_token
+
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py
index 94f8c096e313..e1efbe58a842 100644
--- a/vllm/utils/torch_utils.py
+++ b/vllm/utils/torch_utils.py
@@ -33,6 +33,7 @@
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float": torch.float,
+ "int4_per_token_head": torch.uint8,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
@@ -64,7 +65,11 @@
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
- return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head")
+ return (
+ kv_cache_dtype == "int4_per_token_head"
+ or kv_cache_dtype.startswith("fp8")
+ or kv_cache_dtype.endswith("per_token_head")
+ )
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
@@ -72,6 +77,21 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.endswith("per_token_head")
+def kv_cache_uses_runtime_scale_cache(kv_cache_dtype: str) -> bool:
+ return (
+ kv_cache_dtype == "int4_per_token_head"
+ or kv_cache_uses_per_token_head_scales(kv_cache_dtype)
+ )
+
+
+def kv_cache_uses_int4_packing(kv_cache_dtype: str) -> bool:
+ return kv_cache_dtype == "int4_per_token_head"
+
+
+def kv_cache_uses_fp8_storage(kv_cache_dtype: str) -> bool:
+ return kv_cache_dtype.startswith("fp8")
+
+
def is_strictly_contiguous(t: torch.Tensor) -> bool:
"""
Check if tensor is contiguous AND has no degenerate strides.
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index b9ccd4fce1c3..90139d3f3315 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -11,7 +11,10 @@
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
-from vllm.utils.torch_utils import is_quantized_kv_cache
+from vllm.utils.torch_utils import (
+ is_quantized_kv_cache,
+ kv_cache_uses_fp8_storage,
+)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
@@ -184,7 +187,10 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return True
if is_quantized_kv_cache(kv_cache_dtype):
- return flash_attn_supports_fp8()
+ return (
+ kv_cache_uses_fp8_storage(kv_cache_dtype)
+ and flash_attn_supports_fp8()
+ )
return kv_cache_dtype in ["auto", "float16", "bfloat16"]
@classmethod
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index bd8ec29bc33a..3598f77b137f 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -18,7 +18,12 @@
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
-from vllm.utils.torch_utils import is_quantized_kv_cache
+from vllm.utils.torch_utils import (
+ is_quantized_kv_cache,
+ kv_cache_uses_fp8_storage,
+ kv_cache_uses_int4_packing,
+ kv_cache_uses_runtime_scale_cache,
+)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -271,6 +276,7 @@ class TritonAttentionBackend(AttentionBackend):
"auto",
"float16",
"bfloat16",
+ "int4_per_token_head",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
@@ -308,6 +314,19 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
+ if kv_cache_uses_int4_packing(cache_dtype_str):
+ packed_head_size = (head_size + 1) // 2
+ packed_head_size_padded = ((packed_head_size + 3) // 4) * 4
+ # Reserve 4 inline bytes per (token, head) for the float32 runtime
+ # scale. This keeps int4 accounting honest without needing a second
+ # allocation path for scale caches.
+ return (
+ num_blocks,
+ 2,
+ block_size,
+ num_kv_heads,
+ packed_head_size_padded + 4,
+ )
if kv_cache_uses_per_token_head_scales(cache_dtype_str):
# Pad head_size by sizeof(float32)/sizeof(cache_dtype) so
# the per-head scale fits inline. The backend extracts
@@ -391,19 +410,18 @@ def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
"""Extract per-head scale views from the padded head dimension.
The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)``
- where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last
- ``pad`` elements of each head hold one float32 scale. We create
- strided float32 views over those bytes.
+ where the trailing bytes in each head region hold one float32 scale.
+ For int8/fp8 per-token-head cache, the scale sits after the dense data.
+ For int4 cache, the scale sits after the packed-bytes region padded to a
+ 4-byte boundary. We create strided float32 views over those bytes.
Scale shape: ``(num_blocks, block_size, num_kv_heads)``
"""
if self._k_scale_cache is not None:
return
- from vllm.utils.torch_utils import get_dtype_size
-
num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape
dtype_sz = kv_cache.element_size()
- scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4
+ scale_pad = torch.tensor([], dtype=torch.float32).element_size() // dtype_sz
hs = padded_hs - scale_pad
raw = kv_cache.untyped_storage()
@@ -411,30 +429,51 @@ def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
raw
)
- # In the raw bytes, each (block, kv_half, slot, head) occupies
- # padded_hs * dtype_sz bytes. The scale float32 sits at byte
- # offset hs * dtype_sz within that region.
- kv_half_bytes = block_size * nkv * padded_hs * dtype_sz
- full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks
- slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots
- head_f32 = padded_hs * dtype_sz // 4 # stride between heads
- scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head
+ # Use the actual tensor strides rather than re-deriving contiguous
+ # ones from shape. Hybrid page padding can inflate the leading block
+ # stride without changing the semantic cache shape.
+ scale_size = torch.tensor([], dtype=torch.float32).element_size()
+ base_byte_offset = kv_cache.storage_offset() * dtype_sz
+ block_byte_stride = kv_cache.stride(0) * dtype_sz
+ kv_half_byte_stride = kv_cache.stride(1) * dtype_sz
+ slot_byte_stride = kv_cache.stride(2) * dtype_sz
+ head_byte_stride = kv_cache.stride(3) * dtype_sz
+ last_dim_byte_stride = kv_cache.stride(4) * dtype_sz
+ scale_off_bytes = hs * last_dim_byte_stride
+
+ for value in (
+ base_byte_offset,
+ block_byte_stride,
+ kv_half_byte_stride,
+ slot_byte_stride,
+ head_byte_stride,
+ scale_off_bytes,
+ ):
+ assert value % scale_size == 0, (
+ "Scale cache view requires float32-aligned byte offsets/strides."
+ )
+
+ block_f32 = block_byte_stride // scale_size
+ kv_half_f32 = kv_half_byte_stride // scale_size
+ slot_f32 = slot_byte_stride // scale_size
+ head_f32 = head_byte_stride // scale_size
+ scale_off_f32 = (base_byte_offset + scale_off_bytes) // scale_size
# K scales: kv_half=0
self._k_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
- stride=(full_block_f32, slot_f32, head_f32),
+ stride=(block_f32, slot_f32, head_f32),
storage_offset=scale_off_f32,
)
self._k_scale_cache.fill_(1.0)
- # V scales: kv_half=1, offset by kv_half_bytes
- v_base_f32 = kv_half_bytes // 4
+ # V scales: kv_half=1, offset by the actual kv-half stride
+ v_base_f32 = kv_half_f32
self._v_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
- stride=(full_block_f32, slot_f32, head_f32),
+ stride=(block_f32, slot_f32, head_f32),
storage_offset=v_base_f32 + scale_off_f32,
)
self._v_scale_cache.fill_(1.0)
@@ -494,6 +533,11 @@ def __init__(
self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype)
self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head
+ self._uses_int4_kv_cache = kv_cache_uses_int4_packing(kv_cache_dtype)
+ self._uses_fp8_kv_cache = kv_cache_uses_fp8_storage(kv_cache_dtype)
+ self._uses_runtime_scale_cache = kv_cache_uses_runtime_scale_cache(
+ kv_cache_dtype
+ )
def forward(
self,
@@ -555,21 +599,22 @@ def forward(
layer,
)
- # Per-token-head quantized KV cache: use separate scale caches.
- if self._is_per_token_head_quant:
+ # Runtime-scale quantized KV cache: use scale caches extracted from the
+ # inline cache storage.
+ if self._uses_runtime_scale_cache:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
- if key_cache.dtype == torch.uint8:
+ if self._is_per_token_head_quant and key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
k_descale = None
v_descale = None
k_scale_cache = self._k_scale_cache
v_scale_cache = self._v_scale_cache
- # FP8 per-tensor / auto path (original flow).
+ # Per-tensor quantized path / auto path.
else:
key_cache, value_cache = kv_cache.unbind(1)
- if is_quantized_kv_cache(self.kv_cache_dtype):
+ if self._uses_fp8_kv_cache:
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
@@ -692,25 +737,39 @@ def do_kv_cache_update(
# we use direct Q, K, V tensors without caching
return
# Reshape the input keys and values and store them in the cache.
- if self._is_per_token_head_quant:
+ if self._uses_runtime_scale_cache:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
- if key_cache.dtype == torch.uint8:
+ if self._is_per_token_head_quant and key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
- triton_reshape_and_cache_flash_per_token_head_quant(
- key,
- value,
- key_cache,
- value_cache,
- self._k_scale_cache,
- self._v_scale_cache,
- slot_mapping,
- )
+ if self._is_per_token_head_quant:
+ triton_reshape_and_cache_flash_per_token_head_quant(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ self._k_scale_cache,
+ self._v_scale_cache,
+ slot_mapping,
+ )
+ else:
+ triton_reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slot_mapping,
+ self.kv_cache_dtype,
+ layer._k_scale,
+ layer._v_scale,
+ self._k_scale_cache,
+ self._v_scale_cache,
+ )
return
# For decoder and cross-attention, use KV cache as before.
key_cache, value_cache = kv_cache.unbind(1)
- if is_quantized_kv_cache(self.kv_cache_dtype):
+ if self._uses_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
triton_reshape_and_cache_flash(
@@ -725,7 +784,7 @@ def do_kv_cache_update(
)
def fused_rope_kvcache_supported(self):
- if self._is_per_token_head_quant:
+ if self._is_per_token_head_quant or self._uses_int4_kv_cache:
return False
return rocm_aiter_ops.is_enabled()
@@ -744,7 +803,12 @@ def do_rope_and_kv_cache_update(
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
- is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
+ if self._uses_int4_kv_cache:
+ raise NotImplementedError(
+ "fused rope + int4 KV cache is not supported yet"
+ )
+
+ is_fp8_kv_cache = self._uses_fp8_kv_cache
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
diff --git a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
index 6e696fdb5135..0ff3db5edad3 100644
--- a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
+++ b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
@@ -9,11 +9,20 @@
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
-from vllm.utils.torch_utils import is_quantized_kv_cache
+from vllm.utils.torch_utils import (
+ is_quantized_kv_cache,
+ kv_cache_uses_fp8_storage,
+ kv_cache_uses_int4_packing,
+)
FP8_MIN, FP8_MAX = get_fp8_min_max()
+@triton.jit
+def _round_to_int32(x):
+ return tl.where(x >= 0, tl.floor(x + 0.5), tl.ceil(x - 0.5)).to(tl.int32)
+
+
@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
@@ -38,6 +47,7 @@ def reshape_and_cache_kernel_flash(
USE_HEAD_MAJOR_LAYOUT: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
+ INT4_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
@@ -57,6 +67,8 @@ def reshape_and_cache_kernel_flash(
src_value_idx = token_idx * value_stride
if USE_HEAD_MAJOR_LAYOUT:
+ if INT4_KV_CACHE:
+ tl.device_assert(False, "int4 KV cache does not support head-major layout")
# Decompose the tile index back into head and dim coordinates.
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
@@ -76,8 +88,9 @@ def reshape_and_cache_kernel_flash(
+ (cur_dim % x)
)
else:
- cur_head = tile_pos // head_size
- cur_dim = tile_pos % head_size
+ head_size_physical = (head_size + 1) // 2 if INT4_KV_CACHE else head_size
+ cur_head = tile_pos // head_size_physical
+ cur_dim = tile_pos % head_size_physical
tgt_idx_k = (
block_idx * block_stride
+ block_offset * page_stride
@@ -87,39 +100,80 @@ def reshape_and_cache_kernel_flash(
tgt_idx_v = tgt_idx_k
# [TILE_SIZE]
- key_load = tl.load(
- key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
- )
- if FP8_KV_CACHE:
- # tl.store will do the correct implicit cast to fp8,
- # based on the key_cache_ptr.dtype.element_ty
- key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
+ if INT4_KV_CACHE:
+ packed_src_pos = cur_head * head_size + cur_dim * 2
+ key_lo = tl.load(
+ key_ptr + src_key_idx + packed_src_pos,
+ mask=tile_pos < (num_heads * head_size_physical),
+ other=0.0,
+ )
+ key_hi = tl.load(
+ key_ptr + src_key_idx + packed_src_pos + 1,
+ mask=(tile_pos < (num_heads * head_size_physical))
+ & ((cur_dim * 2 + 1) < head_size),
+ other=0.0,
+ )
+ q_key_lo = _round_to_int32(key_lo / tl.load(k_scale))
+ q_key_hi = _round_to_int32(key_hi / tl.load(k_scale))
+ q_key_lo = tl.maximum(tl.minimum(q_key_lo, 7), -8)
+ q_key_hi = tl.maximum(tl.minimum(q_key_hi, 7), -8)
+ key_tile = ((q_key_lo & 0xF) | ((q_key_hi & 0xF) << 4)).to(tl.uint8)
else:
- key_tile = key_load
+ key_load = tl.load(
+ key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
+ )
+ if FP8_KV_CACHE:
+ # tl.store will do the correct implicit cast to fp8,
+ # based on the key_cache_ptr.dtype.element_ty
+ key_tile = (
+ key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
+ )
+ else:
+ key_tile = key_load
# [TILE_SIZE]
- value_load = tl.load(
- value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
- )
- if FP8_KV_CACHE:
- if value_load.dtype.is_fp8():
- value_tile = value_load
- else:
- # tl.store will do the correct implicit cast to fp8,
- # based on the value_cache_ptr.dtype.element_ty
- value_tile = value_load / tl.load(v_scale)
+ if INT4_KV_CACHE:
+ packed_src_pos = cur_head * head_size + cur_dim * 2
+ value_lo = tl.load(
+ value_ptr + src_value_idx + packed_src_pos,
+ mask=tile_pos < (num_heads * head_size_physical),
+ other=0.0,
+ )
+ value_hi = tl.load(
+ value_ptr + src_value_idx + packed_src_pos + 1,
+ mask=(tile_pos < (num_heads * head_size_physical))
+ & ((cur_dim * 2 + 1) < head_size),
+ other=0.0,
+ )
+ q_value_lo = _round_to_int32(value_lo / tl.load(v_scale))
+ q_value_hi = _round_to_int32(value_hi / tl.load(v_scale))
+ q_value_lo = tl.maximum(tl.minimum(q_value_lo, 7), -8)
+ q_value_hi = tl.maximum(tl.minimum(q_value_hi, 7), -8)
+ value_tile = ((q_value_lo & 0xF) | ((q_value_hi & 0xF) << 4)).to(tl.uint8)
else:
- value_tile = value_load
+ value_load = tl.load(
+ value_ptr + src_value_idx + tile_pos,
+ mask=tile_pos < (num_heads * head_size),
+ )
+ if FP8_KV_CACHE:
+ if value_load.dtype.is_fp8():
+ value_tile = value_load
+ else:
+ # tl.store will do the correct implicit cast to fp8,
+ # based on the value_cache_ptr.dtype.element_ty
+ value_tile = value_load / tl.load(v_scale)
+ else:
+ value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx_k,
key_tile,
- mask=tile_pos < (num_heads * head_size),
+ mask=tile_pos < (num_heads * ((head_size + 1) // 2 if INT4_KV_CACHE else head_size)),
)
tl.store(
value_cache_ptr + tgt_idx_v,
value_tile,
- mask=tile_pos < (num_heads * head_size),
+ mask=tile_pos < (num_heads * ((head_size + 1) // 2 if INT4_KV_CACHE else head_size)),
)
return
@@ -244,6 +298,200 @@ def _reshape_cache_per_token_head(
}
+@triton.jit
+def _reshape_cache_int4_per_token_head(
+ key_ptr, # [num_tokens, num_kv_heads, head_size]
+ value_ptr, # [num_tokens, num_kv_heads, head_size_v]
+ key_cache_ptr, # [num_blocks, block_size, num_kv_heads, packed_head_size+pad]
+ value_cache_ptr, # [num_blocks, block_size, num_kv_heads, packed_head_size_v+pad]
+ k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
+ v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
+ slot_mapping_ptr, # [num_tokens]
+ stride_key_tok: tl.int64,
+ stride_key_head: tl.int64,
+ stride_val_tok: tl.int64,
+ stride_val_head: tl.int64,
+ stride_kc_blk: tl.int64,
+ stride_kc_slot: tl.int64,
+ stride_kc_head: tl.int64,
+ stride_vc_blk: tl.int64,
+ stride_vc_slot: tl.int64,
+ stride_vc_head: tl.int64,
+ stride_ks_blk: tl.int64,
+ stride_ks_slot: tl.int64,
+ stride_ks_head: tl.int64,
+ stride_vs_blk: tl.int64,
+ stride_vs_slot: tl.int64,
+ stride_vs_head: tl.int64,
+ block_size: tl.constexpr,
+ head_size: tl.constexpr,
+ head_size_v: tl.constexpr,
+ HEAD_SIZE_PADDED: tl.constexpr,
+ PACKED_HEAD_SIZE: tl.constexpr,
+ PACKED_HEAD_SIZE_V: tl.constexpr,
+ PACKED_HEAD_SIZE_PADDED: tl.constexpr,
+ PACKED_HEAD_SIZE_V_PADDED: tl.constexpr,
+):
+ tok = tl.program_id(0)
+ head = tl.program_id(1)
+
+ slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
+ if slot < 0:
+ return
+
+ blk = slot // block_size
+ slot_in_blk = slot % block_size
+
+ dim_offs = tl.arange(0, HEAD_SIZE_PADDED)
+
+ # ---- Key head -> absmax -> scale --------------------------------------
+ k_mask = dim_offs < head_size
+ k_h = tl.load(
+ key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
+ mask=k_mask,
+ other=0.0,
+ ).to(tl.float32)
+ k_scale = tl.maximum(tl.max(tl.abs(k_h)) / 7.0, 1e-6)
+ tl.store(
+ k_scale_cache_ptr
+ + blk * stride_ks_blk
+ + slot_in_blk * stride_ks_slot
+ + head * stride_ks_head,
+ k_scale,
+ )
+
+ packed_offs = tl.arange(0, PACKED_HEAD_SIZE_PADDED)
+ k_elem0 = packed_offs * 2
+ k_elem1 = k_elem0 + 1
+ k_lo = tl.load(
+ key_ptr + tok * stride_key_tok + head * stride_key_head + k_elem0,
+ mask=k_elem0 < head_size,
+ other=0.0,
+ ).to(tl.float32)
+ k_hi = tl.load(
+ key_ptr + tok * stride_key_tok + head * stride_key_head + k_elem1,
+ mask=k_elem1 < head_size,
+ other=0.0,
+ ).to(tl.float32)
+ q_key_lo = _round_to_int32(k_lo / k_scale)
+ q_key_hi = _round_to_int32(k_hi / k_scale)
+ q_key_lo = tl.maximum(tl.minimum(q_key_lo, 7), -8)
+ q_key_hi = tl.maximum(tl.minimum(q_key_hi, 7), -8)
+ key_packed = ((q_key_lo & 0xF) | ((q_key_hi & 0xF) << 4)).to(tl.uint8)
+ tl.store(
+ key_cache_ptr
+ + blk * stride_kc_blk
+ + slot_in_blk * stride_kc_slot
+ + head * stride_kc_head
+ + packed_offs,
+ key_packed,
+ mask=packed_offs < PACKED_HEAD_SIZE,
+ )
+
+ # ---- Value head -> absmax -> scale ------------------------------------
+ v_mask = dim_offs < head_size_v
+ v_h = tl.load(
+ value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs,
+ mask=v_mask,
+ other=0.0,
+ ).to(tl.float32)
+ v_scale = tl.maximum(tl.max(tl.abs(v_h)) / 7.0, 1e-6)
+ tl.store(
+ v_scale_cache_ptr
+ + blk * stride_vs_blk
+ + slot_in_blk * stride_vs_slot
+ + head * stride_vs_head,
+ v_scale,
+ )
+
+ packed_offs_v = tl.arange(0, PACKED_HEAD_SIZE_V_PADDED)
+ v_elem0 = packed_offs_v * 2
+ v_elem1 = v_elem0 + 1
+ v_lo = tl.load(
+ value_ptr + tok * stride_val_tok + head * stride_val_head + v_elem0,
+ mask=v_elem0 < head_size_v,
+ other=0.0,
+ ).to(tl.float32)
+ v_hi = tl.load(
+ value_ptr + tok * stride_val_tok + head * stride_val_head + v_elem1,
+ mask=v_elem1 < head_size_v,
+ other=0.0,
+ ).to(tl.float32)
+ q_value_lo = _round_to_int32(v_lo / v_scale)
+ q_value_hi = _round_to_int32(v_hi / v_scale)
+ q_value_lo = tl.maximum(tl.minimum(q_value_lo, 7), -8)
+ q_value_hi = tl.maximum(tl.minimum(q_value_hi, 7), -8)
+ value_packed = ((q_value_lo & 0xF) | ((q_value_hi & 0xF) << 4)).to(tl.uint8)
+ tl.store(
+ value_cache_ptr
+ + blk * stride_vc_blk
+ + slot_in_blk * stride_vc_slot
+ + head * stride_vc_head
+ + packed_offs_v,
+ value_packed,
+ mask=packed_offs_v < PACKED_HEAD_SIZE_V,
+ )
+
+
+def triton_reshape_and_cache_flash_int4_per_token_head(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ k_scale_cache: torch.Tensor,
+ v_scale_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+):
+ num_tokens, num_kv_heads, head_size = key.shape
+ head_size_v = value.shape[2]
+ head_size_padded = triton.next_power_of_2(max(head_size, head_size_v))
+ packed_head_size = (head_size + 1) // 2
+ packed_head_size_v = (head_size_v + 1) // 2
+ packed_head_size_padded = triton.next_power_of_2(max(1, packed_head_size))
+ packed_head_size_v_padded = triton.next_power_of_2(max(1, packed_head_size_v))
+ block_size = key_cache.shape[1]
+
+ if current_platform.is_rocm() or current_platform.is_xpu():
+ num_warps = 4
+ else:
+ num_warps = min(8, max(1, head_size_padded // 32))
+
+ _reshape_cache_int4_per_token_head[(num_tokens, num_kv_heads)](
+ key_ptr=key,
+ value_ptr=value,
+ key_cache_ptr=key_cache,
+ value_cache_ptr=value_cache,
+ k_scale_cache_ptr=k_scale_cache,
+ v_scale_cache_ptr=v_scale_cache,
+ slot_mapping_ptr=slot_mapping,
+ stride_key_tok=key.stride(0),
+ stride_key_head=key.stride(1),
+ stride_val_tok=value.stride(0),
+ stride_val_head=value.stride(1),
+ stride_kc_blk=key_cache.stride(0),
+ stride_kc_slot=key_cache.stride(1),
+ stride_kc_head=key_cache.stride(2),
+ stride_vc_blk=value_cache.stride(0),
+ stride_vc_slot=value_cache.stride(1),
+ stride_vc_head=value_cache.stride(2),
+ stride_ks_blk=k_scale_cache.stride(0),
+ stride_ks_slot=k_scale_cache.stride(1),
+ stride_ks_head=k_scale_cache.stride(2),
+ stride_vs_blk=v_scale_cache.stride(0),
+ stride_vs_slot=v_scale_cache.stride(1),
+ stride_vs_head=v_scale_cache.stride(2),
+ block_size=block_size,
+ head_size=head_size,
+ head_size_v=head_size_v,
+ HEAD_SIZE_PADDED=head_size_padded,
+ PACKED_HEAD_SIZE=packed_head_size,
+ PACKED_HEAD_SIZE_V=packed_head_size_v,
+ PACKED_HEAD_SIZE_PADDED=packed_head_size_padded,
+ PACKED_HEAD_SIZE_V_PADDED=packed_head_size_v_padded,
+ num_warps=num_warps,
+ )
+
+
def triton_reshape_and_cache_flash_per_token_head_quant(
key: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v]
@@ -324,9 +572,11 @@ def triton_reshape_and_cache_flash(
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
- kv_cache_dtype: str, # "auto", "fp8"
+ kv_cache_dtype: str, # "auto", "fp8", "int4_per_token_head"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
+ k_scale_cache: torch.Tensor | None = None,
+ v_scale_cache: torch.Tensor | None = None,
):
num_heads = key.shape[1]
head_size = key.shape[2]
@@ -353,15 +603,15 @@ def triton_reshape_and_cache_flash(
assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
+ int4_kv_cache = kv_cache_uses_int4_packing(kv_cache_dtype)
+ fp8_kv_cache = kv_cache_uses_fp8_storage(kv_cache_dtype)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
- if is_quantized_kv_cache(kv_cache_dtype)
- else key_cache.dtype
+ if fp8_kv_cache
+ else torch.uint8 if int4_kv_cache else key_cache.dtype
)
- if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(
- kv_cache_dtype
- ):
+ if key_cache.dtype != kv_cache_torch_dtype and fp8_kv_cache:
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
@@ -371,7 +621,7 @@ def triton_reshape_and_cache_flash(
"uint8 is not supported by triton reshape_and_cache_flash"
)
- FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
+ FP8_KV_CACHE = fp8_kv_cache
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
@@ -384,6 +634,22 @@ def triton_reshape_and_cache_flash(
)
# heuristics instead of autotuning
+ if int4_kv_cache:
+ assert (
+ not use_head_major_layout
+ ), "int4 KV cache only supports NHD layout for now"
+ if k_scale_cache is not None and v_scale_cache is not None:
+ triton_reshape_and_cache_flash_int4_per_token_head(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ k_scale_cache,
+ v_scale_cache,
+ slot_mapping,
+ )
+ return
+ n = num_heads * ((head_size + 1) // 2)
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
@@ -423,6 +689,7 @@ def triton_reshape_and_cache_flash(
x=x,
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
FP8_KV_CACHE=FP8_KV_CACHE,
+ INT4_KV_CACHE=int4_kv_cache,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
@@ -519,6 +786,11 @@ def triton_reshape_and_cache_flash_diffkv(
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
+ if kv_cache_uses_int4_packing(kv_cache_dtype):
+ raise NotImplementedError(
+ "int4 KV cache is not supported for diffkv layout yet"
+ )
+
num_heads = key.shape[1]
head_size_k = key.shape[2]
head_size_v = value.shape[2]
diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py
index 150f022f848c..bde8647ee065 100644
--- a/vllm/v1/attention/ops/triton_unified_attention.py
+++ b/vllm/v1/attention/ops/triton_unified_attention.py
@@ -39,6 +39,7 @@ def _prepare_kv_tile(
Q,
tensor_scale,
scale_cache_ptr,
+ dim_offsets,
physical_block_idx,
seq_offset,
kv_head_idx,
@@ -48,6 +49,7 @@ def _prepare_kv_tile(
tile_mask,
BLOCK_SIZE: tl.constexpr,
KV_QUANT_MODE: tl.constexpr,
+ DIM_IS_LAST: tl.constexpr,
):
"""Prepare a loaded KV tile for attention computation.
@@ -57,25 +59,46 @@ def _prepare_kv_tile(
- ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16).
- ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline
using the tensor-wide scale.
- - ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's
+ - ``KV_QUANT_MODE == 4`` (int4): unpack signed int4 and dequantize inline
+ using runtime token/head scales.
+ - ``KV_QUANT_MODE == 2 or 3`` (per-token-head int8/fp8): cast to Q's
dtype and return per-head scales separately — the caller applies
them after the dot product for better numerical efficiency.
Returns ``(data, token_head_scales)``. *token_head_scales* is only
- meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on
+ meaningful when ``KV_QUANT_MODE == 2 or 3``; callers gate its use on
the same constexpr so the compiler eliminates dead code.
"""
- # KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor,
- # 2=int8 per-token-head, 3=fp8 per-token-head
+ # KV_QUANT_MODE values:
+ # 0=none, 1=fp8 per-tensor, 2=int8 per-token-head,
+ # 3=fp8 per-token-head, 4=int4 packed bytes + runtime scales
- # Placeholder scales (float32) — never read when KV_QUANT_MODE < 2.
+ # Placeholder scales (float32) — never read when KV_QUANT_MODE is not
+ # per-token-head.
unused_scales = tile_mask.to(tl.float32)
if KV_QUANT_MODE == 1: # FP8 per-tensor
if Q.dtype.is_fp8():
return data.to(Q.dtype), unused_scales
return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales
- if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8)
+ if KV_QUANT_MODE == 4: # int4
+ scale_idx = (
+ physical_block_idx * stride_s_blk
+ + (seq_offset % BLOCK_SIZE) * stride_s_slot
+ + kv_head_idx * stride_s_head
+ )
+ token_head_scales = tl.load(scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0)
+ scale_values = (
+ token_head_scales[:, None] if DIM_IS_LAST else token_head_scales[None, :]
+ )
+ parity = (
+ (dim_offsets[None, :] & 1) if DIM_IS_LAST else (dim_offsets[:, None] & 1)
+ )
+ packed = data.to(tl.uint8).to(tl.int32)
+ nibble = tl.where(parity == 0, packed & 0xF, packed >> 4)
+ signed = tl.where(nibble >= 8, nibble - 16, nibble)
+ return (signed.to(tl.float32) * scale_values).to(Q.dtype), token_head_scales
+ if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: # per-token-head (int8 or fp8)
scale_idx = (
physical_block_idx * stride_s_blk
+ (seq_offset % BLOCK_SIZE) * stride_s_slot
@@ -167,7 +190,7 @@ def kernel_unified_attention_2d(
KV_QUANT_MODE: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
- # Per-token-head scale caches (KV_QUANT_MODE >= 2)
+ # Per-token-head scale caches (KV_QUANT_MODE == 2 or 3)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
@@ -308,17 +331,19 @@ def kernel_unified_attention_2d(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
+ packed_dim = offs_d // 2 if KV_QUANT_MODE == 4 else offs_d
+
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
- + offs_d[None, :] * stride_v_cache_3
+ + packed_dim[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
- + offs_d[:, None] * stride_k_cache_3
+ + packed_dim[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
@@ -333,6 +358,7 @@ def kernel_unified_attention_2d(
Q,
k_scale,
k_scale_cache_ptr,
+ offs_d,
physical_block_idx,
seq_offset,
kv_head_idx,
@@ -342,6 +368,7 @@ def kernel_unified_attention_2d(
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
+ False,
)
# V : (TILE_SIZE, HEAD_SIZE)
@@ -355,6 +382,7 @@ def kernel_unified_attention_2d(
Q,
v_scale,
v_scale_cache_ptr,
+ offs_d,
physical_block_idx,
seq_offset,
kv_head_idx,
@@ -364,6 +392,7 @@ def kernel_unified_attention_2d(
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
+ True,
)
# Compute attention mask: causal by default (key <= query)
@@ -404,7 +433,7 @@ def kernel_unified_attention_2d(
# Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
- if KV_QUANT_MODE >= 2:
+ if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
@@ -472,7 +501,7 @@ def kernel_unified_attention_2d(
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
# Per-token-head quant: apply v_scale to P instead of V.
- if KV_QUANT_MODE >= 2:
+ if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
@@ -549,7 +578,7 @@ def kernel_unified_attention_3d(
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
KV_QUANT_MODE: tl.constexpr = 0,
- # Per-token-head scale caches (KV_QUANT_MODE >= 2)
+ # Per-token-head scale caches (KV_QUANT_MODE == 2 or 3)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
@@ -699,17 +728,19 @@ def kernel_unified_attention_3d(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
+ packed_dim = offs_d // 2 if KV_QUANT_MODE == 4 else offs_d
+
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
- + offs_d[None, :] * stride_v_cache_3
+ + packed_dim[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
- + offs_d[:, None] * stride_k_cache_3
+ + packed_dim[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
@@ -724,6 +755,7 @@ def kernel_unified_attention_3d(
Q,
k_scale,
k_scale_cache_ptr,
+ offs_d,
physical_block_idx,
seq_offset,
kv_head_idx,
@@ -733,6 +765,7 @@ def kernel_unified_attention_3d(
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
+ False,
)
# V : (TILE_SIZE, HEAD_SIZE)
@@ -746,6 +779,7 @@ def kernel_unified_attention_3d(
Q,
v_scale,
v_scale_cache_ptr,
+ offs_d,
physical_block_idx,
seq_offset,
kv_head_idx,
@@ -755,6 +789,7 @@ def kernel_unified_attention_3d(
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
+ True,
)
# Compute attention mask: causal by default (key <= query)
@@ -795,7 +830,7 @@ def kernel_unified_attention_3d(
# Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
- if KV_QUANT_MODE >= 2:
+ if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
@@ -863,7 +898,7 @@ def kernel_unified_attention_3d(
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
# Per-token-head quant: apply v_scale to P instead of V.
- if KV_QUANT_MODE >= 2:
+ if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index 9ab5af0f6fb0..41a8f0074fac 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -18,6 +18,7 @@
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import format_gib
from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
@@ -932,6 +933,40 @@ def unify_kv_cache_spec_page_size(
return kv_cache_spec
max_page_size = max(page_sizes)
+
+ # Runtime-scale quantized attention can add a small fixed per-head overhead
+ # that breaks exact divisibility between hybrid attention types even when
+ # the underlying KV topology is otherwise compatible. In that case we can
+ # preserve the existing block_size and pad the smaller attention pages up to
+ # the maximum page size instead of hard-failing during grouping.
+ if (
+ all(isinstance(spec, AttentionSpec) for spec in kv_cache_spec.values())
+ and any(
+ spec.kv_quant_mode.uses_runtime_scales
+ for spec in kv_cache_spec.values()
+ )
+ ):
+ block_sizes = {spec.block_size for spec in kv_cache_spec.values()}
+ has_non_divisible_page = any(
+ spec.page_size_bytes != max_page_size
+ and max_page_size % spec.page_size_bytes != 0
+ for spec in kv_cache_spec.values()
+ )
+ if len(block_sizes) == 1 and has_non_divisible_page:
+ logger.warning(
+ "Padding smaller hybrid KV pages up to %d bytes to keep page "
+ "sizes compatible across attention types.",
+ max_page_size,
+ )
+ return {
+ layer_name: (
+ layer_spec
+ if layer_spec.page_size_bytes == max_page_size
+ else replace(layer_spec, page_size_padded=max_page_size)
+ )
+ for layer_name, layer_spec in kv_cache_spec.items()
+ }
+
new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
@@ -1294,27 +1329,10 @@ def _report_kv_cache_config(
vllm_config: The global VllmConfig
kv_cache_config: The resolved KV cache configuration
"""
- min_block_size = min(
- [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups]
- )
-
# Log the KV cache size and maximum concurrency.
- num_tokens = (
- kv_cache_config.num_blocks
- // len(kv_cache_config.kv_cache_groups)
- * min_block_size
+ num_tokens = estimate_token_capacity_for_kv_cache_config(
+ vllm_config, kv_cache_config
)
- dcp_size = vllm_config.parallel_config.decode_context_parallel_size
- pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
- if pcp_size * dcp_size > 1:
- num_tokens *= pcp_size * dcp_size
- logger.info(
- "Multiplying the GPU KV cache size by the cp_world_size %d "
- "(pcp_world_size %d * dcp_world_size %d).",
- pcp_size * dcp_size,
- pcp_size,
- dcp_size,
- )
num_tokens_str = f"{num_tokens:,}"
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
@@ -1329,6 +1347,30 @@ def _report_kv_cache_config(
)
+def estimate_token_capacity_for_kv_cache_config(
+ vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
+) -> int:
+ """Estimate total KV tokens that fit in the current GPU cache allocation.
+
+ The log line reports the total number of tokens that can be cached across
+ all concurrent requests, not the maximum request length for a single
+ sequence. Reuse the concurrency calculation so the reported token capacity
+ can exceed ``max_model_len`` when the cache holds multiple full-length
+ requests.
+ """
+ if not kv_cache_config.kv_cache_groups:
+ return 0
+
+ max_model_len = vllm_config.model_config.max_model_len
+ if max_model_len <= 0:
+ return 0
+
+ return int(
+ get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config)
+ * max_model_len
+ )
+
+
def _max_memory_usage_bytes_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py
index 323862d77ba1..1aab90c14581 100644
--- a/vllm/v1/engine/input_processor.py
+++ b/vllm/v1/engine/input_processor.py
@@ -98,9 +98,9 @@ def _validate_params(
self.tokenizer,
)
- if (
- params.thinking_token_budget is not None
- and self.vllm_config.reasoning_config is None
+ if params.thinking_token_budget is not None and (
+ self.vllm_config.reasoning_config is None
+ or not self.vllm_config.reasoning_config.enabled
):
raise ValueError(
"thinking_token_budget is set but reasoning_config is "
diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py
index 6f8ad8e7d8ef..3dbb76eb3c8b 100644
--- a/vllm/v1/kv_cache_interface.py
+++ b/vllm/v1/kv_cache_interface.py
@@ -38,15 +38,29 @@ class KVQuantMode(IntEnum):
FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path)
INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8
FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8
+ INT4_PER_TOKEN_HEAD = 4 # packed signed int4 with runtime token/head scales
@property
def is_per_token_head(self) -> bool:
"""True for any per-token-head quantization mode."""
- return self >= 2
+ return self in (
+ KVQuantMode.INT8_PER_TOKEN_HEAD,
+ KVQuantMode.FP8_PER_TOKEN_HEAD,
+ )
+
+ @property
+ def is_int4_packed(self) -> bool:
+ return self == KVQuantMode.INT4_PER_TOKEN_HEAD
+
+ @property
+ def uses_runtime_scales(self) -> bool:
+ return self.is_per_token_head or self.is_int4_packed
def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
"""Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
+ if kv_cache_dtype == "int4_per_token_head":
+ return KVQuantMode.INT4_PER_TOKEN_HEAD
if kv_cache_dtype == "int8_per_token_head":
return KVQuantMode.INT8_PER_TOKEN_HEAD
if kv_cache_dtype == "fp8_per_token_head":
@@ -65,6 +79,11 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
return get_kv_quant_mode(kv_cache_dtype).is_per_token_head
+def kv_cache_uses_runtime_scales(kv_cache_dtype: str) -> bool:
+ """Return True if *kv_cache_dtype* computes scales at cache-write time."""
+ return get_kv_quant_mode(kv_cache_dtype).uses_runtime_scales
+
+
@dataclass(frozen=True)
class KVCacheSpec:
"""
@@ -135,11 +154,17 @@ def page_size_bytes(self) -> int:
@property
def real_page_size_bytes(self) -> int:
+ head_size_physical = self.head_size
+ if self.kv_quant_mode.is_int4_packed:
+ packed_head_size = cdiv(self.head_size, 2)
+ packed_head_size_padded = cdiv(packed_head_size, 4) * 4
+ # int4 stores packed bytes plus one float32 runtime scale inline.
+ head_size_physical = packed_head_size_padded + get_dtype_size(torch.float32)
return (
2
* self.block_size
* self.num_kv_heads
- * self.head_size
+ * head_size_physical
* get_dtype_size(self.dtype)
)
@@ -237,10 +262,21 @@ def merge(cls, specs: list[Self]) -> Self:
@property
def real_page_size_bytes(self) -> int:
+ head_size_k = self.head_size
+ head_size_v = self.head_size_v
+ if self.kv_quant_mode.is_int4_packed:
+ packed_head_size_k = cdiv(self.head_size, 2)
+ packed_head_size_v = cdiv(self.head_size_v, 2)
+ head_size_k = cdiv(packed_head_size_k, 4) * 4 + get_dtype_size(
+ torch.float32
+ )
+ head_size_v = cdiv(packed_head_size_v, 4) * 4 + get_dtype_size(
+ torch.float32
+ )
return (
self.block_size
* self.num_kv_heads
- * (self.head_size + self.head_size_v)
+ * (head_size_k + head_size_v)
* get_dtype_size(self.dtype)
)
diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py
index 0d9b67017fd1..1739452b44a0 100644
--- a/vllm/v1/sample/logits_processor/builtin.py
+++ b/vllm/v1/sample/logits_processor/builtin.py
@@ -301,7 +301,7 @@ def __init__(
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
# Check if thinking is enabled
- self.is_enabled = reasoning_config is not None
+ self.is_enabled = reasoning_config is not None and reasoning_config.enabled
self.reasoning_start_token_ids = getattr(
reasoning_config, "reasoning_start_token_ids", []
diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py
index 34089a67b3be..53629020c901 100644
--- a/vllm/v1/worker/gpu/attn_utils.py
+++ b/vllm/v1/worker/gpu/attn_utils.py
@@ -149,7 +149,39 @@ def _reshape_kv_cache(
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
- raw_tensor = raw_tensor.view(kv_cache_shape)
+ padded_page_size_bytes = getattr(kv_cache_spec, "page_size_padded", None)
+ if (
+ padded_page_size_bytes is not None
+ and padded_page_size_bytes > kv_cache_spec.real_page_size_bytes
+ ):
+ block_dim = attn_backend.get_kv_cache_block_dim(
+ kv_cache_spec.block_size,
+ kv_cache_spec.num_kv_heads,
+ kv_cache_spec.head_size,
+ cache_dtype_str=cache_dtype,
+ )
+ physical_block_dim = kv_cache_stride_order.index(block_dim)
+ if physical_block_dim != 0:
+ raise NotImplementedError(
+ "Padded KV pages are only supported when the backend "
+ "block dimension is leading."
+ )
+ dtype_size = raw_tensor.element_size()
+ padded_elems_per_block = padded_page_size_bytes // dtype_size
+ strides = [0] * len(kv_cache_shape)
+ strides[-1] = 1
+ for i in range(len(kv_cache_shape) - 2, -1, -1):
+ if i == physical_block_dim:
+ strides[i] = padded_elems_per_block
+ else:
+ strides[i] = strides[i + 1] * kv_cache_shape[i + 1]
+ raw_tensor = torch.as_strided(
+ raw_tensor,
+ size=kv_cache_shape,
+ stride=tuple(strides),
+ )
+ else:
+ raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 872cb83d2401..c848d3962f90 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -6658,12 +6658,44 @@ def _reshape_kv_cache_tensors(
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
- kv_caches[layer_name] = (
- kv_cache_raw_tensors[layer_name]
- .view(dtype)
- .view(kv_cache_shape)
- .permute(*inv_order)
+ raw_view = kv_cache_raw_tensors[layer_name].view(dtype)
+ padded_page_size_bytes = getattr(
+ kv_cache_spec, "page_size_padded", None
)
+ semantic_elems_per_block = int(np.prod(kv_cache_shape[1:]))
+ if (
+ padded_page_size_bytes is not None
+ and padded_page_size_bytes > kv_cache_spec.real_page_size_bytes
+ ):
+ block_dim = attn_backend.get_kv_cache_block_dim(
+ kernel_block_size,
+ kv_cache_spec.num_kv_heads,
+ kv_cache_spec.head_size,
+ cache_dtype_str=self.cache_config.cache_dtype,
+ )
+ physical_block_dim = kv_cache_stride_order.index(block_dim)
+ if physical_block_dim != 0:
+ raise NotImplementedError(
+ "Padded KV pages are only supported when the "
+ "backend block dimension is leading."
+ )
+ dtype_size = get_dtype_size(dtype)
+ padded_elems_per_block = padded_page_size_bytes // dtype_size
+ strides = [0] * len(kv_cache_shape)
+ strides[-1] = 1
+ for i in range(len(kv_cache_shape) - 2, -1, -1):
+ if i == physical_block_dim:
+ strides[i] = padded_elems_per_block
+ else:
+ strides[i] = strides[i + 1] * kv_cache_shape[i + 1]
+ raw_view = torch.as_strided(
+ raw_view,
+ size=kv_cache_shape,
+ stride=tuple(strides),
+ )
+ else:
+ raw_view = raw_view.view(kv_cache_shape)
+ kv_caches[layer_name] = raw_view.permute(*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
index 83fc12cb5c3b..7513bc0f300f 100644
--- a/vllm/v1/worker/utils.py
+++ b/vllm/v1/worker/utils.py
@@ -112,7 +112,10 @@ def init_meta(
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
- Only AttentionSpec layers are processed; Mamba layers are skipped.
+ Only decoder-style AttentionSpec layers are processed here.
+ Mamba layers are skipped, and encoder-only attention is skipped
+ because it does not participate in the decode-time KV block reuse
+ path handled by this zeroer.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
@@ -120,7 +123,9 @@ def init_meta(
for group in attn_groups_iter:
spec = group.kv_cache_spec
- if type(spec) is not FullAttentionSpec:
+ if not isinstance(spec, AttentionSpec) or isinstance(
+ spec, EncoderOnlyAttentionSpec
+ ):
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue