Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import math
from types import SimpleNamespace

import pytest
import torch
Expand Down Expand Up @@ -211,6 +212,238 @@ def test_boundary_skip_layers_cap_at_half(self):
assert len(layers) == 8


class TestTurboQuantWorkspaceReservation:
def test_decode_uses_layer_fallback_when_workspace_unavailable(self, monkeypatch):
from vllm.v1.attention.backends.turboquant_attn import (
TurboQuantAttentionImpl,
)

impl = TurboQuantAttentionImpl.__new__(TurboQuantAttentionImpl)
impl.num_heads = 8
impl.head_size = 256
impl.scale = 1.0
impl.max_num_kv_splits = 4
impl.tq_config = SimpleNamespace(
key_mse_bits=4,
key_packed_size=66,
effective_value_quant_bits=4,
key_fp8=False,
norm_correction=True,
)

monkeypatch.setattr(
impl,
"_get_decode_workspace",
lambda batch_size, output_dtype: (None, None, None),
)

captured = {}

def fake_decode_attention(**kwargs):
captured["buf_holder"] = kwargs["buf_holder"]
return torch.empty_like(kwargs["query"])

monkeypatch.setattr(
"vllm.v1.attention.backends.turboquant_attn."
"triton_turboquant_decode_attention",
fake_decode_attention,
)

layer = torch.nn.Module()
query = torch.randn(1, 8, 256)
kv_cache = torch.empty(1, 1, 8, 134, dtype=torch.uint8)
attn_metadata = SimpleNamespace(
block_table=torch.zeros(1, 1, dtype=torch.int32),
seq_lens=torch.ones(1, dtype=torch.int32),
)
Pi = torch.eye(256, dtype=torch.float32)
centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.float32)

output = impl._decode_attention(
query, kv_cache, attn_metadata, Pi, centroids, PiT=Pi, layer=layer
)

assert captured["buf_holder"] is layer
assert output.shape == query.shape

def test_decode_uses_workspace_without_layer_fallback(self, monkeypatch):
from vllm.v1.attention.backends.turboquant_attn import (
TurboQuantAttentionImpl,
)

impl = TurboQuantAttentionImpl.__new__(TurboQuantAttentionImpl)
impl.num_heads = 8
impl.head_size = 256
impl.scale = 1.0
impl.max_num_kv_splits = 4
impl.tq_config = SimpleNamespace(
key_mse_bits=4,
key_packed_size=66,
effective_value_quant_bits=4,
key_fp8=False,
norm_correction=True,
)

query = torch.randn(2, 8, 256, dtype=torch.float16)
workspace = (
torch.empty(2, 8, 4, 257, dtype=torch.float32),
torch.empty(2, 8, 256, dtype=torch.float16),
torch.empty(2, 8, dtype=torch.float32),
)
monkeypatch.setattr(
impl,
"_get_decode_workspace",
lambda batch_size, output_dtype: workspace,
)

captured = {}

def fake_decode_attention(**kwargs):
captured["buf_holder"] = kwargs["buf_holder"]
captured["mid_o_buf"] = kwargs["mid_o_buf"]
captured["output_buf"] = kwargs["output_buf"]
captured["lse_buf"] = kwargs["lse_buf"]
return torch.empty_like(kwargs["query"])

monkeypatch.setattr(
"vllm.v1.attention.backends.turboquant_attn."
"triton_turboquant_decode_attention",
fake_decode_attention,
)

layer = torch.nn.Module()
kv_cache = torch.empty(1, 1, 8, 134, dtype=torch.uint8)
attn_metadata = SimpleNamespace(
block_table=torch.zeros(2, 1, dtype=torch.int32),
seq_lens=torch.ones(2, dtype=torch.int32),
)
Pi = torch.eye(256, dtype=torch.float32)
centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.float32)

output = impl._decode_attention(
query, kv_cache, attn_metadata, Pi, centroids, PiT=Pi, layer=layer
)

assert captured["buf_holder"] is None
assert captured["mid_o_buf"] is workspace[0]
assert captured["output_buf"] is workspace[1]
assert captured["lse_buf"] is workspace[2]
assert output.shape == query.shape

def test_reservation_noops_without_workspace_manager(self, default_vllm_config):
from vllm.v1.attention.backends.turboquant_attn import (
reserve_turboquant_decode_workspace,
)

assert (
reserve_turboquant_decode_workspace(
vllm_config=default_vllm_config,
num_heads=8,
head_size=256,
)
is False
)

def test_workspace_reservation_uses_max_not_sum_for_heterogeneous_heads(
self, default_vllm_config, workspace_init
):
from vllm.v1.attention.backends.turboquant_attn import (
_get_turboquant_decode_workspace_shapes,
reserve_turboquant_decode_workspace,
)
from vllm.v1.worker.workspace import current_workspace_manager

if not torch.cuda.is_available():
pytest.skip("CUDA required for workspace manager tests")

default_vllm_config.scheduler_config.max_num_seqs = 8
default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph = 4

batch_size = 128
num_heads = 8
splits = default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph

reserve_turboquant_decode_workspace(
vllm_config=default_vllm_config,
num_heads=num_heads,
head_size=256,
)
workspace_bytes_256 = current_workspace_manager()._current_workspaces[0].numel()
raw_bytes_256 = sum(
math.prod(shape) * dtype.itemsize
for shape, dtype in _get_turboquant_decode_workspace_shapes(
batch_size=batch_size,
num_heads=num_heads,
head_size=256,
max_num_kv_splits=splits,
)
)

