diff --git a/tests/model_executor/test_turboquant_warmup.py b/tests/model_executor/test_turboquant_warmup.py new file mode 100644 index 000000000000..36da95590b7d --- /dev/null +++ b/tests/model_executor/test_turboquant_warmup.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import sys +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +import torch + +fake_flash_attn: Any = ModuleType("vllm.vllm_flash_attn") +fake_flash_attn.flash_attn_varlen_func = lambda *args, **kwargs: None +fake_flash_attn.get_scheduler_metadata = lambda *args, **kwargs: None +sys.modules.setdefault("vllm.vllm_flash_attn", fake_flash_attn) + +fake_flash_attn_interface: Any = ModuleType("vllm.vllm_flash_attn.flash_attn_interface") +fake_flash_attn_interface.is_fa_version_supported = lambda fa_version: False +fake_flash_attn_interface.fa_version_unsupported_reason = lambda fa_version: "test" +sys.modules.setdefault( + "vllm.vllm_flash_attn.flash_attn_interface", + fake_flash_attn_interface, +) + +from vllm.model_executor.warmup import kernel_warmup, turboquant_warmup # noqa: E402 + + +class _FakeTQConfig: + key_mse_bits = 4 + key_packed_size = 10 + effective_value_quant_bits = 4 + key_fp8 = False + norm_correction = True + slot_size_aligned = 24 + + +class _FakeTurboQuantAttentionImpl: + def __init__( + self, + *, + num_heads: int = 4, + head_size: int = 8, + num_kv_heads: int = 2, + max_num_kv_splits: int = 32, + scale: float = 0.125, + tq_config: _FakeTQConfig | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // num_kv_heads + self.max_num_kv_splits = max_num_kv_splits + self.scale = scale + self.tq_config = tq_config or _FakeTQConfig() + self.ensure_calls = 0 + self.decode_calls: list[dict[str, Any]] = [] + + def _ensure_on_device(self, layer: torch.nn.Module, device: torch.device) -> None: + self.ensure_calls += 1 + layer._tq_Pi = torch.eye(self.head_size, dtype=torch.float32, device=device) + layer._tq_PiT = torch.eye(self.head_size, dtype=torch.float32, device=device) + layer._tq_centroids = torch.zeros(16, dtype=torch.float32, device=device) + + def _decode_attention( + self, + query: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: Any, + Pi: torch.Tensor, + centroids: torch.Tensor, + PiT: torch.Tensor | None = None, + layer: torch.nn.Module | None = None, + ) -> torch.Tensor: + self.decode_calls.append( + { + "query": query, + "kv_cache": kv_cache, + "attn_metadata": attn_metadata, + "Pi": Pi, + "centroids": centroids, + "PiT": PiT, + "layer": layer, + } + ) + return torch.empty_like(query) + + +class _FakeAttention(torch.nn.Module): + def __init__( + self, + *, + kv_cache_dtype: str = "turboquant_4bit_nc", + impl: _FakeTurboQuantAttentionImpl | None = None, + ) -> None: + super().__init__() + self.kv_cache_dtype = kv_cache_dtype + self.impl = impl or _FakeTurboQuantAttentionImpl() + + +@pytest.fixture(autouse=True) +def patch_turboquant_types(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(turboquant_warmup, "Attention", _FakeAttention) + monkeypatch.setattr( + turboquant_warmup, + "TurboQuantAttentionImpl", + _FakeTurboQuantAttentionImpl, + ) + monkeypatch.setattr( + turboquant_warmup.torch.accelerator, + "synchronize", + lambda: None, + ) + + +def test_turboquant_decode_warmup_skips_non_tq_layers() -> None: + layer = _FakeAttention(kv_cache_dtype="fp8") + model = torch.nn.Sequential(layer) + + turboquant_warmup.turboquant_decode_warmup( + model, + device=torch.device("cpu"), + block_size=16, + block_table_stride=8, + max_num_decode_tokens=4, + model_dtype=torch.bfloat16, + ) + + assert layer.impl.decode_calls == [] + assert layer.impl.ensure_calls == 0 + + +def test_turboquant_decode_warmup_builds_runtime_shaped_inputs() -> None: + impl = _FakeTurboQuantAttentionImpl(max_num_kv_splits=64) + model = torch.nn.Sequential(_FakeAttention(impl=impl)) + + turboquant_warmup.turboquant_decode_warmup( + model, + device=torch.device("cpu"), + block_size=32, + block_table_stride=17, + max_num_decode_tokens=4, + model_dtype=torch.bfloat16, + ) + + calls = impl.decode_calls + assert len(calls) == 1 + call = calls[0] + assert call["query"].shape == (4, impl.num_heads, impl.head_size) + assert call["query"].dtype == torch.bfloat16 + assert call["kv_cache"].shape == ( + 2, + 32, + impl.num_kv_heads, + impl.tq_config.slot_size_aligned, + ) + assert call["kv_cache"].dtype == torch.uint8 + metadata = call["attn_metadata"] + assert metadata.block_table.shape == (4, 17) + assert metadata.block_table.tolist()[0][:2] == [1, 0] + assert metadata.block_table.tolist()[3][:2] == [1, 0] + assert metadata.seq_lens.tolist() == [1, 1, 1, 1] + assert metadata.query_start_loc.tolist() == [0, 1, 2, 3, 4] + assert metadata.num_decodes == 4 + assert metadata.num_decode_tokens == 4 + assert call["Pi"].shape == (impl.head_size, impl.head_size) + assert call["layer"] is model[0] + assert impl.ensure_calls == 1 + + +def test_turboquant_decode_warmup_deduplicates_compile_key() -> None: + first = _FakeAttention(impl=_FakeTurboQuantAttentionImpl()) + second = _FakeAttention(impl=_FakeTurboQuantAttentionImpl()) + model = torch.nn.Sequential(first, second) + + turboquant_warmup.turboquant_decode_warmup( + model, + device=torch.device("cpu"), + block_size=16, + block_table_stride=8, + max_num_decode_tokens=4, + model_dtype=torch.float16, + ) + + assert len(first.impl.decode_calls) == 1 + assert second.impl.decode_calls == [] + assert first.impl.ensure_calls == 1 + assert second.impl.ensure_calls == 0 + + +def test_turboquant_decode_warmup_keeps_distinct_compile_keys() -> None: + first = _FakeAttention( + impl=_FakeTurboQuantAttentionImpl(num_heads=4, num_kv_heads=2) + ) + second = _FakeAttention( + impl=_FakeTurboQuantAttentionImpl(num_heads=8, num_kv_heads=2) + ) + model = torch.nn.Sequential(first, second) + + turboquant_warmup.turboquant_decode_warmup( + model, + device=torch.device("cpu"), + block_size=16, + block_table_stride=8, + max_num_decode_tokens=4, + model_dtype=torch.float16, + ) + + assert len(first.impl.decode_calls) == 1 + assert len(second.impl.decode_calls) == 1 + + +def test_kernel_warmup_passes_turboquant_runtime_constants( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls = [] + + def fake_tq_warmup(model_arg, **kwargs): + calls.append({"model": model_arg, **kwargs}) + + monkeypatch.setattr(kernel_warmup, "turboquant_decode_warmup", fake_tq_warmup) + monkeypatch.setattr(kernel_warmup, "has_flashinfer", lambda: False) + + model = torch.nn.Linear(1, 1) + worker = SimpleNamespace( + get_model=lambda: model, + scheduler_config=SimpleNamespace( + max_num_batched_tokens=1024, + max_num_seqs=7, + ), + cache_config=SimpleNamespace(block_size=48), + model_runner=SimpleNamespace( + device=torch.device("cpu"), + dtype=torch.bfloat16, + input_batch=SimpleNamespace( + block_table=SimpleNamespace( + block_tables=[ + SimpleNamespace(block_size=16, max_num_blocks_per_req=257) + ] + ) + ), + is_pooling_model=False, + attn_groups=[], + ), + vllm_config=SimpleNamespace( + kernel_config=SimpleNamespace(enable_flashinfer_autotune=False) + ), + ) + + kernel_warmup.kernel_warmup(worker) + + assert calls == [ + { + "model": model, + "device": torch.device("cpu"), + "block_size": 16, + "block_table_stride": 257, + "max_num_decode_tokens": 7, + "model_dtype": torch.bfloat16, + } + ] diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 70abd8a6c503..c8fc47d28fe6 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.model_executor.warmup.turboquant_warmup import turboquant_decode_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer @@ -25,6 +26,8 @@ def kernel_warmup(worker: "Worker"): + model = worker.get_model() + # Deep GEMM warmup do_deep_gemm_warmup = ( envs.VLLM_USE_DEEP_GEMM @@ -32,10 +35,32 @@ def kernel_warmup(worker: "Worker"): and envs.VLLM_DEEP_GEMM_WARMUP != "skip" ) if do_deep_gemm_warmup: - model = worker.get_model() max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) + block_size = worker.cache_config.block_size + block_table_stride = 1 + block_tables = worker.model_runner.input_batch.block_table.block_tables + if block_tables: + block_table = block_tables[0] + # V1 may split KV manager blocks into smaller attention-kernel blocks. + # Warmup must use the runtime BlockTable constants or Triton will + # compile a different BLOCK_SIZE/stride variant from real decode. + block_size = block_table.block_size + block_table_stride = block_table.max_num_blocks_per_req + max_num_decode_tokens = min( + worker.scheduler_config.max_num_seqs, + worker.scheduler_config.max_num_batched_tokens, + ) + turboquant_decode_warmup( + model, + device=worker.model_runner.device, + block_size=block_size, + block_table_stride=block_table_stride, + max_num_decode_tokens=max_num_decode_tokens, + model_dtype=worker.model_runner.dtype, + ) + enable_flashinfer_autotune = ( worker.vllm_config.kernel_config.enable_flashinfer_autotune ) diff --git a/vllm/model_executor/warmup/turboquant_warmup.py b/vllm/model_executor/warmup/turboquant_warmup.py new file mode 100644 index 000000000000..6dd5df36ed01 --- /dev/null +++ b/vllm/model_executor/warmup/turboquant_warmup.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Warm up TurboQuant decode kernels before serving requests.""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.v1.attention.backends.turboquant_attn import ( + TurboQuantAttentionImpl, + TurboQuantMetadata, +) + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class _TurboQuantDecodeWarmupKey: + num_kv_heads: int + head_dim: int + block_size: int + block_table_stride: int + num_kv_splits: int + kv_group_size: int + scale: float + mse_bits: int + key_packed_size: int + value_quant_bits: int + key_fp8: bool + norm_correction: bool + output_fp16: bool + + +def _iter_turboquant_attention_layers( + model: torch.nn.Module, +) -> Iterable[tuple[Attention, TurboQuantAttentionImpl]]: + for layer in model.modules(): + if not isinstance(layer, Attention): + continue + if not layer.kv_cache_dtype.startswith("turboquant_"): + continue + if not isinstance(layer.impl, TurboQuantAttentionImpl): + continue + yield layer, layer.impl + + +def _make_warmup_key( + impl: TurboQuantAttentionImpl, + *, + block_size: int, + block_table_stride: int, + model_dtype: torch.dtype, +) -> _TurboQuantDecodeWarmupKey: + return _TurboQuantDecodeWarmupKey( + num_kv_heads=impl.num_kv_heads, + head_dim=impl.head_size, + block_size=block_size, + # Triton specializes regular scalar stride arguments too. Keep the + # synthetic block table stride equal to the runtime block table stride + # and include it in the dedupe key so warmup covers the same variant. + block_table_stride=block_table_stride, + num_kv_splits=impl.max_num_kv_splits, + kv_group_size=impl.num_kv_groups, + scale=impl.scale, + mse_bits=impl.tq_config.key_mse_bits, + key_packed_size=impl.tq_config.key_packed_size, + value_quant_bits=impl.tq_config.effective_value_quant_bits, + key_fp8=impl.tq_config.key_fp8, + norm_correction=impl.tq_config.norm_correction, + output_fp16=model_dtype == torch.float16, + ) + + +def _warmup_turboquant_decode_layer( + layer: Attention, + impl: TurboQuantAttentionImpl, + *, + device: torch.device, + block_size: int, + block_table_stride: int, + max_num_decode_tokens: int, + model_dtype: torch.dtype, +) -> None: + impl._ensure_on_device(layer, device) + + batch_size = max_num_decode_tokens + query = torch.zeros( + (batch_size, impl.num_heads, impl.head_size), + dtype=model_dtype, + device=device, + ) + kv_cache = torch.zeros( + ( + 2, + block_size, + impl.num_kv_heads, + impl.tq_config.slot_size_aligned, + ), + dtype=torch.uint8, + device=device, + ) + block_table = torch.zeros( + (batch_size, block_table_stride), dtype=torch.int32, device=device + ) + block_table[:, 0] = 1 + seq_lens = torch.ones(batch_size, dtype=torch.int32, device=device) + attn_metadata = TurboQuantMetadata( + seq_lens=seq_lens, + slot_mapping=torch.zeros(batch_size, dtype=torch.long, device=device), + block_table=block_table, + query_start_loc=torch.arange(batch_size + 1, dtype=torch.int32, device=device), + num_actual_tokens=batch_size, + max_query_len=1, + max_seq_len=1, + is_prefill=False, + num_decodes=batch_size, + num_decode_tokens=batch_size, + ) + + # Use the runtime decode helper instead of calling the Triton launcher + # directly. This warms both the decode kernels and the WorkspaceManager + # allocation path before the workspace is locked after CUDA graph capture. + impl._decode_attention( + query=query, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + Pi=layer._tq_Pi, + centroids=layer._tq_centroids, + PiT=layer._tq_PiT, + layer=layer, + ) + + +@torch.inference_mode() +def turboquant_decode_warmup( + model: torch.nn.Module, + *, + device: torch.device, + block_size: int, + block_table_stride: int, + max_num_decode_tokens: int, + model_dtype: torch.dtype, +) -> None: + """Compile TurboQuant decode kernels without running model forward. + + V1 dummy/profile warmup can avoid the TurboQuant decode path, which leaves + `_tq_decode_stage1` and `_tq_decode_stage2` to compile on the first real + decode request. This warmup calls the backend decode path with synthetic + tensors whose launch-time constants match the runtime attention layer. + """ + if max_num_decode_tokens <= 0: + return + + seen: set[_TurboQuantDecodeWarmupKey] = set() + num_warmups = 0 + + for layer, impl in _iter_turboquant_attention_layers(model): + key = _make_warmup_key( + impl, + block_size=block_size, + block_table_stride=block_table_stride, + model_dtype=model_dtype, + ) + if key in seen: + continue + seen.add(key) + _warmup_turboquant_decode_layer( + layer, + impl, + device=device, + block_size=block_size, + block_table_stride=block_table_stride, + max_num_decode_tokens=max_num_decode_tokens, + model_dtype=model_dtype, + ) + num_warmups += 1 + + if num_warmups > 0: + torch.accelerator.synchronize() + logger.info("Warmed up %d TurboQuant decode kernel variant(s).", num_warmups)