Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions tests/model_executor/test_turboquant_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import sys
from types import ModuleType, SimpleNamespace
from typing import Any

import pytest
import torch

fake_flash_attn: Any = ModuleType("vllm.vllm_flash_attn")
fake_flash_attn.flash_attn_varlen_func = lambda *args, **kwargs: None
fake_flash_attn.get_scheduler_metadata = lambda *args, **kwargs: None
sys.modules.setdefault("vllm.vllm_flash_attn", fake_flash_attn)

fake_flash_attn_interface: Any = ModuleType("vllm.vllm_flash_attn.flash_attn_interface")
fake_flash_attn_interface.is_fa_version_supported = lambda fa_version: False
fake_flash_attn_interface.fa_version_unsupported_reason = lambda fa_version: "test"
sys.modules.setdefault(
"vllm.vllm_flash_attn.flash_attn_interface",
fake_flash_attn_interface,
)

from vllm.model_executor.warmup import kernel_warmup, turboquant_warmup # noqa: E402


class _FakeTQConfig:
key_mse_bits = 4
key_packed_size = 10
effective_value_quant_bits = 4
key_fp8 = False
norm_correction = True
slot_size_aligned = 24


class _FakeTurboQuantAttentionImpl:
def __init__(
self,
*,
num_heads: int = 4,
head_size: int = 8,
num_kv_heads: int = 2,
max_num_kv_splits: int = 32,
scale: float = 0.125,
tq_config: _FakeTQConfig | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.num_kv_groups = num_heads // num_kv_heads
self.max_num_kv_splits = max_num_kv_splits
self.scale = scale
self.tq_config = tq_config or _FakeTQConfig()
self.ensure_calls = 0
self.decode_calls: list[dict[str, Any]] = []

def _ensure_on_device(self, layer: torch.nn.Module, device: torch.device) -> None:
self.ensure_calls += 1
layer._tq_Pi = torch.eye(self.head_size, dtype=torch.float32, device=device)
layer._tq_PiT = torch.eye(self.head_size, dtype=torch.float32, device=device)
layer._tq_centroids = torch.zeros(16, dtype=torch.float32, device=device)

def _decode_attention(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: Any,
Pi: torch.Tensor,
centroids: torch.Tensor,
PiT: torch.Tensor | None = None,
layer: torch.nn.Module | None = None,
) -> torch.Tensor:
self.decode_calls.append(
{
"query": query,
"kv_cache": kv_cache,
"attn_metadata": attn_metadata,
"Pi": Pi,
"centroids": centroids,
"PiT": PiT,
"layer": layer,
}
)
return torch.empty_like(query)


class _FakeAttention(torch.nn.Module):
def __init__(
self,
*,
kv_cache_dtype: str = "turboquant_4bit_nc",
impl: _FakeTurboQuantAttentionImpl | None = None,
) -> None:
super().__init__()
self.kv_cache_dtype = kv_cache_dtype
self.impl = impl or _FakeTurboQuantAttentionImpl()


@pytest.fixture(autouse=True)
def patch_turboquant_types(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(turboquant_warmup, "Attention", _FakeAttention)
monkeypatch.setattr(
turboquant_warmup,
"TurboQuantAttentionImpl",
_FakeTurboQuantAttentionImpl,
)
monkeypatch.setattr(
turboquant_warmup.torch.accelerator,
"synchronize",
lambda: None,
)


def test_turboquant_decode_warmup_skips_non_tq_layers() -> None:
layer = _FakeAttention(kv_cache_dtype="fp8")
model = torch.nn.Sequential(layer)

turboquant_warmup.turboquant_decode_warmup(
model,
device=torch.device("cpu"),
block_size=16,
block_table_stride=8,
max_num_decode_tokens=4,
model_dtype=torch.bfloat16,
)

assert layer.impl.decode_calls == []
assert layer.impl.ensure_calls == 0


def test_turboquant_decode_warmup_builds_runtime_shaped_inputs() -> None:
impl = _FakeTurboQuantAttentionImpl(max_num_kv_splits=64)
model = torch.nn.Sequential(_FakeAttention(impl=impl))

turboquant_warmup.turboquant_decode_warmup(
model,
device=torch.device("cpu"),
block_size=32,
block_table_stride=17,
max_num_decode_tokens=4,
model_dtype=torch.bfloat16,
)

calls = impl.decode_calls
assert len(calls) == 1
call = calls[0]
assert call["query"].shape == (4, impl.num_heads, impl.head_size)
assert call["query"].dtype == torch.bfloat16
assert call["kv_cache"].shape == (
2,
32,
impl.num_kv_heads,
impl.tq_config.slot_size_aligned,
)
assert call["kv_cache"].dtype == torch.uint8
metadata = call["attn_metadata"]
assert metadata.block_table.shape == (4, 17)
assert metadata.block_table.tolist()[0][:2] == [1, 0]
assert metadata.block_table.tolist()[3][:2] == [1, 0]
assert metadata.seq_lens.tolist() == [1, 1, 1, 1]
assert metadata.query_start_loc.tolist() == [0, 1, 2, 3, 4]
assert metadata.num_decodes == 4
assert metadata.num_decode_tokens == 4
assert call["Pi"].shape == (impl.head_size, impl.head_size)
assert call["layer"] is model[0]
assert impl.ensure_calls == 1


def test_turboquant_decode_warmup_deduplicates_compile_key() -> None:
first = _FakeAttention(impl=_FakeTurboQuantAttentionImpl())
second = _FakeAttention(impl=_FakeTurboQuantAttentionImpl())
model = torch.nn.Sequential(first, second)

turboquant_warmup.turboquant_decode_warmup(
model,
device=torch.device("cpu"),
block_size=16,
block_table_stride=8,
max_num_decode_tokens=4,
model_dtype=torch.float16,
)

assert len(first.impl.decode_calls) == 1
assert second.impl.decode_calls == []
assert first.impl.ensure_calls == 1
assert second.impl.ensure_calls == 0


def test_turboquant_decode_warmup_keeps_distinct_compile_keys() -> None:
first = _FakeAttention(
impl=_FakeTurboQuantAttentionImpl(num_heads=4, num_kv_heads=2)
)
second = _FakeAttention(
impl=_FakeTurboQuantAttentionImpl(num_heads=8, num_kv_heads=2)
)
model = torch.nn.Sequential(first, second)

turboquant_warmup.turboquant_decode_warmup(
model,
device=torch.device("cpu"),
block_size=16,
block_table_stride=8,
max_num_decode_tokens=4,
model_dtype=torch.float16,
)

assert len(first.impl.decode_calls) == 1
assert len(second.impl.decode_calls) == 1


def test_kernel_warmup_passes_turboquant_runtime_constants(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls = []

def fake_tq_warmup(model_arg, **kwargs):
calls.append({"model": model_arg, **kwargs})

monkeypatch.setattr(kernel_warmup, "turboquant_decode_warmup", fake_tq_warmup)
monkeypatch.setattr(kernel_warmup, "has_flashinfer", lambda: False)

model = torch.nn.Linear(1, 1)
worker = SimpleNamespace(
get_model=lambda: model,
scheduler_config=SimpleNamespace(
max_num_batched_tokens=1024,
max_num_seqs=7,
),
cache_config=SimpleNamespace(block_size=48),
model_runner=SimpleNamespace(
device=torch.device("cpu"),
dtype=torch.bfloat16,
input_batch=SimpleNamespace(
block_table=SimpleNamespace(
block_tables=[
SimpleNamespace(block_size=16, max_num_blocks_per_req=257)
]
)
),
is_pooling_model=False,
attn_groups=[],
),
vllm_config=SimpleNamespace(
kernel_config=SimpleNamespace(enable_flashinfer_autotune=False)
),
)

kernel_warmup.kernel_warmup(worker)

assert calls == [
{
"model": model,
"device": torch.device("cpu"),
"block_size": 16,
"block_table_stride": 257,
"max_num_decode_tokens": 7,
"model_dtype": torch.bfloat16,
}
]
27 changes: 26 additions & 1 deletion vllm/model_executor/warmup/kernel_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.model_executor.warmup.turboquant_warmup import turboquant_decode_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer
Expand All @@ -25,17 +26,41 @@


def kernel_warmup(worker: "Worker"):
model = worker.get_model()

# Deep GEMM warmup
do_deep_gemm_warmup = (
envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
)
if do_deep_gemm_warmup:
model = worker.get_model()
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)

block_size = worker.cache_config.block_size
block_table_stride = 1
block_tables = worker.model_runner.input_batch.block_table.block_tables
if block_tables:
block_table = block_tables[0]
# V1 may split KV manager blocks into smaller attention-kernel blocks.
# Warmup must use the runtime BlockTable constants or Triton will
# compile a different BLOCK_SIZE/stride variant from real decode.
block_size = block_table.block_size
block_table_stride = block_table.max_num_blocks_per_req
max_num_decode_tokens = min(
worker.scheduler_config.max_num_seqs,
worker.scheduler_config.max_num_batched_tokens,
)
turboquant_decode_warmup(
model,
device=worker.model_runner.device,
block_size=block_size,
block_table_stride=block_table_stride,
max_num_decode_tokens=max_num_decode_tokens,
model_dtype=worker.model_runner.dtype,
)

enable_flashinfer_autotune = (
worker.vllm_config.kernel_config.enable_flashinfer_autotune
)
Expand Down
Loading
Loading