diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 9d41185f1ae1..5c99eb20bc7d 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -249,7 +249,7 @@ Token counting starts from `reasoning_start_str`. Once the reasoning token count To use this feature: - `--reasoning-parser` enables reasoning extraction. -- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`). +- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`). If not set, vLLM will attempt to automatically initialize these tokens from the reasoning parser. - `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit. If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`. diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py similarity index 72% rename from tests/v1/entrypoints/openai/test_thinking_token_budget.py rename to tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py index b3f6d53ab1d4..d2db50082a55 100644 --- a/tests/v1/entrypoints/openai/test_thinking_token_budget.py +++ b/tests/entrypoints/openai/chat_completion/test_thinking_token_budget.py @@ -24,6 +24,24 @@ def server(): "--max-model-len", "2048", "--enforce-eager", + "--gpu-memory-utilization", + "0.4", + "--no-async-scheduling", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def server_with_auto_reasoning_config(): + args = [ + "--reasoning-parser", + "qwen3", + "--max-model-len", + "2048", + "--enforce-eager", + "--gpu-memory-utilization", + "0.4", "--no-async-scheduling", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -31,12 +49,18 @@ def server(): @pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: +async def client(request, server, server_with_auto_reasoning_config): + server_map = { + "default": server, + "auto_config": server_with_auto_reasoning_config, + } + target_server = server_map[request.param] + async with target_server.get_async_client() as async_client: yield async_client @pytest.mark.asyncio +@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True) async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI): """Test that mixed requests (some with thinking_token_budget, some without) complete successfully without errors.""" @@ -61,6 +85,7 @@ async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI): @pytest.mark.asyncio +@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True) async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI): """Test that thinking_token_budget limits the number of reasoning tokens. @@ -82,6 +107,6 @@ async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI reasoning_token_count += 1 assert reasoning_token_count == THINK_BUDGET, ( - f"reasoning tokens ({reasoning_token_count}) != " + f"reasoning tokens ({reasoning_token_count}) exceeded " f"thinking_token_budget ({THINK_BUDGET})" ) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 3ebf9cc3713a..3da49f3f7574 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -368,6 +368,23 @@ def test_auto_backend_selection_behavior(): assert backend_auto.get_name() == backend_none.get_name() +def test_flash_attn_rejects_int4_kv_cache(monkeypatch: pytest.MonkeyPatch): + try: + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + except ImportError: + pytest.skip("vllm_flash_attn extension is not available in this env") + + monkeypatch.setattr( + "vllm.v1.attention.backends.flash_attn.flash_attn_supports_fp8", + lambda: True, + ) + + assert FlashAttentionBackend.supports_kv_cache_dtype("fp8") + assert not FlashAttentionBackend.supports_kv_cache_dtype( + "int4_per_token_head" + ) + + @pytest.mark.parametrize( "backend_name,flash_attn_version,should_succeed", [ diff --git a/tests/kernels/attention/test_triton_int4_kv_cache.py b/tests/kernels/attention/test_triton_int4_kv_cache.py new file mode 100644 index 000000000000..50d56cc47125 --- /dev/null +++ b/tests/kernels/attention/test_triton_int4_kv_cache.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backend import AttentionType +from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionBackend, + TritonAttentionImpl, +) +from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) +from vllm.v1.attention.ops.triton_unified_attention import unified_attention +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVQuantMode, get_kv_quant_mode + + +def _expand_scale(scale: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + return scale.unsqueeze(-1) if scale.ndim == x.ndim - 1 else scale + + +def _quantize_int4(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return torch.round(x.float() / _expand_scale(scale, x)).clamp(-8, 7).to( + torch.int32 + ) + + +def _dequantize_int4_values(q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return q.to(torch.float32) * _expand_scale(scale, q) + + +def _pack_int4(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + q = _quantize_int4(x, scale) + q_u4 = torch.where(q < 0, q + 16, q).to(torch.uint8) + if q_u4.shape[-1] % 2: + q_u4 = torch.nn.functional.pad(q_u4, (0, 1)) + return q_u4[..., 0::2] | (q_u4[..., 1::2] << 4) + + +def _dequantize_packed_int4( + packed: torch.Tensor, + scale: torch.Tensor, + head_size: int, +) -> torch.Tensor: + low = (packed & 0x0F).to(torch.int16) + high = (packed >> 4).to(torch.int16) + low = torch.where(low >= 8, low - 16, low) + high = torch.where(high >= 8, high - 16, high) + + out = torch.empty( + *packed.shape[:-1], + packed.shape[-1] * 2, + device=packed.device, + dtype=torch.float32, + ) + scale_expanded = _expand_scale(scale.to(torch.float32), out[..., 0::2]) + out[..., 0::2] = low.to(torch.float32) * scale_expanded + out[..., 1::2] = high.to(torch.float32) * scale_expanded + return out[..., :head_size] + + +def _int4_scales_per_token_head(x: torch.Tensor) -> torch.Tensor: + return x.abs().amax(dim=-1).float().clamp(min=1e-6) / 7.0 + + +def _get_int4_inline_scale_views( + kv_cache: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + key_cache, value_cache = kv_cache.unbind(1) + _, block_size, num_heads, padded_hs = key_cache.shape + scale_pad = 4 + data_hs_padded = padded_hs - scale_pad + assert data_hs_padded % 4 == 0 + + raw = kv_cache.untyped_storage() + base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(raw) + + kv_half_bytes = block_size * num_heads * padded_hs + full_block_f32 = 2 * kv_half_bytes // 4 + slot_f32 = num_heads * padded_hs // 4 + head_f32 = padded_hs // 4 + scale_off_f32 = data_hs_padded // 4 + + k_scale_cache = torch.as_strided( + base_f32, + size=(kv_cache.shape[0], block_size, num_heads), + stride=(full_block_f32, slot_f32, head_f32), + storage_offset=scale_off_f32, + ) + v_scale_cache = torch.as_strided( + base_f32, + size=(kv_cache.shape[0], block_size, num_heads), + stride=(full_block_f32, slot_f32, head_f32), + storage_offset=(kv_half_bytes // 4) + scale_off_f32, + ) + k_scale_cache.fill_(1.0) + v_scale_cache.fill_(1.0) + return key_cache, value_cache, k_scale_cache, v_scale_cache + + +def _ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, +) -> torch.Tensor: + outputs: list[torch.Tensor] = [] + start_idx = 0 + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + for i, query_len in enumerate(query_lens): + kv_len = kv_lens[i] + q = query[start_idx : start_idx + query_len] * scale + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + + scores = torch.einsum("qhd,khd->hqk", q, k).float() + mask = torch.triu( + torch.ones(query_len, kv_len, device=q.device), + diagonal=kv_len - query_len + 1, + ).bool() + scores.masked_fill_(mask, float("-inf")) + probs = torch.softmax(scores, dim=-1).to(v.dtype) + outputs.append(torch.einsum("hqk,khd->qhd", probs, v)) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +def test_int4_kv_cache_shape_and_page_size(): + assert get_kv_quant_mode("int4_per_token_head") == KVQuantMode.INT4_PER_TOKEN_HEAD + assert TritonAttentionBackend.get_kv_cache_shape( + 8, 16, 4, 128, "int4_per_token_head" + ) == ( + 8, + 2, + 16, + 4, + 68, + ) + + spec = FullAttentionSpec( + block_size=16, + num_kv_heads=4, + head_size=128, + dtype=torch.uint8, + kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD, + ) + assert spec.page_size_bytes == 16 * 4 * (68 + 68) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA") +@torch.inference_mode() +def test_int4_inline_scale_views_respect_padded_block_stride(): + device = "cuda" + num_blocks = 3 + block_size = 16 + num_kv_heads = 4 + head_size = 128 + + kv_cache_shape = TritonAttentionBackend.get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size, "int4_per_token_head" + ) + contiguous_block_elems = kv_cache_shape[1] * kv_cache_shape[2] * kv_cache_shape[3] * kv_cache_shape[4] + padded_block_elems = contiguous_block_elems + 128 + raw = torch.zeros(num_blocks * padded_block_elems, dtype=torch.uint8, device=device) + + half_stride = kv_cache_shape[2] * kv_cache_shape[3] * kv_cache_shape[4] + slot_stride = kv_cache_shape[3] * kv_cache_shape[4] + head_stride = kv_cache_shape[4] + kv_cache = torch.as_strided( + raw, + size=kv_cache_shape, + stride=(padded_block_elems, half_stride, slot_stride, head_stride, 1), + ) + + impl = TritonAttentionImpl( + num_heads=num_kv_heads, + head_size=head_size, + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="int4_per_token_head", + attn_type=AttentionType.DECODER, + ) + impl._ensure_scale_caches(kv_cache) + assert impl._k_scale_cache is not None + assert impl._v_scale_cache is not None + + for blk in range(num_blocks): + impl._k_scale_cache[blk].fill_(10.0 + blk) + impl._v_scale_cache[blk].fill_(20.0 + blk) + + base_f32 = torch.tensor([], dtype=torch.float32, device=device).set_(kv_cache.untyped_storage()) + block_f32 = padded_block_elems // 4 + half_f32 = half_stride // 4 + slot_f32 = slot_stride // 4 + head_f32 = head_stride // 4 + scale_off_f32 = (kv_cache_shape[-1] - 4) // 4 + + for blk in range(num_blocks): + for slot in (0, block_size - 1): + for head in (0, num_kv_heads - 1): + k_idx = blk * block_f32 + slot * slot_f32 + head * head_f32 + scale_off_f32 + v_idx = blk * block_f32 + half_f32 + slot * slot_f32 + head * head_f32 + scale_off_f32 + assert base_f32[k_idx].item() == pytest.approx(10.0 + blk) + assert base_f32[v_idx].item() == pytest.approx(20.0 + blk) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA") +@torch.inference_mode() +def test_triton_reshape_and_cache_flash_int4(): + device = "cuda" + set_random_seed(0) + torch.set_default_device(device) + + num_tokens = 23 + num_heads = 4 + head_size = 97 + block_size = 16 + num_blocks = 8 + packed_head_size = (head_size + 1) // 2 + packed_head_size_padded = ((packed_head_size + 3) // 4) * 4 + + key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) + value = torch.randn_like(key) + slot_mapping = torch.randperm(num_blocks * block_size, device=device)[:num_tokens] + + kv_cache = torch.zeros( + num_blocks, + 2, + block_size, + num_heads, + packed_head_size_padded + 4, + dtype=torch.uint8, + device=device, + ) + key_cache, value_cache, k_scale_cache, v_scale_cache = _get_int4_inline_scale_views( + kv_cache + ) + + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + "int4_per_token_head", + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + k_scale_cache, + v_scale_cache, + ) + + got_key = _dequantize_packed_int4( + key_cache[..., :packed_head_size], k_scale_cache, head_size + ) + got_value = _dequantize_packed_int4( + value_cache[..., :packed_head_size], v_scale_cache, head_size + ) + + expected_key = torch.zeros( + num_blocks, block_size, num_heads, head_size, device=device, dtype=torch.float32 + ) + expected_value = torch.zeros_like(expected_key) + expected_k_scales = torch.ones( + num_blocks, block_size, num_heads, device=device, dtype=torch.float32 + ) + expected_v_scales = torch.ones_like(expected_k_scales) + key_scales = _int4_scales_per_token_head(key) + value_scales = _int4_scales_per_token_head(value) + key_roundtrip = _dequantize_int4_values( + _quantize_int4(key, key_scales), key_scales + ) + value_roundtrip = _dequantize_int4_values( + _quantize_int4(value, value_scales), value_scales + ) + + for token_idx, slot in enumerate(slot_mapping.cpu().tolist()): + block_idx = slot // block_size + block_offset = slot % block_size + expected_key[block_idx, block_offset] = key_roundtrip[token_idx] + expected_value[block_idx, block_offset] = value_roundtrip[token_idx] + expected_k_scales[block_idx, block_offset] = key_scales[token_idx] + expected_v_scales[block_idx, block_offset] = value_scales[token_idx] + + atol = max(float(key_scales.max().item()), float(value_scales.max().item())) + 1e-4 + torch.testing.assert_close(got_key, expected_key, atol=atol, rtol=0.0) + torch.testing.assert_close(got_value, expected_value, atol=atol, rtol=0.0) + torch.testing.assert_close(k_scale_cache, expected_k_scales, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(v_scale_cache, expected_v_scales, atol=1e-4, rtol=1e-4) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="requires CUDA") +@torch.inference_mode() +def test_triton_unified_attention_int4(): + device = "cuda" + set_random_seed(1) + torch.set_default_device(device) + + query_lens = [1, 5, 3] + kv_lens = [33, 18, 27] + num_blocks = 64 + block_size = 16 + num_query_heads = 4 + num_kv_heads = 2 + head_size = 128 + + query = torch.randn( + sum(query_lens), + num_query_heads, + head_size, + dtype=torch.bfloat16, + device=device, + ) + key_cache = torch.randn( + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=torch.bfloat16, + device=device, + ) + value_cache = torch.randn_like(key_cache) + + k_scale = _int4_scales_per_token_head(key_cache) + v_scale = _int4_scales_per_token_head(value_cache) + key_cache_int4 = _pack_int4(key_cache, k_scale) + value_cache_int4 = _pack_int4(value_cache, v_scale) + + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32, device=device) + cu_query_lens = cu_query_lens.cumsum(dim=0, dtype=torch.int32) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32, device=device) + + max_kv_len = max(kv_lens) + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, + num_blocks, + (len(query_lens), max_num_blocks_per_seq), + dtype=torch.int32, + device=device, + ) + + output = torch.empty_like(query) + unified_attention( + q=query, + k=key_cache_int4, + v=value_cache_int4, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens_tensor, + max_seqlen_q=max(query_lens), + max_seqlen_k=max_kv_len, + softmax_scale=head_size**-0.5, + causal=True, + window_size=(-1, -1), + block_table=block_tables, + softcap=0.0, + q_descale=None, + k_descale=None, + v_descale=None, + kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD, + k_scale_cache=k_scale, + v_scale_cache=v_scale, + ) + + ref_output = _ref_paged_attn( + query=query.float(), + key_cache=_dequantize_packed_int4(key_cache_int4, k_scale, head_size), + value_cache=_dequantize_packed_int4(value_cache_int4, v_scale, head_size), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=head_size**-0.5, + ) + torch.testing.assert_close(output.float(), ref_output.float(), atol=2e-2, rtol=2e-2) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d8ecf28cbed1..fafb2952d51a 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,6 +24,7 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, + estimate_token_capacity_for_kv_cache_config, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -43,6 +44,7 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + KVQuantMode, MambaSpec, MLAAttentionSpec, SlidingWindowSpec, @@ -107,6 +109,7 @@ def new_kv_cache_spec( num_kv_heads=2, head_size=64, dtype=torch.float32, + kv_quant_mode=KVQuantMode.NONE, page_size_padded=None, sliding_window=None, attention_chunk_size=None, @@ -116,6 +119,7 @@ def new_kv_cache_spec( num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, + kv_quant_mode=kv_quant_mode, page_size_padded=page_size_padded, sliding_window=sliding_window, attention_chunk_size=attention_chunk_size, @@ -127,6 +131,7 @@ def new_sliding_window_spec( num_kv_heads=2, head_size=64, dtype=torch.float32, + kv_quant_mode=KVQuantMode.NONE, page_size_padded=None, sliding_window=1, ): @@ -135,6 +140,7 @@ def new_sliding_window_spec( num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, + kv_quant_mode=kv_quant_mode, page_size_padded=page_size_padded, sliding_window=sliding_window, ) @@ -1418,6 +1424,73 @@ def test_get_max_concurrency_for_kv_cache_config(): assert max_concurrency_hybrid_model == 3 +def test_estimate_token_capacity_for_kv_cache_config_reports_total_cached_tokens(): + model_id = "Qwen/Qwen1.5-7B" + max_model_len = 32768 + model_config = ModelConfig( + model_id, + runner="generate", + dtype="float16", + max_model_len=max_model_len, + ) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=1024, + enable_chunked_prefill=True, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=scheduler_config, + ) + + full_attention_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=4, + head_size=128, + dtype=torch.uint8, + kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD, + ) + sliding_window_spec = SlidingWindowSpec( + block_size=16, + num_kv_heads=4, + head_size=128, + dtype=torch.uint8, + kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD, + sliding_window=1024, + ) + + num_blocks = (1024 + 64) * 3 + group_size = 32 + page_size = full_attention_spec.page_size_bytes + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor(size=page_size * num_blocks, shared_by=[]) + for _ in range(group_size) + ], + kv_cache_groups=[ + KVCacheGroupSpec([f"full_{i}" for i in range(group_size)], full_attention_spec), + KVCacheGroupSpec( + [f"sw_{i}" for i in range(group_size)], + sliding_window_spec, + ), + ], + ) + + expected_capacity = int( + get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config) + * vllm_config.model_config.max_model_len + ) + + assert ( + estimate_token_capacity_for_kv_cache_config(vllm_config, kv_cache_config) + == expected_capacity + ) + + assert expected_capacity > vllm_config.model_config.max_model_len + + def test_allocate_with_lookahead(): """Verify that lookahead tokens correctly affect block allocation""" block_size = 4 @@ -1752,6 +1825,30 @@ def test_get_kv_cache_config_one_worker(): vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] )[0] + # The same non-divisible hybrid layout can be supported for runtime-scale + # quantized KV cache by padding the smaller page up to the larger one. + kv_cache_specs_hybrid_runtime_quant = { + "layer_1": new_kv_cache_spec( + head_size=64, dtype=torch.uint8, kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD + ), + "layer_2": new_sliding_window_spec( + head_size=96, dtype=torch.uint8, kv_quant_mode=KVQuantMode.INT4_PER_TOKEN_HEAD + ), + } + kv_cache_config_hybrid_runtime_quant = get_kv_cache_configs( + vllm_config, + [kv_cache_specs_hybrid_runtime_quant], + [mem_per_block_per_layer * 2 * 32], + )[0] + assert len(kv_cache_config_hybrid_runtime_quant.kv_cache_groups) == 2 + assert all( + group.kv_cache_spec.block_size == 16 + for group in kv_cache_config_hybrid_runtime_quant.kv_cache_groups + ) + assert len( + {group.kv_cache_spec.page_size_bytes for group in kv_cache_config_hybrid_runtime_quant.kv_cache_groups} + ) == 1 + # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 kv_cache_config_override_blocks = get_kv_cache_configs( diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index bf29793710a4..9ee6a70abe4c 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -106,6 +106,7 @@ class MockReasoningConfig: reasoning_start_token_ids = [THINK_START_TOKEN_ID] reasoning_end_token_ids = [THINK_END_TOKEN_ID] + enabled = True def _generate_fake_sampling_metadata( diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index d59b74782be2..8c2659b9c7e6 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -336,9 +336,13 @@ def _rocm_aiter_fused_topk_fake( router_logits: torch.Tensor, top_k: int, gate_up: bool, -) -> None: - # tuple[torch.Tensor, torch.Tensor]: - pass +) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = x.shape[0] + topk_weights = torch.empty( + (num_tokens, top_k), dtype=torch.float32, device=x.device + ) + topk_indices = torch.empty((num_tokens, top_k), dtype=torch.int32, device=x.device) + return topk_weights, topk_indices # Cache whether aiter supports FP8 MLA parameters @@ -1918,7 +1922,7 @@ def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool: @staticmethod def shuffle_weight( - self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> torch.Tensor: from aiter.ops.shuffle import shuffle_weight diff --git a/vllm/config/cache.py b/vllm/config/cache.py index cd1554590ea3..e1d5e3363492 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -11,6 +11,7 @@ from vllm.utils.torch_utils import ( is_quantized_kv_cache, kv_cache_uses_per_token_head_scales, + kv_cache_uses_runtime_scale_cache, ) logger = init_logger(__name__) @@ -19,6 +20,7 @@ "auto", "float16", "bfloat16", + "int4_per_token_head", "fp8", "fp8_e4m3", "fp8_e5m2", @@ -242,11 +244,11 @@ def _warn_deprecated_calculate_kv_scales(cls, calculate_kv_scales: bool) -> bool @field_validator("cache_dtype", mode="after") @classmethod def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: - if kv_cache_uses_per_token_head_scales(cache_dtype): + if kv_cache_uses_runtime_scale_cache(cache_dtype): logger.info( "Using %s data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " - "Dynamic per-token-head scales will be computed at runtime.", + "Runtime scales will be computed dynamically.", str(cache_dtype), ) elif is_quantized_kv_cache(cache_dtype): diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index be1e2b6da58d..ff5546e05ebf 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -5,6 +5,7 @@ from vllm.config.model import ModelConfig from vllm.config.utils import config +from vllm.reasoning import ReasoningParserManager from vllm.tokenizers import cached_tokenizer_from_config @@ -18,11 +19,11 @@ class ReasoningConfig: `initialize_token_ids` and are not intended to be set directly. """ - # NOTE: These parameters are temporary, the intent is to derive them - # automatically from the reasoning parser in a future version. - reasoning_start_str: str = "" + reasoning_parser: str = "" + """The name of the ReasoningParser to use for this model.""" + reasoning_start_str: str = "" """String that indicates the start of reasoning.""" - reasoning_end_str: str = "" + reasoning_end_str: str = "" """String that indicates the end of reasoning content.""" _reasoning_start_token_ids: list[int] | None = field( @@ -36,6 +37,16 @@ class ReasoningConfig: """Private backing field for `reasoning_end_token_ids`. Set by `initialize_token_ids`. Not intended to be configured directly.""" + _enabled: bool = field(default=False, init=False, repr=False) + """Private field indicating whether reasoning token IDs have been initialized. + Set to True by `initialize_token_ids` once token IDs are initialized.""" + + @property + def enabled(self) -> bool: + """Returns True if reasoning is enabled (i.e. if token IDs have been + initialized), False otherwise.""" + return self._enabled + @property def reasoning_start_token_ids(self) -> list[int] | None: """Token IDs derived from `reasoning_start_str`. Set automatically by @@ -54,15 +65,36 @@ def initialize_token_ids(self, model_config: ModelConfig) -> None: self._reasoning_start_token_ids is not None and self._reasoning_end_token_ids is not None ): - return + self._enabled = True + return # Already initialized tokenizer = cached_tokenizer_from_config(model_config=model_config) + reasoning_start_str = self.reasoning_start_str + reasoning_end_str = self.reasoning_end_str + if self.reasoning_parser is not None and ( + not reasoning_start_str or not reasoning_end_str + ): + parser_cls = ReasoningParserManager.get_reasoning_parser( + self.reasoning_parser + ) + reasoning_parser = parser_cls(tokenizer) + start_token = reasoning_parser.reasoning_start_str + if start_token and not reasoning_start_str: + reasoning_start_str = start_token + end_token = reasoning_parser.reasoning_end_str + if end_token and not reasoning_end_str: + reasoning_end_str = end_token + + if not reasoning_start_str or not reasoning_end_str: + # If we don't have valid strings to tokenize, + # we can't initialize the token IDs. + return self._reasoning_start_token_ids = tokenizer.encode( - self.reasoning_start_str, add_special_tokens=False + reasoning_start_str, add_special_tokens=False ) self._reasoning_end_token_ids = tokenizer.encode( - self.reasoning_end_str, add_special_tokens=False + reasoning_end_str, add_special_tokens=False ) if not self._reasoning_start_token_ids or not self._reasoning_end_token_ids: @@ -72,3 +104,4 @@ def initialize_token_ids(self, model_config: ModelConfig) -> None: f"reasoning_end_str='{self.reasoning_end_str}'. " "Ensure the strings are valid tokens in the model's vocabulary." ) + self._enabled = True diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 6229b44d52a8..3b8431e9530a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1210,6 +1210,12 @@ def has_blocked_weights(): if self.reasoning_config is not None and self.model_config is not None: self.reasoning_config.initialize_token_ids(self.model_config) + if not self.reasoning_config.enabled: + logger.warning_once( + "Auto-initialization of reasoning token IDs failed. " + "Please check whether your reasoning parser has implemented " + "the `reasoning_start_str` and `reasoning_end_str`." + ) # Hybrid KV cache manager (HMA) runtime rules: # - Explicit enable (--no-disable-kv-cache-manager): error if runtime diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 55c87bf356c5..c9b90848ff04 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1591,7 +1591,7 @@ def create_engine_config( self._set_default_max_num_seqs_and_batched_tokens_args( usage_context, model_config ) - + self._set_default_reasoning_config_args() sliding_window: int | None = None if not is_interleaved(model_config.hf_text_config): # Only set CacheConfig.sliding_window if the model is all sliding @@ -2233,6 +2233,13 @@ def _set_default_chunked_prefill_and_prefix_caching_args( ) self.enable_prefix_caching = False + def _set_default_reasoning_config_args(self): + if not self.reasoning_parser: + return + if self.reasoning_config is None: + self.reasoning_config = ReasoningConfig() + self.reasoning_config.reasoning_parser = self.reasoning_parser + def _set_default_max_num_seqs_and_batched_tokens_args( self, usage_context: UsageContext | None, diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index f718c239984b..6a026eef839d 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -177,6 +177,21 @@ ], } +_POSSIBLE_WFP8A16_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { + PlatformEnum.CUDA: [ + MarlinFP8ScaledMMLinearKernel, + ], + PlatformEnum.ROCM: [ + # To be added + ], + PlatformEnum.CPU: [ + # To be added + ], + PlatformEnum.XPU: [ + # To be added + ], +} + # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { PlatformEnum.CUDA: [ @@ -463,6 +478,41 @@ def choose_mp_linear_kernel( ) +def init_wfp8_a16_linear_kernel( + weight_quant_key: QuantKey, + activation_quant_key: QuantKey, + weight_shape: tuple[int, int], + input_dtype: torch.dtype, + out_dtype: torch.dtype, + force_kernel: type[FP8ScaledMMLinearKernel] | None = None, + module_name: str | None = None, +) -> FP8ScaledMMLinearKernel: + config = FP8ScaledMMLinearLayerConfig( + weight_quant_key=weight_quant_key, + activation_quant_key=activation_quant_key, + weight_shape=weight_shape, + input_dtype=input_dtype, + out_dtype=out_dtype, + ) + + kernel_type = choose_scaled_mm_linear_kernel( + config, _POSSIBLE_WFP8A16_KERNELS, force_kernel=force_kernel + ) + + if module_name: + logger.info_once( + "Selected %s for %s", + kernel_type.__name__, + module_name, + scope="global", + ) + + return kernel_type( + config, + layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"], + ) + + # Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes. _NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = { "flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel, @@ -588,6 +638,7 @@ def register_linear_kernel( "init_nvfp4_linear_kernel", "choose_mp_linear_kernel", "register_linear_kernel", + "init_wfp8_a16_linear_kernel", "FP8ScaledMMLinearKernel", "Int8ScaledMMLinearKernel", "ScaledMMLinearKernel", diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index aec855d9aeb1..3ac90c942e73 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -702,19 +702,33 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: num_v_heads = self.num_v_heads // self.tp_size _, state_dtype = self.get_state_dtype() - # All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with - # T = chunk_size is sufficient to populate every autotuner cache. + # All kernels use BT = chunk_size, so a single pass with T = chunk_size + # is sufficient to populate every autotuner cache. Mirror the real + # prefill path here: build q/k/v/g/beta via fused_post_conv_prep and + # then run chunk_gated_delta_rule with in-kernel L2 norm disabled. T = FLA_CHUNK_SIZE - q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) - # NOTE: g and beta must have the same dtypes as during - # inference, so we construct them with the same function - # (fused_gdn_gating). dummy_a and dummy_b are throwaway - # inputs required by that function. + dummy_mixed_qkv = torch.randn( + T, mixed_qkv.shape[-1], device=device, dtype=dtype + ) dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) - g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias) + q, k, v, g, beta = fused_post_conv_prep( + conv_output=dummy_mixed_qkv, + a=dummy_a, + b=dummy_b, + A_log=self.A_log, + dt_bias=self.dt_bias, + num_k_heads=num_k_heads, + head_k_dim=self.head_k_dim, + head_v_dim=self.head_v_dim, + apply_l2norm=True, + output_g_exp=False, + ) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + g = g.unsqueeze(0) + beta = beta.unsqueeze(0) state = torch.zeros( 1, num_v_heads, @@ -735,7 +749,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: initial_state=state, output_final_state=True, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=False, ) except Exception: logger.warning( @@ -753,7 +767,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: self.prefix, ) finally: - del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens + del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens torch.accelerator.empty_cache() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 7bffc3218b42..42b35a420cab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -6,45 +6,49 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from vllm.config import get_current_vllm_config +from vllm.model_executor.kernels.linear import ( + init_wfp8_a16_linear_kernel, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + STRATEGY_TO_PARAMETER_TYPE, + STRATEGY_TO_WEIGHT_QUANT_KEY, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( create_fp8_scale_parameter, create_fp8_weight_parameter, - process_fp8_weight_block_strategy, validate_fp8_block_shape, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8DynamicTensorSym, + kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) -from vllm.model_executor.parameter import ( - BlockQuantScaleParameter, - ChannelQuantScaleParameter, - PerTensorScaleParameter, -) +from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import replace_parameter __all__ = ["CompressedTensorsW8A16Fp8"] -strategy_to_parameter_type = { - QuantizationStrategy.BLOCK: BlockQuantScaleParameter, - QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, - QuantizationStrategy.TENSOR: PerTensorScaleParameter, -} - class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype self.is_static_input_scheme = is_static_input_scheme self.weight_block_size = self.weight_quant.block_structure + self.weight_quant_key = STRATEGY_TO_WEIGHT_QUANT_KEY[self.strategy] + self.activation_quant_key = ( + kFp8StaticTensorSym if is_static_input_scheme else kFp8DynamicTensorSym + ) + @classmethod def get_min_capability(cls) -> int: # turing and up @@ -89,7 +93,7 @@ def create_weights( # WEIGHT SCALE weight_scale = create_fp8_scale_parameter( - strategy_to_parameter_type[self.strategy], + STRATEGY_TO_PARAMETER_TYPE[self.strategy], output_partition_sizes, input_size_per_partition, layer.weight_block_size, @@ -105,32 +109,36 @@ def create_weights( ) layer.register_parameter("input_scale", input_scale) + self.linear_kernel = init_wfp8_a16_linear_kernel( + weight_quant_key=self.weight_quant_key, + activation_quant_key=self.activation_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = layer.weight - weight_scale = layer.weight_scale - size_k_first = True - # TODO(rob): refactor block quant into separate class. if self.strategy == QuantizationStrategy.BLOCK: assert self.is_static_input_scheme is False - size_k_first = False - weight, weight_scale = process_fp8_weight_block_strategy( - weight, weight_scale - ) + # MarlinFP8ScaledMMLinearKernel uses "weight_scale_inv" for block + # quant, while CT registers the scale as "weight_scale". + # Rename by deleting the old parameter and adding the new one so + # that prepare_fp8_layer_for_marlin (which prefers "weight_scale" + # over "weight_scale_inv") picks up "weight_scale_inv" correctly. + weight_scale_data = layer.weight_scale.data + del layer._parameters["weight_scale"] + replace_parameter(layer, "weight_scale_inv", weight_scale_data) else: - # Weights must be transposed for marlin - weight = weight.t() if self.strategy == QuantizationStrategy.TENSOR: - # If we have a fused module (QKV, MLP) with per tensor scales, - # we expand each scale to its shard's channels. - weight_scale = convert_to_channelwise( - weight_scale, layer.logical_widths + # For fused modules with per-tensor scales, expand each scale + # to its shard's channels. + replace_parameter( + layer, + "weight_scale", + convert_to_channelwise(layer.weight_scale, layer.logical_widths), ) - # Update layer with new values - replace_parameter(layer, "weight", weight.data) - replace_parameter(layer, "weight_scale", weight_scale.data) - - prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first) + self.linear_kernel.process_weights_after_loading(layer) def apply_weights( self, @@ -138,12 +146,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + return self.linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c6b810eb9679..3bf606ddb332 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -16,6 +16,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + STRATEGY_TO_PARAMETER_TYPE, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( create_fp8_input_scale, create_fp8_scale_parameter, @@ -34,20 +37,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, ) -from vllm.model_executor.parameter import ( - BlockQuantScaleParameter, - ChannelQuantScaleParameter, - PerTensorScaleParameter, -) __all__ = ["CompressedTensorsW8A8Fp8"] -strategy_to_parameter_type = { - QuantizationStrategy.BLOCK: BlockQuantScaleParameter, - QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, - QuantizationStrategy.TENSOR: PerTensorScaleParameter, -} - STATIC_QUANT = True DYNAMIC_QUANT = False activation_quant_key_mapping = { @@ -130,7 +122,7 @@ def create_weights( # WEIGHT SCALE weight_scale = create_fp8_scale_parameter( - strategy_to_parameter_type[self.strategy], + STRATEGY_TO_PARAMETER_TYPE[self.strategy], output_partition_sizes, input_size_per_partition, layer.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index f88092169110..04c64d9bd56f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -6,8 +6,36 @@ import regex as re from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import QuantizationStrategy from torch.nn import Module +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter, +) + +# Maps quantization strategy to the corresponding scale parameter type. +# Shared across compressed-tensor scheme classes (w8a16_fp8, w8a8_fp8, …). +STRATEGY_TO_PARAMETER_TYPE = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + +# Maps quantization strategy to the vLLM weight-quant key used for +# kernel selection. Shared across compressed-tensor scheme classes. +STRATEGY_TO_WEIGHT_QUANT_KEY = { + QuantizationStrategy.BLOCK: kFp8Static128BlockSym, + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, +} + def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 726ac2232af9..3ed14ae808df 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -10,7 +10,7 @@ ) from vllm.platforms import current_platform from vllm.utils.torch_utils import is_quantized_kv_cache -from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales +from vllm.v1.kv_cache_interface import kv_cache_uses_runtime_scales logger = init_logger(__name__) @@ -54,14 +54,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert not hasattr(layer, "prob_scale") return - # Per-token-head quantized KV cache: scales are computed dynamically - # per (token, head) in the kernel at cache-write time. Checkpoint - # scales are never used regardless of calculate_kv_scales. - if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype): + # Runtime-scale quantized KV cache: scales are computed dynamically + # at cache-write time. Checkpoint scales are never used. + if kv_cache_uses_runtime_scales(layer.kv_cache_dtype): layer._k_scale.copy_(1.0) layer._v_scale.copy_(1.0) layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 + layer.calculate_kv_scales = False del layer.k_scale del layer.v_scale del layer.q_scale diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 7244b9fac84f..2ff72ea2f5e0 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -39,6 +39,20 @@ def vocab(self) -> dict[str, int]: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() + @property + def reasoning_start_str(self) -> str | None: + """Set `reasoning_start_str` to the strings that delimit + the reasoning block (e.g. `""""` and `""`). + """ + return None + + @property + def reasoning_end_str(self) -> str | None: + """Set `reasoning_end_str` to the strings that delimit + the reasoning block (e.g. `""""` and `""`). + """ + return None + @abstractmethod def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index a8bb33d2c9cd..938b7f736b2c 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -39,6 +39,14 @@ def end_token(self) -> str: """The token that ends reasoning content.""" raise NotImplementedError + @property + def reasoning_start_str(self) -> str: + return self.start_token + + @property + def reasoning_end_str(self) -> str: + return self.end_token + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 94f8c096e313..e1efbe58a842 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -33,6 +33,7 @@ "float16": torch.float16, "bfloat16": torch.bfloat16, "float": torch.float, + "int4_per_token_head": torch.uint8, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, @@ -64,7 +65,11 @@ def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: - return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head") + return ( + kv_cache_dtype == "int4_per_token_head" + or kv_cache_dtype.startswith("fp8") + or kv_cache_dtype.endswith("per_token_head") + ) def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: @@ -72,6 +77,21 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: return kv_cache_dtype.endswith("per_token_head") +def kv_cache_uses_runtime_scale_cache(kv_cache_dtype: str) -> bool: + return ( + kv_cache_dtype == "int4_per_token_head" + or kv_cache_uses_per_token_head_scales(kv_cache_dtype) + ) + + +def kv_cache_uses_int4_packing(kv_cache_dtype: str) -> bool: + return kv_cache_dtype == "int4_per_token_head" + + +def kv_cache_uses_fp8_storage(kv_cache_dtype: str) -> bool: + return kv_cache_dtype.startswith("fp8") + + def is_strictly_contiguous(t: torch.Tensor) -> bool: """ Check if tensor is contiguous AND has no degenerate strides. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b9ccd4fce1c3..90139d3f3315 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,7 +11,10 @@ from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import ( + is_quantized_kv_cache, + kv_cache_uses_fp8_storage, +) from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, @@ -184,7 +187,10 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: if kv_cache_dtype is None: return True if is_quantized_kv_cache(kv_cache_dtype): - return flash_attn_supports_fp8() + return ( + kv_cache_uses_fp8_storage(kv_cache_dtype) + and flash_attn_supports_fp8() + ) return kv_cache_dtype in ["auto", "float16", "bfloat16"] @classmethod diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index bd8ec29bc33a..3598f77b137f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,7 +18,12 @@ from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import next_power_of_2 -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import ( + is_quantized_kv_cache, + kv_cache_uses_fp8_storage, + kv_cache_uses_int4_packing, + kv_cache_uses_runtime_scale_cache, +) from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -271,6 +276,7 @@ class TritonAttentionBackend(AttentionBackend): "auto", "float16", "bfloat16", + "int4_per_token_head", "fp8", "fp8_e4m3", "fp8_e5m2", @@ -308,6 +314,19 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + if kv_cache_uses_int4_packing(cache_dtype_str): + packed_head_size = (head_size + 1) // 2 + packed_head_size_padded = ((packed_head_size + 3) // 4) * 4 + # Reserve 4 inline bytes per (token, head) for the float32 runtime + # scale. This keeps int4 accounting honest without needing a second + # allocation path for scale caches. + return ( + num_blocks, + 2, + block_size, + num_kv_heads, + packed_head_size_padded + 4, + ) if kv_cache_uses_per_token_head_scales(cache_dtype_str): # Pad head_size by sizeof(float32)/sizeof(cache_dtype) so # the per-head scale fits inline. The backend extracts @@ -391,19 +410,18 @@ def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None: """Extract per-head scale views from the padded head dimension. The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)`` - where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last - ``pad`` elements of each head hold one float32 scale. We create - strided float32 views over those bytes. + where the trailing bytes in each head region hold one float32 scale. + For int8/fp8 per-token-head cache, the scale sits after the dense data. + For int4 cache, the scale sits after the packed-bytes region padded to a + 4-byte boundary. We create strided float32 views over those bytes. Scale shape: ``(num_blocks, block_size, num_kv_heads)`` """ if self._k_scale_cache is not None: return - from vllm.utils.torch_utils import get_dtype_size - num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape dtype_sz = kv_cache.element_size() - scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4 + scale_pad = torch.tensor([], dtype=torch.float32).element_size() // dtype_sz hs = padded_hs - scale_pad raw = kv_cache.untyped_storage() @@ -411,30 +429,51 @@ def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None: raw ) - # In the raw bytes, each (block, kv_half, slot, head) occupies - # padded_hs * dtype_sz bytes. The scale float32 sits at byte - # offset hs * dtype_sz within that region. - kv_half_bytes = block_size * nkv * padded_hs * dtype_sz - full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks - slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots - head_f32 = padded_hs * dtype_sz // 4 # stride between heads - scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head + # Use the actual tensor strides rather than re-deriving contiguous + # ones from shape. Hybrid page padding can inflate the leading block + # stride without changing the semantic cache shape. + scale_size = torch.tensor([], dtype=torch.float32).element_size() + base_byte_offset = kv_cache.storage_offset() * dtype_sz + block_byte_stride = kv_cache.stride(0) * dtype_sz + kv_half_byte_stride = kv_cache.stride(1) * dtype_sz + slot_byte_stride = kv_cache.stride(2) * dtype_sz + head_byte_stride = kv_cache.stride(3) * dtype_sz + last_dim_byte_stride = kv_cache.stride(4) * dtype_sz + scale_off_bytes = hs * last_dim_byte_stride + + for value in ( + base_byte_offset, + block_byte_stride, + kv_half_byte_stride, + slot_byte_stride, + head_byte_stride, + scale_off_bytes, + ): + assert value % scale_size == 0, ( + "Scale cache view requires float32-aligned byte offsets/strides." + ) + + block_f32 = block_byte_stride // scale_size + kv_half_f32 = kv_half_byte_stride // scale_size + slot_f32 = slot_byte_stride // scale_size + head_f32 = head_byte_stride // scale_size + scale_off_f32 = (base_byte_offset + scale_off_bytes) // scale_size # K scales: kv_half=0 self._k_scale_cache = torch.as_strided( base_f32, size=(num_blocks, block_size, nkv), - stride=(full_block_f32, slot_f32, head_f32), + stride=(block_f32, slot_f32, head_f32), storage_offset=scale_off_f32, ) self._k_scale_cache.fill_(1.0) - # V scales: kv_half=1, offset by kv_half_bytes - v_base_f32 = kv_half_bytes // 4 + # V scales: kv_half=1, offset by the actual kv-half stride + v_base_f32 = kv_half_f32 self._v_scale_cache = torch.as_strided( base_f32, size=(num_blocks, block_size, nkv), - stride=(full_block_f32, slot_f32, head_f32), + stride=(block_f32, slot_f32, head_f32), storage_offset=v_base_f32 + scale_off_f32, ) self._v_scale_cache.fill_(1.0) @@ -494,6 +533,11 @@ def __init__( self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype) self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head + self._uses_int4_kv_cache = kv_cache_uses_int4_packing(kv_cache_dtype) + self._uses_fp8_kv_cache = kv_cache_uses_fp8_storage(kv_cache_dtype) + self._uses_runtime_scale_cache = kv_cache_uses_runtime_scale_cache( + kv_cache_dtype + ) def forward( self, @@ -555,21 +599,22 @@ def forward( layer, ) - # Per-token-head quantized KV cache: use separate scale caches. - if self._is_per_token_head_quant: + # Runtime-scale quantized KV cache: use scale caches extracted from the + # inline cache storage. + if self._uses_runtime_scale_cache: self._ensure_scale_caches(kv_cache) key_cache, value_cache = kv_cache.unbind(1) - if key_cache.dtype == torch.uint8: + if self._is_per_token_head_quant and key_cache.dtype == torch.uint8: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) k_descale = None v_descale = None k_scale_cache = self._k_scale_cache v_scale_cache = self._v_scale_cache - # FP8 per-tensor / auto path (original flow). + # Per-tensor quantized path / auto path. else: key_cache, value_cache = kv_cache.unbind(1) - if is_quantized_kv_cache(self.kv_cache_dtype): + if self._uses_fp8_kv_cache: if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -692,25 +737,39 @@ def do_kv_cache_update( # we use direct Q, K, V tensors without caching return # Reshape the input keys and values and store them in the cache. - if self._is_per_token_head_quant: + if self._uses_runtime_scale_cache: self._ensure_scale_caches(kv_cache) key_cache, value_cache = kv_cache.unbind(1) - if key_cache.dtype == torch.uint8: + if self._is_per_token_head_quant and key_cache.dtype == torch.uint8: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - triton_reshape_and_cache_flash_per_token_head_quant( - key, - value, - key_cache, - value_cache, - self._k_scale_cache, - self._v_scale_cache, - slot_mapping, - ) + if self._is_per_token_head_quant: + triton_reshape_and_cache_flash_per_token_head_quant( + key, + value, + key_cache, + value_cache, + self._k_scale_cache, + self._v_scale_cache, + slot_mapping, + ) + else: + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + self._k_scale_cache, + self._v_scale_cache, + ) return # For decoder and cross-attention, use KV cache as before. key_cache, value_cache = kv_cache.unbind(1) - if is_quantized_kv_cache(self.kv_cache_dtype): + if self._uses_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) triton_reshape_and_cache_flash( @@ -725,7 +784,7 @@ def do_kv_cache_update( ) def fused_rope_kvcache_supported(self): - if self._is_per_token_head_quant: + if self._is_per_token_head_quant or self._uses_int4_kv_cache: return False return rocm_aiter_ops.is_enabled() @@ -744,7 +803,12 @@ def do_rope_and_kv_cache_update( key_cache, value_cache = kv_cache.unbind(1) flash_layout = True - is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype) + if self._uses_int4_kv_cache: + raise NotImplementedError( + "fused rope + int4 KV cache is not supported yet" + ) + + is_fp8_kv_cache = self._uses_fp8_kv_cache if is_fp8_kv_cache: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py index 6e696fdb5135..0ff3db5edad3 100644 --- a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py @@ -9,11 +9,20 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import ( + is_quantized_kv_cache, + kv_cache_uses_fp8_storage, + kv_cache_uses_int4_packing, +) FP8_MIN, FP8_MAX = get_fp8_min_max() +@triton.jit +def _round_to_int32(x): + return tl.where(x >= 0, tl.floor(x + 0.5), tl.ceil(x - 0.5)).to(tl.int32) + + @triton.jit def reshape_and_cache_kernel_flash( key_ptr, # [num_tokens, num_heads, head_size] @@ -38,6 +47,7 @@ def reshape_and_cache_kernel_flash( USE_HEAD_MAJOR_LAYOUT: tl.constexpr, # FP8 flags FP8_KV_CACHE: tl.constexpr, + INT4_KV_CACHE: tl.constexpr, # tune parameters TILE_SIZE: tl.constexpr, ): @@ -57,6 +67,8 @@ def reshape_and_cache_kernel_flash( src_value_idx = token_idx * value_stride if USE_HEAD_MAJOR_LAYOUT: + if INT4_KV_CACHE: + tl.device_assert(False, "int4 KV cache does not support head-major layout") # Decompose the tile index back into head and dim coordinates. cur_head = tile_pos // head_size cur_dim = tile_pos % head_size @@ -76,8 +88,9 @@ def reshape_and_cache_kernel_flash( + (cur_dim % x) ) else: - cur_head = tile_pos // head_size - cur_dim = tile_pos % head_size + head_size_physical = (head_size + 1) // 2 if INT4_KV_CACHE else head_size + cur_head = tile_pos // head_size_physical + cur_dim = tile_pos % head_size_physical tgt_idx_k = ( block_idx * block_stride + block_offset * page_stride @@ -87,39 +100,80 @@ def reshape_and_cache_kernel_flash( tgt_idx_v = tgt_idx_k # [TILE_SIZE] - key_load = tl.load( - key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size) - ) - if FP8_KV_CACHE: - # tl.store will do the correct implicit cast to fp8, - # based on the key_cache_ptr.dtype.element_ty - key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + if INT4_KV_CACHE: + packed_src_pos = cur_head * head_size + cur_dim * 2 + key_lo = tl.load( + key_ptr + src_key_idx + packed_src_pos, + mask=tile_pos < (num_heads * head_size_physical), + other=0.0, + ) + key_hi = tl.load( + key_ptr + src_key_idx + packed_src_pos + 1, + mask=(tile_pos < (num_heads * head_size_physical)) + & ((cur_dim * 2 + 1) < head_size), + other=0.0, + ) + q_key_lo = _round_to_int32(key_lo / tl.load(k_scale)) + q_key_hi = _round_to_int32(key_hi / tl.load(k_scale)) + q_key_lo = tl.maximum(tl.minimum(q_key_lo, 7), -8) + q_key_hi = tl.maximum(tl.minimum(q_key_hi, 7), -8) + key_tile = ((q_key_lo & 0xF) | ((q_key_hi & 0xF) << 4)).to(tl.uint8) else: - key_tile = key_load + key_load = tl.load( + key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size) + ) + if FP8_KV_CACHE: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = ( + key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + ) + else: + key_tile = key_load # [TILE_SIZE] - value_load = tl.load( - value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size) - ) - if FP8_KV_CACHE: - if value_load.dtype.is_fp8(): - value_tile = value_load - else: - # tl.store will do the correct implicit cast to fp8, - # based on the value_cache_ptr.dtype.element_ty - value_tile = value_load / tl.load(v_scale) + if INT4_KV_CACHE: + packed_src_pos = cur_head * head_size + cur_dim * 2 + value_lo = tl.load( + value_ptr + src_value_idx + packed_src_pos, + mask=tile_pos < (num_heads * head_size_physical), + other=0.0, + ) + value_hi = tl.load( + value_ptr + src_value_idx + packed_src_pos + 1, + mask=(tile_pos < (num_heads * head_size_physical)) + & ((cur_dim * 2 + 1) < head_size), + other=0.0, + ) + q_value_lo = _round_to_int32(value_lo / tl.load(v_scale)) + q_value_hi = _round_to_int32(value_hi / tl.load(v_scale)) + q_value_lo = tl.maximum(tl.minimum(q_value_lo, 7), -8) + q_value_hi = tl.maximum(tl.minimum(q_value_hi, 7), -8) + value_tile = ((q_value_lo & 0xF) | ((q_value_hi & 0xF) << 4)).to(tl.uint8) else: - value_tile = value_load + value_load = tl.load( + value_ptr + src_value_idx + tile_pos, + mask=tile_pos < (num_heads * head_size), + ) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load tl.store( key_cache_ptr + tgt_idx_k, key_tile, - mask=tile_pos < (num_heads * head_size), + mask=tile_pos < (num_heads * ((head_size + 1) // 2 if INT4_KV_CACHE else head_size)), ) tl.store( value_cache_ptr + tgt_idx_v, value_tile, - mask=tile_pos < (num_heads * head_size), + mask=tile_pos < (num_heads * ((head_size + 1) // 2 if INT4_KV_CACHE else head_size)), ) return @@ -244,6 +298,200 @@ def _reshape_cache_per_token_head( } +@triton.jit +def _reshape_cache_int4_per_token_head( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size_v] + key_cache_ptr, # [num_blocks, block_size, num_kv_heads, packed_head_size+pad] + value_cache_ptr, # [num_blocks, block_size, num_kv_heads, packed_head_size_v+pad] + k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32 + v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32 + slot_mapping_ptr, # [num_tokens] + stride_key_tok: tl.int64, + stride_key_head: tl.int64, + stride_val_tok: tl.int64, + stride_val_head: tl.int64, + stride_kc_blk: tl.int64, + stride_kc_slot: tl.int64, + stride_kc_head: tl.int64, + stride_vc_blk: tl.int64, + stride_vc_slot: tl.int64, + stride_vc_head: tl.int64, + stride_ks_blk: tl.int64, + stride_ks_slot: tl.int64, + stride_ks_head: tl.int64, + stride_vs_blk: tl.int64, + stride_vs_slot: tl.int64, + stride_vs_head: tl.int64, + block_size: tl.constexpr, + head_size: tl.constexpr, + head_size_v: tl.constexpr, + HEAD_SIZE_PADDED: tl.constexpr, + PACKED_HEAD_SIZE: tl.constexpr, + PACKED_HEAD_SIZE_V: tl.constexpr, + PACKED_HEAD_SIZE_PADDED: tl.constexpr, + PACKED_HEAD_SIZE_V_PADDED: tl.constexpr, +): + tok = tl.program_id(0) + head = tl.program_id(1) + + slot = tl.load(slot_mapping_ptr + tok).to(tl.int64) + if slot < 0: + return + + blk = slot // block_size + slot_in_blk = slot % block_size + + dim_offs = tl.arange(0, HEAD_SIZE_PADDED) + + # ---- Key head -> absmax -> scale -------------------------------------- + k_mask = dim_offs < head_size + k_h = tl.load( + key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs, + mask=k_mask, + other=0.0, + ).to(tl.float32) + k_scale = tl.maximum(tl.max(tl.abs(k_h)) / 7.0, 1e-6) + tl.store( + k_scale_cache_ptr + + blk * stride_ks_blk + + slot_in_blk * stride_ks_slot + + head * stride_ks_head, + k_scale, + ) + + packed_offs = tl.arange(0, PACKED_HEAD_SIZE_PADDED) + k_elem0 = packed_offs * 2 + k_elem1 = k_elem0 + 1 + k_lo = tl.load( + key_ptr + tok * stride_key_tok + head * stride_key_head + k_elem0, + mask=k_elem0 < head_size, + other=0.0, + ).to(tl.float32) + k_hi = tl.load( + key_ptr + tok * stride_key_tok + head * stride_key_head + k_elem1, + mask=k_elem1 < head_size, + other=0.0, + ).to(tl.float32) + q_key_lo = _round_to_int32(k_lo / k_scale) + q_key_hi = _round_to_int32(k_hi / k_scale) + q_key_lo = tl.maximum(tl.minimum(q_key_lo, 7), -8) + q_key_hi = tl.maximum(tl.minimum(q_key_hi, 7), -8) + key_packed = ((q_key_lo & 0xF) | ((q_key_hi & 0xF) << 4)).to(tl.uint8) + tl.store( + key_cache_ptr + + blk * stride_kc_blk + + slot_in_blk * stride_kc_slot + + head * stride_kc_head + + packed_offs, + key_packed, + mask=packed_offs < PACKED_HEAD_SIZE, + ) + + # ---- Value head -> absmax -> scale ------------------------------------ + v_mask = dim_offs < head_size_v + v_h = tl.load( + value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs, + mask=v_mask, + other=0.0, + ).to(tl.float32) + v_scale = tl.maximum(tl.max(tl.abs(v_h)) / 7.0, 1e-6) + tl.store( + v_scale_cache_ptr + + blk * stride_vs_blk + + slot_in_blk * stride_vs_slot + + head * stride_vs_head, + v_scale, + ) + + packed_offs_v = tl.arange(0, PACKED_HEAD_SIZE_V_PADDED) + v_elem0 = packed_offs_v * 2 + v_elem1 = v_elem0 + 1 + v_lo = tl.load( + value_ptr + tok * stride_val_tok + head * stride_val_head + v_elem0, + mask=v_elem0 < head_size_v, + other=0.0, + ).to(tl.float32) + v_hi = tl.load( + value_ptr + tok * stride_val_tok + head * stride_val_head + v_elem1, + mask=v_elem1 < head_size_v, + other=0.0, + ).to(tl.float32) + q_value_lo = _round_to_int32(v_lo / v_scale) + q_value_hi = _round_to_int32(v_hi / v_scale) + q_value_lo = tl.maximum(tl.minimum(q_value_lo, 7), -8) + q_value_hi = tl.maximum(tl.minimum(q_value_hi, 7), -8) + value_packed = ((q_value_lo & 0xF) | ((q_value_hi & 0xF) << 4)).to(tl.uint8) + tl.store( + value_cache_ptr + + blk * stride_vc_blk + + slot_in_blk * stride_vc_slot + + head * stride_vc_head + + packed_offs_v, + value_packed, + mask=packed_offs_v < PACKED_HEAD_SIZE_V, + ) + + +def triton_reshape_and_cache_flash_int4_per_token_head( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + slot_mapping: torch.Tensor, +): + num_tokens, num_kv_heads, head_size = key.shape + head_size_v = value.shape[2] + head_size_padded = triton.next_power_of_2(max(head_size, head_size_v)) + packed_head_size = (head_size + 1) // 2 + packed_head_size_v = (head_size_v + 1) // 2 + packed_head_size_padded = triton.next_power_of_2(max(1, packed_head_size)) + packed_head_size_v_padded = triton.next_power_of_2(max(1, packed_head_size_v)) + block_size = key_cache.shape[1] + + if current_platform.is_rocm() or current_platform.is_xpu(): + num_warps = 4 + else: + num_warps = min(8, max(1, head_size_padded // 32)) + + _reshape_cache_int4_per_token_head[(num_tokens, num_kv_heads)]( + key_ptr=key, + value_ptr=value, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + k_scale_cache_ptr=k_scale_cache, + v_scale_cache_ptr=v_scale_cache, + slot_mapping_ptr=slot_mapping, + stride_key_tok=key.stride(0), + stride_key_head=key.stride(1), + stride_val_tok=value.stride(0), + stride_val_head=value.stride(1), + stride_kc_blk=key_cache.stride(0), + stride_kc_slot=key_cache.stride(1), + stride_kc_head=key_cache.stride(2), + stride_vc_blk=value_cache.stride(0), + stride_vc_slot=value_cache.stride(1), + stride_vc_head=value_cache.stride(2), + stride_ks_blk=k_scale_cache.stride(0), + stride_ks_slot=k_scale_cache.stride(1), + stride_ks_head=k_scale_cache.stride(2), + stride_vs_blk=v_scale_cache.stride(0), + stride_vs_slot=v_scale_cache.stride(1), + stride_vs_head=v_scale_cache.stride(2), + block_size=block_size, + head_size=head_size, + head_size_v=head_size_v, + HEAD_SIZE_PADDED=head_size_padded, + PACKED_HEAD_SIZE=packed_head_size, + PACKED_HEAD_SIZE_V=packed_head_size_v, + PACKED_HEAD_SIZE_PADDED=packed_head_size_padded, + PACKED_HEAD_SIZE_V_PADDED=packed_head_size_v_padded, + num_warps=num_warps, + ) + + def triton_reshape_and_cache_flash_per_token_head_quant( key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v] @@ -324,9 +572,11 @@ def triton_reshape_and_cache_flash( # [num_blocks, block_size, num_heads, head_size] value_cache: torch.Tensor, slot_mapping: torch.Tensor, # [num_tokens] - kv_cache_dtype: str, # "auto", "fp8" + kv_cache_dtype: str, # "auto", "fp8", "int4_per_token_head" k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 + k_scale_cache: torch.Tensor | None = None, + v_scale_cache: torch.Tensor | None = None, ): num_heads = key.shape[1] head_size = key.shape[2] @@ -353,15 +603,15 @@ def triton_reshape_and_cache_flash( assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), ( f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." ) + int4_kv_cache = kv_cache_uses_int4_packing(kv_cache_dtype) + fp8_kv_cache = kv_cache_uses_fp8_storage(kv_cache_dtype) kv_cache_torch_dtype = ( current_platform.fp8_dtype() - if is_quantized_kv_cache(kv_cache_dtype) - else key_cache.dtype + if fp8_kv_cache + else torch.uint8 if int4_kv_cache else key_cache.dtype ) - if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache( - kv_cache_dtype - ): + if key_cache.dtype != kv_cache_torch_dtype and fp8_kv_cache: # to avoid erounous implicit cast in triton kernel (tl.store to uint8) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) key_cache = key_cache.view(kv_cache_torch_dtype) @@ -371,7 +621,7 @@ def triton_reshape_and_cache_flash( "uint8 is not supported by triton reshape_and_cache_flash" ) - FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype) + FP8_KV_CACHE = fp8_kv_cache assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, @@ -384,6 +634,22 @@ def triton_reshape_and_cache_flash( ) # heuristics instead of autotuning + if int4_kv_cache: + assert ( + not use_head_major_layout + ), "int4 KV cache only supports NHD layout for now" + if k_scale_cache is not None and v_scale_cache is not None: + triton_reshape_and_cache_flash_int4_per_token_head( + key, + value, + key_cache, + value_cache, + k_scale_cache, + v_scale_cache, + slot_mapping, + ) + return + n = num_heads * ((head_size + 1) // 2) TILE_SIZE = min(2048, triton.next_power_of_2(n)) if current_platform.is_rocm() or current_platform.is_xpu(): num_stages = 4 @@ -423,6 +689,7 @@ def triton_reshape_and_cache_flash( x=x, USE_HEAD_MAJOR_LAYOUT=use_head_major_layout, FP8_KV_CACHE=FP8_KV_CACHE, + INT4_KV_CACHE=int4_kv_cache, # autotune parameters TILE_SIZE=TILE_SIZE, num_warps=num_warps, @@ -519,6 +786,11 @@ def triton_reshape_and_cache_flash_diffkv( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): + if kv_cache_uses_int4_packing(kv_cache_dtype): + raise NotImplementedError( + "int4 KV cache is not supported for diffkv layout yet" + ) + num_heads = key.shape[1] head_size_k = key.shape[2] head_size_v = value.shape[2] diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index 150f022f848c..bde8647ee065 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -39,6 +39,7 @@ def _prepare_kv_tile( Q, tensor_scale, scale_cache_ptr, + dim_offsets, physical_block_idx, seq_offset, kv_head_idx, @@ -48,6 +49,7 @@ def _prepare_kv_tile( tile_mask, BLOCK_SIZE: tl.constexpr, KV_QUANT_MODE: tl.constexpr, + DIM_IS_LAST: tl.constexpr, ): """Prepare a loaded KV tile for attention computation. @@ -57,25 +59,46 @@ def _prepare_kv_tile( - ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16). - ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline using the tensor-wide scale. - - ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's + - ``KV_QUANT_MODE == 4`` (int4): unpack signed int4 and dequantize inline + using runtime token/head scales. + - ``KV_QUANT_MODE == 2 or 3`` (per-token-head int8/fp8): cast to Q's dtype and return per-head scales separately — the caller applies them after the dot product for better numerical efficiency. Returns ``(data, token_head_scales)``. *token_head_scales* is only - meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on + meaningful when ``KV_QUANT_MODE == 2 or 3``; callers gate its use on the same constexpr so the compiler eliminates dead code. """ - # KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor, - # 2=int8 per-token-head, 3=fp8 per-token-head + # KV_QUANT_MODE values: + # 0=none, 1=fp8 per-tensor, 2=int8 per-token-head, + # 3=fp8 per-token-head, 4=int4 packed bytes + runtime scales - # Placeholder scales (float32) — never read when KV_QUANT_MODE < 2. + # Placeholder scales (float32) — never read when KV_QUANT_MODE is not + # per-token-head. unused_scales = tile_mask.to(tl.float32) if KV_QUANT_MODE == 1: # FP8 per-tensor if Q.dtype.is_fp8(): return data.to(Q.dtype), unused_scales return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales - if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8) + if KV_QUANT_MODE == 4: # int4 + scale_idx = ( + physical_block_idx * stride_s_blk + + (seq_offset % BLOCK_SIZE) * stride_s_slot + + kv_head_idx * stride_s_head + ) + token_head_scales = tl.load(scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0) + scale_values = ( + token_head_scales[:, None] if DIM_IS_LAST else token_head_scales[None, :] + ) + parity = ( + (dim_offsets[None, :] & 1) if DIM_IS_LAST else (dim_offsets[:, None] & 1) + ) + packed = data.to(tl.uint8).to(tl.int32) + nibble = tl.where(parity == 0, packed & 0xF, packed >> 4) + signed = tl.where(nibble >= 8, nibble - 16, nibble) + return (signed.to(tl.float32) * scale_values).to(Q.dtype), token_head_scales + if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: # per-token-head (int8 or fp8) scale_idx = ( physical_block_idx * stride_s_blk + (seq_offset % BLOCK_SIZE) * stride_s_slot @@ -167,7 +190,7 @@ def kernel_unified_attention_2d( KV_QUANT_MODE: tl.constexpr = 0, FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, - # Per-token-head scale caches (KV_QUANT_MODE >= 2) + # Per-token-head scale caches (KV_QUANT_MODE == 2 or 3) # Shape: [num_blocks, block_size, num_kv_heads] k_scale_cache_ptr=None, v_scale_cache_ptr=None, @@ -308,17 +331,19 @@ def kernel_unified_attention_2d( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE ).to(tl.int64) + packed_dim = offs_d // 2 if KV_QUANT_MODE == 4 else offs_d + v_offset = ( physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 + + packed_dim[None, :] * stride_v_cache_3 + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) k_offset = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + packed_dim[:, None] * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) @@ -333,6 +358,7 @@ def kernel_unified_attention_2d( Q, k_scale, k_scale_cache_ptr, + offs_d, physical_block_idx, seq_offset, kv_head_idx, @@ -342,6 +368,7 @@ def kernel_unified_attention_2d( tile_mask, BLOCK_SIZE, KV_QUANT_MODE, + False, ) # V : (TILE_SIZE, HEAD_SIZE) @@ -355,6 +382,7 @@ def kernel_unified_attention_2d( Q, v_scale, v_scale_cache_ptr, + offs_d, physical_block_idx, seq_offset, kv_head_idx, @@ -364,6 +392,7 @@ def kernel_unified_attention_2d( tile_mask, BLOCK_SIZE, KV_QUANT_MODE, + True, ) # Compute attention mask: causal by default (key <= query) @@ -404,7 +433,7 @@ def kernel_unified_attention_2d( # Per-token-head quant: fuse softmax_scale with per-head k_scale # to avoid a separate BLOCK_M × TILE_SIZE multiply on S. - if KV_QUANT_MODE >= 2: + if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :]) else: S += scale * tl.dot(Q, K) @@ -472,7 +501,7 @@ def kernel_unified_attention_2d( # acc : (BLOCK_M, HEAD_SIZE_PADDED) # Per-token-head quant: apply v_scale to P instead of V. - if KV_QUANT_MODE >= 2: + if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: P_v = (P * v_token_head_scales[None, :]).to(V.dtype) acc += tl.dot(P_v, V) else: @@ -549,7 +578,7 @@ def kernel_unified_attention_3d( mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence # KV cache quantization: 0=none, 1=fp8, 2=per-token-head KV_QUANT_MODE: tl.constexpr = 0, - # Per-token-head scale caches (KV_QUANT_MODE >= 2) + # Per-token-head scale caches (KV_QUANT_MODE == 2 or 3) # Shape: [num_blocks, block_size, num_kv_heads] k_scale_cache_ptr=None, v_scale_cache_ptr=None, @@ -699,17 +728,19 @@ def kernel_unified_attention_3d( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE ).to(tl.int64) + packed_dim = offs_d // 2 if KV_QUANT_MODE == 4 else offs_d + v_offset = ( physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 + + packed_dim[None, :] * stride_v_cache_3 + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) k_offset = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + packed_dim[:, None] * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) @@ -724,6 +755,7 @@ def kernel_unified_attention_3d( Q, k_scale, k_scale_cache_ptr, + offs_d, physical_block_idx, seq_offset, kv_head_idx, @@ -733,6 +765,7 @@ def kernel_unified_attention_3d( tile_mask, BLOCK_SIZE, KV_QUANT_MODE, + False, ) # V : (TILE_SIZE, HEAD_SIZE) @@ -746,6 +779,7 @@ def kernel_unified_attention_3d( Q, v_scale, v_scale_cache_ptr, + offs_d, physical_block_idx, seq_offset, kv_head_idx, @@ -755,6 +789,7 @@ def kernel_unified_attention_3d( tile_mask, BLOCK_SIZE, KV_QUANT_MODE, + True, ) # Compute attention mask: causal by default (key <= query) @@ -795,7 +830,7 @@ def kernel_unified_attention_3d( # Per-token-head quant: fuse softmax_scale with per-head k_scale # to avoid a separate BLOCK_M × TILE_SIZE multiply on S. - if KV_QUANT_MODE >= 2: + if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :]) else: S += scale * tl.dot(Q, K) @@ -863,7 +898,7 @@ def kernel_unified_attention_3d( # acc : (BLOCK_M, HEAD_SIZE_PADDED) # Per-token-head quant: apply v_scale to P instead of V. - if KV_QUANT_MODE >= 2: + if KV_QUANT_MODE == 2 or KV_QUANT_MODE == 3: P_v = (P * v_token_head_scales[None, :]).to(V.dtype) acc += tl.dot(P_v, V) else: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ab5af0f6fb0..41a8f0074fac 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,6 +18,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( + AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -932,6 +933,40 @@ def unify_kv_cache_spec_page_size( return kv_cache_spec max_page_size = max(page_sizes) + + # Runtime-scale quantized attention can add a small fixed per-head overhead + # that breaks exact divisibility between hybrid attention types even when + # the underlying KV topology is otherwise compatible. In that case we can + # preserve the existing block_size and pad the smaller attention pages up to + # the maximum page size instead of hard-failing during grouping. + if ( + all(isinstance(spec, AttentionSpec) for spec in kv_cache_spec.values()) + and any( + spec.kv_quant_mode.uses_runtime_scales + for spec in kv_cache_spec.values() + ) + ): + block_sizes = {spec.block_size for spec in kv_cache_spec.values()} + has_non_divisible_page = any( + spec.page_size_bytes != max_page_size + and max_page_size % spec.page_size_bytes != 0 + for spec in kv_cache_spec.values() + ) + if len(block_sizes) == 1 and has_non_divisible_page: + logger.warning( + "Padding smaller hybrid KV pages up to %d bytes to keep page " + "sizes compatible across attention types.", + max_page_size, + ) + return { + layer_name: ( + layer_spec + if layer_spec.page_size_bytes == max_page_size + else replace(layer_spec, page_size_padded=max_page_size) + ) + for layer_name, layer_spec in kv_cache_spec.items() + } + new_kv_cache_spec = {} for layer_name, layer_spec in kv_cache_spec.items(): if layer_spec.page_size_bytes == max_page_size: @@ -1294,27 +1329,10 @@ def _report_kv_cache_config( vllm_config: The global VllmConfig kv_cache_config: The resolved KV cache configuration """ - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups] - ) - # Log the KV cache size and maximum concurrency. - num_tokens = ( - kv_cache_config.num_blocks - // len(kv_cache_config.kv_cache_groups) - * min_block_size + num_tokens = estimate_token_capacity_for_kv_cache_config( + vllm_config, kv_cache_config ) - dcp_size = vllm_config.parallel_config.decode_context_parallel_size - pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - num_tokens *= pcp_size * dcp_size - logger.info( - "Multiplying the GPU KV cache size by the cp_world_size %d " - "(pcp_world_size %d * dcp_world_size %d).", - pcp_size * dcp_size, - pcp_size, - dcp_size, - ) num_tokens_str = f"{num_tokens:,}" logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local") max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" @@ -1329,6 +1347,30 @@ def _report_kv_cache_config( ) +def estimate_token_capacity_for_kv_cache_config( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> int: + """Estimate total KV tokens that fit in the current GPU cache allocation. + + The log line reports the total number of tokens that can be cached across + all concurrent requests, not the maximum request length for a single + sequence. Reuse the concurrency calculation so the reported token capacity + can exceed ``max_model_len`` when the cache holds multiple full-length + requests. + """ + if not kv_cache_config.kv_cache_groups: + return 0 + + max_model_len = vllm_config.model_config.max_model_len + if max_model_len <= 0: + return 0 + + return int( + get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config) + * max_model_len + ) + + def _max_memory_usage_bytes_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 323862d77ba1..1aab90c14581 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -98,9 +98,9 @@ def _validate_params( self.tokenizer, ) - if ( - params.thinking_token_budget is not None - and self.vllm_config.reasoning_config is None + if params.thinking_token_budget is not None and ( + self.vllm_config.reasoning_config is None + or not self.vllm_config.reasoning_config.enabled ): raise ValueError( "thinking_token_budget is set but reasoning_config is " diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6f8ad8e7d8ef..3dbb76eb3c8b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -38,15 +38,29 @@ class KVQuantMode(IntEnum): FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path) INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8 FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8 + INT4_PER_TOKEN_HEAD = 4 # packed signed int4 with runtime token/head scales @property def is_per_token_head(self) -> bool: """True for any per-token-head quantization mode.""" - return self >= 2 + return self in ( + KVQuantMode.INT8_PER_TOKEN_HEAD, + KVQuantMode.FP8_PER_TOKEN_HEAD, + ) + + @property + def is_int4_packed(self) -> bool: + return self == KVQuantMode.INT4_PER_TOKEN_HEAD + + @property + def uses_runtime_scales(self) -> bool: + return self.is_per_token_head or self.is_int4_packed def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode: """Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`.""" + if kv_cache_dtype == "int4_per_token_head": + return KVQuantMode.INT4_PER_TOKEN_HEAD if kv_cache_dtype == "int8_per_token_head": return KVQuantMode.INT8_PER_TOKEN_HEAD if kv_cache_dtype == "fp8_per_token_head": @@ -65,6 +79,11 @@ def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: return get_kv_quant_mode(kv_cache_dtype).is_per_token_head +def kv_cache_uses_runtime_scales(kv_cache_dtype: str) -> bool: + """Return True if *kv_cache_dtype* computes scales at cache-write time.""" + return get_kv_quant_mode(kv_cache_dtype).uses_runtime_scales + + @dataclass(frozen=True) class KVCacheSpec: """ @@ -135,11 +154,17 @@ def page_size_bytes(self) -> int: @property def real_page_size_bytes(self) -> int: + head_size_physical = self.head_size + if self.kv_quant_mode.is_int4_packed: + packed_head_size = cdiv(self.head_size, 2) + packed_head_size_padded = cdiv(packed_head_size, 4) * 4 + # int4 stores packed bytes plus one float32 runtime scale inline. + head_size_physical = packed_head_size_padded + get_dtype_size(torch.float32) return ( 2 * self.block_size * self.num_kv_heads - * self.head_size + * head_size_physical * get_dtype_size(self.dtype) ) @@ -237,10 +262,21 @@ def merge(cls, specs: list[Self]) -> Self: @property def real_page_size_bytes(self) -> int: + head_size_k = self.head_size + head_size_v = self.head_size_v + if self.kv_quant_mode.is_int4_packed: + packed_head_size_k = cdiv(self.head_size, 2) + packed_head_size_v = cdiv(self.head_size_v, 2) + head_size_k = cdiv(packed_head_size_k, 4) * 4 + get_dtype_size( + torch.float32 + ) + head_size_v = cdiv(packed_head_size_v, 4) * 4 + get_dtype_size( + torch.float32 + ) return ( self.block_size * self.num_kv_heads - * (self.head_size + self.head_size_v) + * (head_size_k + head_size_v) * get_dtype_size(self.dtype) ) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 0d9b67017fd1..1739452b44a0 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -301,7 +301,7 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs # Check if thinking is enabled - self.is_enabled = reasoning_config is not None + self.is_enabled = reasoning_config is not None and reasoning_config.enabled self.reasoning_start_token_ids = getattr( reasoning_config, "reasoning_start_token_ids", [] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 34089a67b3be..53629020c901 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -149,7 +149,39 @@ def _reshape_kv_cache( dtype = kv_cache_spec.dtype raw_tensor = raw_tensor.view(dtype) - raw_tensor = raw_tensor.view(kv_cache_shape) + padded_page_size_bytes = getattr(kv_cache_spec, "page_size_padded", None) + if ( + padded_page_size_bytes is not None + and padded_page_size_bytes > kv_cache_spec.real_page_size_bytes + ): + block_dim = attn_backend.get_kv_cache_block_dim( + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype, + ) + physical_block_dim = kv_cache_stride_order.index(block_dim) + if physical_block_dim != 0: + raise NotImplementedError( + "Padded KV pages are only supported when the backend " + "block dimension is leading." + ) + dtype_size = raw_tensor.element_size() + padded_elems_per_block = padded_page_size_bytes // dtype_size + strides = [0] * len(kv_cache_shape) + strides[-1] = 1 + for i in range(len(kv_cache_shape) - 2, -1, -1): + if i == physical_block_dim: + strides[i] = padded_elems_per_block + else: + strides[i] = strides[i + 1] * kv_cache_shape[i + 1] + raw_tensor = torch.as_strided( + raw_tensor, + size=kv_cache_shape, + stride=tuple(strides), + ) + else: + raw_tensor = raw_tensor.view(kv_cache_shape) kv_caches[layer_name] = raw_tensor.permute(*inv_order) return kv_caches diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 872cb83d2401..c848d3962f90 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6658,12 +6658,44 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) - .view(kv_cache_shape) - .permute(*inv_order) + raw_view = kv_cache_raw_tensors[layer_name].view(dtype) + padded_page_size_bytes = getattr( + kv_cache_spec, "page_size_padded", None ) + semantic_elems_per_block = int(np.prod(kv_cache_shape[1:])) + if ( + padded_page_size_bytes is not None + and padded_page_size_bytes > kv_cache_spec.real_page_size_bytes + ): + block_dim = attn_backend.get_kv_cache_block_dim( + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) + physical_block_dim = kv_cache_stride_order.index(block_dim) + if physical_block_dim != 0: + raise NotImplementedError( + "Padded KV pages are only supported when the " + "backend block dimension is leading." + ) + dtype_size = get_dtype_size(dtype) + padded_elems_per_block = padded_page_size_bytes // dtype_size + strides = [0] * len(kv_cache_shape) + strides[-1] = 1 + for i in range(len(kv_cache_shape) - 2, -1, -1): + if i == physical_block_dim: + strides[i] = padded_elems_per_block + else: + strides[i] = strides[i + 1] * kv_cache_shape[i + 1] + raw_view = torch.as_strided( + raw_view, + size=kv_cache_shape, + stride=tuple(strides), + ) + else: + raw_view = raw_view.view(kv_cache_shape) + kv_caches[layer_name] = raw_view.permute(*inv_order) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 83fc12cb5c3b..7513bc0f300f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -112,7 +112,10 @@ def init_meta( PAGE_SIZE_EL accounts for this ratio so that ``block_id * PAGE_SIZE_EL`` lands at the correct offset. - Only AttentionSpec layers are processed; Mamba layers are skipped. + Only decoder-style AttentionSpec layers are processed here. + Mamba layers are skipped, and encoder-only attention is skipped + because it does not participate in the decode-time KV block reuse + path handled by this zeroer. """ seen_ptrs: set[int] = set() seg_addrs: list[int] = [] @@ -120,7 +123,9 @@ def init_meta( for group in attn_groups_iter: spec = group.kv_cache_spec - if type(spec) is not FullAttentionSpec: + if not isinstance(spec, AttentionSpec) or isinstance( + spec, EncoderOnlyAttentionSpec + ): continue if group.kv_cache_group_id >= len(kernel_block_sizes): continue