Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import math
from types import SimpleNamespace

import pytest
import torch
Expand Down Expand Up @@ -276,6 +277,178 @@ def test_no_hybrid_hints_returns_empty(self):
assert _get_full_attention_layer_indices(mc) == []


class TestTurboQuantDecodeWorkspace:
def test_decode_acquires_workspace_when_manager_is_initialized(self, monkeypatch):
from vllm.v1.attention.ops import triton_turboquant_decode as decode

captured = {}

class FakeWorkspaceManager:
def get_simultaneous(self, *shapes_and_dtypes):
captured["shapes_and_dtypes"] = shapes_and_dtypes
return (
torch.empty(2, 8, 4, 129),
torch.empty(2, 8, 128, dtype=torch.float16),
torch.empty(2, 8),
)

monkeypatch.setattr(decode, "_get_layout", lambda *args: {"unused": True})
monkeypatch.setattr(
"vllm.v1.worker.workspace.is_workspace_manager_initialized",
lambda: True,
)
monkeypatch.setattr(
"vllm.v1.worker.workspace.current_workspace_manager",
lambda: FakeWorkspaceManager(),
)

def stop_after_workspace(*args, **kwargs):
raise RuntimeError("workspace acquired")

monkeypatch.setattr(decode, "_use_fp8_e4b15", stop_after_workspace)

with pytest.raises(RuntimeError, match="workspace acquired"):
decode.triton_turboquant_decode_attention(
query=torch.empty(2, 8, 128, dtype=torch.float16),
kv_cache=torch.empty(1, 16, 8, 102, dtype=torch.uint8),
block_table=torch.zeros(2, 1, dtype=torch.int32),
seq_lens=torch.ones(2, dtype=torch.int32),
Pi=torch.empty(128, 128),
centroids=torch.empty(8),
scale=1.0,
mse_bits=3,
key_packed_size=50,
value_quant_bits=3,
PiT=torch.empty(128, 128),
max_num_kv_splits=4,
)

assert captured["shapes_and_dtypes"] == (
((2, 8, 4, 129), torch.float32),
((2, 8, 128), torch.float16),
((2, 8), torch.float32),
)

def test_capture_model_reserves_turboquant_workspace_before_early_return(self):
from vllm.config import CUDAGraphMode
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

calls = []
runner = GPUModelRunner.__new__(GPUModelRunner)
runner.compilation_config = SimpleNamespace(cudagraph_mode=CUDAGraphMode.NONE)
runner._reserve_turboquant_decode_workspace = lambda: calls.append("reserve")

assert runner.capture_model() == 0
assert calls == ["reserve"]

def test_reserve_turboquant_workspace_checks_all_attention_groups(
self, monkeypatch
):
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

calls = []

class FakeWorkspaceManager:
def get_simultaneous(self, *shapes_and_dtypes):
calls.append(shapes_and_dtypes)

runner = GPUModelRunner.__new__(GPUModelRunner)
runner.cache_config = SimpleNamespace(
cache_dtype="turboquant_3bit_nc", block_size=16
)
runner.scheduler_config = SimpleNamespace(
max_num_seqs=16,
max_num_batched_tokens=4096,
enable_chunked_prefill=True,
)
runner.parallel_config = SimpleNamespace(tensor_parallel_size=2)
runner.dtype = torch.float16
runner.model_config = SimpleNamespace(
max_model_len=8192,
get_num_attention_heads=lambda parallel_config: 8,
get_num_kv_heads=lambda parallel_config: 4,
get_head_size=lambda: 128,
)
runner.vllm_config = SimpleNamespace(
attention_config=SimpleNamespace(tq_max_kv_splits_for_cuda_graph=4)
)
flash_group = SimpleNamespace(
backend=SimpleNamespace(get_name=lambda: "FLASH_ATTN")
)
tq_group = SimpleNamespace(backend=SimpleNamespace(get_name=lambda: "TURBOQUANT"))
runner.attn_groups = [[flash_group], [tq_group]]
tq_group.kv_cache_group_id = 1
runner._kernel_block_sizes = [16, 32]

monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.current_workspace_manager",
lambda: FakeWorkspaceManager(),
)

runner._reserve_turboquant_decode_workspace()

assert calls == [
(
((16, 8, 4, 129), torch.float32),
((16, 8, 128), torch.float16),
((16, 8), torch.float32),
),
(
((1, 4, 8224, 128), torch.float16),
((1, 4, 8224, 128), torch.float16),
),
]

def test_reserve_turboquant_workspace_skips_continuation_prefill_when_disabled(
self, monkeypatch
):
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

calls = []

class FakeWorkspaceManager:
def get_simultaneous(self, *shapes_and_dtypes):
calls.append(shapes_and_dtypes)

runner = GPUModelRunner.__new__(GPUModelRunner)
runner.cache_config = SimpleNamespace(
cache_dtype="turboquant_3bit_nc", block_size=16
)
runner.scheduler_config = SimpleNamespace(
max_num_seqs=16,
max_num_batched_tokens=4096,
enable_chunked_prefill=False,
)
runner.parallel_config = SimpleNamespace(tensor_parallel_size=2)
runner.dtype = torch.float16
runner.model_config = SimpleNamespace(
max_model_len=8192,
get_num_attention_heads=lambda parallel_config: 8,
get_num_kv_heads=lambda parallel_config: 4,
get_head_size=lambda: 128,
)
runner.vllm_config = SimpleNamespace(
attention_config=SimpleNamespace(tq_max_kv_splits_for_cuda_graph=4)
)
runner.attn_groups = [
[SimpleNamespace(backend=SimpleNamespace(get_name=lambda: "TURBOQUANT"))]
]

monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.current_workspace_manager",
lambda: FakeWorkspaceManager(),
)

runner._reserve_turboquant_decode_workspace()

assert calls == [
(
((16, 8, 4, 129), torch.float32),
((16, 8, 128), torch.float16),
((16, 8), torch.float32),
)
]

# ============================================================================
# Centroids tests (CPU-only)
# ============================================================================
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,6 @@ def _decode_attention(
mid_o_buf=mid_o_buf,
output_buf=output_buf,
lse_buf=lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
)
return result
23 changes: 15 additions & 8 deletions vllm/v1/attention/ops/triton_turboquant_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import math
from typing import Any

import torch

Expand Down Expand Up @@ -501,7 +500,6 @@ def triton_turboquant_decode_attention(
mid_o_buf: torch.Tensor | None = None,
output_buf: torch.Tensor | None = None,
lse_buf: torch.Tensor | None = None,
buf_holder: Any = None,
max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph)
) -> torch.Tensor:
"""Launch fused TQ decode attention (Triton stage1 + stage2).
Expand Down Expand Up @@ -529,6 +527,21 @@ def triton_turboquant_decode_attention(

NUM_KV_SPLITS = max_num_kv_splits

if mid_o_buf is None or output_buf is None or lse_buf is None:
from vllm.v1.worker.workspace import (
current_workspace_manager,
is_workspace_manager_initialized,
)

if is_workspace_manager_initialized():
mid_o_buf, output_buf, lse_buf = (
current_workspace_manager().get_simultaneous(
((B, Hq, NUM_KV_SPLITS, D + 1), torch.float32),
((B, Hq, D), query.dtype),
((B, Hq), torch.float32),
)
)

if (
mid_o_buf is not None
and mid_o_buf.shape[0] >= B
Expand All @@ -544,8 +557,6 @@ def triton_turboquant_decode_attention(
dtype=torch.float32,
device=device,
)
if buf_holder is not None:
buf_holder._tq_mid_o_buf = mid_o

# Stage 1: split-KV tiled attention scoring + value accumulation
fp8_e4b15 = _use_fp8_e4b15(device.index or 0)
Expand Down Expand Up @@ -598,14 +609,10 @@ def triton_turboquant_decode_attention(
output = output_buf[:B, :Hq, :D]
else:
output = torch.empty(B, Hq, D, dtype=out_dtype, device=device)
if buf_holder is not None:
buf_holder._tq_output_buf = output
if lse_buf is not None and lse_buf.shape[0] >= B:
lse = lse_buf[:B, :Hq]
else:
lse = torch.empty(B, Hq, dtype=torch.float32, device=device)
if buf_holder is not None:
buf_holder._tq_lse_buf = lse

grid2 = (B, Hq)
_fwd_kernel_stage2[grid2](
Expand Down
58 changes: 57 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
split_attn_metadata,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace
from vllm.v1.worker.workspace import current_workspace_manager, lock_workspace

from .utils import (
AttentionGroup,
Expand All @@ -218,6 +218,8 @@

logger = init_logger(__name__)

_TURBOQUANT_CONTINUATION_DECODE_THRESHOLD = 128

AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
Expand Down Expand Up @@ -6105,6 +6107,8 @@ def profile_cudagraph_memory(self) -> int:

@instrument(span_name="Capture model")
def capture_model(self) -> int:
self._reserve_turboquant_decode_workspace()

if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
Expand Down Expand Up @@ -6194,6 +6198,58 @@ def capture_model(self) -> int:
)
return cuda_graph_size

def _reserve_turboquant_decode_workspace(self) -> None:
if not self.cache_config.cache_dtype.startswith("turboquant_"):
return
if not self.attn_groups:
return

max_num_reqs = self.scheduler_config.max_num_seqs
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_model_len = self.model_config.max_model_len
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
max_num_splits = (
self.vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)

for groups in self.attn_groups:
for group in groups:
if group.backend.get_name() != "TURBOQUANT":
continue

current_workspace_manager().get_simultaneous(
(
(max_num_reqs, num_heads, max_num_splits, head_size + 1),
torch.float32,
),
((max_num_reqs, num_heads, head_size), self.dtype),
((max_num_reqs, num_heads), torch.float32),
)
reserve_continuation_prefill = (
self.scheduler_config.enable_chunked_prefill
and max_num_tokens > _TURBOQUANT_CONTINUATION_DECODE_THRESHOLD
)
if reserve_continuation_prefill:
kernel_block_sizes = getattr(self, "_kernel_block_sizes", None)
group_id = getattr(group, "kv_cache_group_id", 0)
block_size = (
kernel_block_sizes[group_id]
if kernel_block_sizes is not None
and group_id < len(kernel_block_sizes)
else self.cache_config.block_size
)
if block_size is not None:
max_cached_len = max(0, max_model_len - 1)
alloc_len = round_up(max_cached_len, block_size)
cache_buf_shape = (1, num_kv_heads, alloc_len, head_size)
current_workspace_manager().get_simultaneous(
(cache_buf_shape, torch.float16),
(cache_buf_shape, torch.float16),
)
return

def _warmup_and_capture(
self,
desc: BatchDescriptor,
Expand Down
Loading