reserve_turboquant_decode_workspace(
vllm_config=default_vllm_config,
num_heads=num_heads,
head_size=512,
)
workspace_bytes_512 = current_workspace_manager()._current_workspaces[0].numel()
raw_bytes_512 = sum(
math.prod(shape) * dtype.itemsize
for shape, dtype in _get_turboquant_decode_workspace_shapes(
batch_size=batch_size,
num_heads=num_heads,
head_size=512,
max_num_kv_splits=splits,
)
)

assert workspace_bytes_256 >= raw_bytes_256
assert workspace_bytes_512 >= raw_bytes_512
assert workspace_bytes_512 < raw_bytes_256 + raw_bytes_512

def test_workspace_acquire_after_lock_no_growth(
self, default_vllm_config, workspace_init
):
from vllm.v1.attention.backends.turboquant_attn import (
_get_turboquant_decode_workspace_shapes,
reserve_turboquant_decode_workspace,
)
from vllm.v1.worker.workspace import (
current_workspace_manager,
lock_workspace,
)

if not torch.cuda.is_available():
pytest.skip("CUDA required for workspace manager tests")

default_vllm_config.scheduler_config.max_num_seqs = 8
default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph = 4

reserve_turboquant_decode_workspace(
vllm_config=default_vllm_config,
num_heads=8,
head_size=256,
)
ws_before = current_workspace_manager()._current_workspaces[0].numel()

lock_workspace()

shapes = _get_turboquant_decode_workspace_shapes(
batch_size=1,
num_heads=8,
head_size=256,
max_num_kv_splits=(
default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
),
output_dtype=torch.float16,
)
mid_o, output, lse = current_workspace_manager().get_simultaneous(*shapes)
ws_after = current_workspace_manager()._current_workspaces[0].numel()

assert ws_before == ws_after, "workspace grew after lock"
assert mid_o.shape == (1, 8, 4, 257)
assert output.shape == (1, 8, 256)
assert lse.shape == (1, 8)


class TestHybridAttentionIndices:
"""Regression tests for boundary protection on hybrid models.

Expand Down
97 changes: 85 additions & 12 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,54 @@
_CONTINUATION_DECODE_THRESHOLD = 128


def _get_turboquant_decode_workspace_shapes(
batch_size: int,
num_heads: int,
head_size: int,
max_num_kv_splits: int,
output_dtype: torch.dtype = torch.float32,
) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]:
"""Workspace views required by one TurboQuant decode kernel call."""
return (
((batch_size, num_heads, max_num_kv_splits, head_size + 1), torch.float32),
((batch_size, num_heads, head_size), output_dtype),
((batch_size, num_heads), torch.float32),
)


def reserve_turboquant_decode_workspace(
*,
vllm_config: Any,
num_heads: int,
head_size: int,
) -> bool:
"""Pre-grow WorkspaceManager for TurboQuant decode scratch buffers.

WorkspaceManager has no separate reservation API; the supported way to
size it is to request the largest views during model initialization or
profiling, before ``lock_workspace()`` freezes workspace size. The output
buffer is reserved as float32 so runtime requests with fp16/bf16 query
dtypes cannot exceed the reservation.
"""
if not is_workspace_manager_initialized():
return False

batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
_CONTINUATION_DECODE_THRESHOLD,
)
max_num_kv_splits = vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
current_workspace_manager().get_simultaneous(
*_get_turboquant_decode_workspace_shapes(
batch_size=batch_size,
num_heads=num_heads,
head_size=head_size,
max_num_kv_splits=max_num_kv_splits,
)
)
Comment thread
lesj0610 marked this conversation as resolved.
return True


def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
"""Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).

Expand Down Expand Up @@ -294,6 +342,11 @@ def __init__(
self.max_num_kv_splits = (
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)
reserve_turboquant_decode_workspace(
vllm_config=vllm_config,
num_heads=self.num_heads,
head_size=self.head_size,
)

def _flash_attn_varlen(
self,
Expand Down Expand Up @@ -360,6 +413,26 @@ def _ensure_on_device(self, layer, device):
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
layer._tq_cached = True

def _get_decode_workspace(
self,
batch_size: int,
output_dtype: torch.dtype,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
if not is_workspace_manager_initialized():
return None, None, None

return tuple(
current_workspace_manager().get_simultaneous(
*_get_turboquant_decode_workspace_shapes(
batch_size=batch_size,
num_heads=self.num_heads,
head_size=self.head_size,
max_num_kv_splits=self.max_num_kv_splits,
output_dtype=output_dtype,
)
)
)

def do_kv_cache_update(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -870,18 +943,18 @@ def _decode_attention(
# Falls back to kernel-internal allocation if workspace unavailable.
B = query.shape[0]
D = self.head_size
S = self.max_num_kv_splits
Hq = self.num_heads
mid_o_buf = output_buf = lse_buf = None
if is_workspace_manager_initialized():
# output_buf in query dtype — matches the in-kernel fp16 cast in stage2.
mid_o_buf, output_buf, lse_buf = (
current_workspace_manager().get_simultaneous(
((B, Hq, S, D + 1), torch.float32),
((B, Hq, D), query.dtype),
((B, Hq), torch.float32),
)
)
assert query.shape[-1] == D
assert Hq == query.shape[1]

# output_buf in query dtype — matches the in-kernel cast in stage2.
mid_o_buf, output_buf, lse_buf = self._get_decode_workspace(B, query.dtype)
buf_holder = (
layer
if layer is not None
and (mid_o_buf is None or output_buf is None or lse_buf is None)
else None
)

result = triton_turboquant_decode_attention(
query=query,
Expand All @@ -900,7 +973,7 @@ def _decode_attention(
mid_o_buf=mid_o_buf,
output_buf=output_buf,
lse_buf=lse_buf,
buf_holder=layer,
buf_holder=buf_holder,
max_num_kv_splits=self.max_num_kv_splits,
)
return result
Loading