diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index b9567195b3a8..f91ed1081758 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -6,6 +6,7 @@ """ import math +from types import SimpleNamespace import pytest import torch @@ -211,6 +212,238 @@ def test_boundary_skip_layers_cap_at_half(self): assert len(layers) == 8 +class TestTurboQuantWorkspaceReservation: + def test_decode_uses_layer_fallback_when_workspace_unavailable(self, monkeypatch): + from vllm.v1.attention.backends.turboquant_attn import ( + TurboQuantAttentionImpl, + ) + + impl = TurboQuantAttentionImpl.__new__(TurboQuantAttentionImpl) + impl.num_heads = 8 + impl.head_size = 256 + impl.scale = 1.0 + impl.max_num_kv_splits = 4 + impl.tq_config = SimpleNamespace( + key_mse_bits=4, + key_packed_size=66, + effective_value_quant_bits=4, + key_fp8=False, + norm_correction=True, + ) + + monkeypatch.setattr( + impl, + "_get_decode_workspace", + lambda batch_size, output_dtype: (None, None, None), + ) + + captured = {} + + def fake_decode_attention(**kwargs): + captured["buf_holder"] = kwargs["buf_holder"] + return torch.empty_like(kwargs["query"]) + + monkeypatch.setattr( + "vllm.v1.attention.backends.turboquant_attn." + "triton_turboquant_decode_attention", + fake_decode_attention, + ) + + layer = torch.nn.Module() + query = torch.randn(1, 8, 256) + kv_cache = torch.empty(1, 1, 8, 134, dtype=torch.uint8) + attn_metadata = SimpleNamespace( + block_table=torch.zeros(1, 1, dtype=torch.int32), + seq_lens=torch.ones(1, dtype=torch.int32), + ) + Pi = torch.eye(256, dtype=torch.float32) + centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.float32) + + output = impl._decode_attention( + query, kv_cache, attn_metadata, Pi, centroids, PiT=Pi, layer=layer + ) + + assert captured["buf_holder"] is layer + assert output.shape == query.shape + + def test_decode_uses_workspace_without_layer_fallback(self, monkeypatch): + from vllm.v1.attention.backends.turboquant_attn import ( + TurboQuantAttentionImpl, + ) + + impl = TurboQuantAttentionImpl.__new__(TurboQuantAttentionImpl) + impl.num_heads = 8 + impl.head_size = 256 + impl.scale = 1.0 + impl.max_num_kv_splits = 4 + impl.tq_config = SimpleNamespace( + key_mse_bits=4, + key_packed_size=66, + effective_value_quant_bits=4, + key_fp8=False, + norm_correction=True, + ) + + query = torch.randn(2, 8, 256, dtype=torch.float16) + workspace = ( + torch.empty(2, 8, 4, 257, dtype=torch.float32), + torch.empty(2, 8, 256, dtype=torch.float16), + torch.empty(2, 8, dtype=torch.float32), + ) + monkeypatch.setattr( + impl, + "_get_decode_workspace", + lambda batch_size, output_dtype: workspace, + ) + + captured = {} + + def fake_decode_attention(**kwargs): + captured["buf_holder"] = kwargs["buf_holder"] + captured["mid_o_buf"] = kwargs["mid_o_buf"] + captured["output_buf"] = kwargs["output_buf"] + captured["lse_buf"] = kwargs["lse_buf"] + return torch.empty_like(kwargs["query"]) + + monkeypatch.setattr( + "vllm.v1.attention.backends.turboquant_attn." + "triton_turboquant_decode_attention", + fake_decode_attention, + ) + + layer = torch.nn.Module() + kv_cache = torch.empty(1, 1, 8, 134, dtype=torch.uint8) + attn_metadata = SimpleNamespace( + block_table=torch.zeros(2, 1, dtype=torch.int32), + seq_lens=torch.ones(2, dtype=torch.int32), + ) + Pi = torch.eye(256, dtype=torch.float32) + centroids = torch.linspace(-1.0, 1.0, 16, dtype=torch.float32) + + output = impl._decode_attention( + query, kv_cache, attn_metadata, Pi, centroids, PiT=Pi, layer=layer + ) + + assert captured["buf_holder"] is None + assert captured["mid_o_buf"] is workspace[0] + assert captured["output_buf"] is workspace[1] + assert captured["lse_buf"] is workspace[2] + assert output.shape == query.shape + + def test_reservation_noops_without_workspace_manager(self, default_vllm_config): + from vllm.v1.attention.backends.turboquant_attn import ( + reserve_turboquant_decode_workspace, + ) + + assert ( + reserve_turboquant_decode_workspace( + vllm_config=default_vllm_config, + num_heads=8, + head_size=256, + ) + is False + ) + + def test_workspace_reservation_uses_max_not_sum_for_heterogeneous_heads( + self, default_vllm_config, workspace_init + ): + from vllm.v1.attention.backends.turboquant_attn import ( + _get_turboquant_decode_workspace_shapes, + reserve_turboquant_decode_workspace, + ) + from vllm.v1.worker.workspace import current_workspace_manager + + if not torch.cuda.is_available(): + pytest.skip("CUDA required for workspace manager tests") + + default_vllm_config.scheduler_config.max_num_seqs = 8 + default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph = 4 + + batch_size = 128 + num_heads = 8 + splits = default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph + + reserve_turboquant_decode_workspace( + vllm_config=default_vllm_config, + num_heads=num_heads, + head_size=256, + ) + workspace_bytes_256 = current_workspace_manager()._current_workspaces[0].numel() + raw_bytes_256 = sum( + math.prod(shape) * dtype.itemsize + for shape, dtype in _get_turboquant_decode_workspace_shapes( + batch_size=batch_size, + num_heads=num_heads, + head_size=256, + max_num_kv_splits=splits, + ) + ) + + reserve_turboquant_decode_workspace( + vllm_config=default_vllm_config, + num_heads=num_heads, + head_size=512, + ) + workspace_bytes_512 = current_workspace_manager()._current_workspaces[0].numel() + raw_bytes_512 = sum( + math.prod(shape) * dtype.itemsize + for shape, dtype in _get_turboquant_decode_workspace_shapes( + batch_size=batch_size, + num_heads=num_heads, + head_size=512, + max_num_kv_splits=splits, + ) + ) + + assert workspace_bytes_256 >= raw_bytes_256 + assert workspace_bytes_512 >= raw_bytes_512 + assert workspace_bytes_512 < raw_bytes_256 + raw_bytes_512 + + def test_workspace_acquire_after_lock_no_growth( + self, default_vllm_config, workspace_init + ): + from vllm.v1.attention.backends.turboquant_attn import ( + _get_turboquant_decode_workspace_shapes, + reserve_turboquant_decode_workspace, + ) + from vllm.v1.worker.workspace import ( + current_workspace_manager, + lock_workspace, + ) + + if not torch.cuda.is_available(): + pytest.skip("CUDA required for workspace manager tests") + + default_vllm_config.scheduler_config.max_num_seqs = 8 + default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph = 4 + + reserve_turboquant_decode_workspace( + vllm_config=default_vllm_config, + num_heads=8, + head_size=256, + ) + ws_before = current_workspace_manager()._current_workspaces[0].numel() + + lock_workspace() + + shapes = _get_turboquant_decode_workspace_shapes( + batch_size=1, + num_heads=8, + head_size=256, + max_num_kv_splits=( + default_vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph + ), + output_dtype=torch.float16, + ) + mid_o, output, lse = current_workspace_manager().get_simultaneous(*shapes) + ws_after = current_workspace_manager()._current_workspaces[0].numel() + + assert ws_before == ws_after, "workspace grew after lock" + assert mid_o.shape == (1, 8, 4, 257) + assert output.shape == (1, 8, 256) + assert lse.shape == (1, 8) + + class TestHybridAttentionIndices: """Regression tests for boundary protection on hybrid models. diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index 53684b4360f7..a252b4546ee6 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -69,6 +69,54 @@ _CONTINUATION_DECODE_THRESHOLD = 128 +def _get_turboquant_decode_workspace_shapes( + batch_size: int, + num_heads: int, + head_size: int, + max_num_kv_splits: int, + output_dtype: torch.dtype = torch.float32, +) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + """Workspace views required by one TurboQuant decode kernel call.""" + return ( + ((batch_size, num_heads, max_num_kv_splits, head_size + 1), torch.float32), + ((batch_size, num_heads, head_size), output_dtype), + ((batch_size, num_heads), torch.float32), + ) + + +def reserve_turboquant_decode_workspace( + *, + vllm_config: Any, + num_heads: int, + head_size: int, +) -> bool: + """Pre-grow WorkspaceManager for TurboQuant decode scratch buffers. + + WorkspaceManager has no separate reservation API; the supported way to + size it is to request the largest views during model initialization or + profiling, before ``lock_workspace()`` freezes workspace size. The output + buffer is reserved as float32 so runtime requests with fp16/bf16 query + dtypes cannot exceed the reservation. + """ + if not is_workspace_manager_initialized(): + return False + + batch_size = max( + vllm_config.scheduler_config.max_num_seqs, + _CONTINUATION_DECODE_THRESHOLD, + ) + max_num_kv_splits = vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph + current_workspace_manager().get_simultaneous( + *_get_turboquant_decode_workspace_shapes( + batch_size=batch_size, + num_heads=num_heads, + head_size=head_size, + max_num_kv_splits=max_num_kv_splits, + ) + ) + return True + + def _build_hadamard(d: int, device_str: str) -> torch.Tensor: """Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device). @@ -294,6 +342,11 @@ def __init__( self.max_num_kv_splits = ( vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph ) + reserve_turboquant_decode_workspace( + vllm_config=vllm_config, + num_heads=self.num_heads, + head_size=self.head_size, + ) def _flash_attn_varlen( self, @@ -360,6 +413,26 @@ def _ensure_on_device(self, layer, device): layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 layer._tq_cached = True + def _get_decode_workspace( + self, + batch_size: int, + output_dtype: torch.dtype, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if not is_workspace_manager_initialized(): + return None, None, None + + return tuple( + current_workspace_manager().get_simultaneous( + *_get_turboquant_decode_workspace_shapes( + batch_size=batch_size, + num_heads=self.num_heads, + head_size=self.head_size, + max_num_kv_splits=self.max_num_kv_splits, + output_dtype=output_dtype, + ) + ) + ) + def do_kv_cache_update( self, layer: torch.nn.Module, @@ -870,18 +943,18 @@ def _decode_attention( # Falls back to kernel-internal allocation if workspace unavailable. B = query.shape[0] D = self.head_size - S = self.max_num_kv_splits Hq = self.num_heads - mid_o_buf = output_buf = lse_buf = None - if is_workspace_manager_initialized(): - # output_buf in query dtype — matches the in-kernel fp16 cast in stage2. - mid_o_buf, output_buf, lse_buf = ( - current_workspace_manager().get_simultaneous( - ((B, Hq, S, D + 1), torch.float32), - ((B, Hq, D), query.dtype), - ((B, Hq), torch.float32), - ) - ) + assert query.shape[-1] == D + assert Hq == query.shape[1] + + # output_buf in query dtype — matches the in-kernel cast in stage2. + mid_o_buf, output_buf, lse_buf = self._get_decode_workspace(B, query.dtype) + buf_holder = ( + layer + if layer is not None + and (mid_o_buf is None or output_buf is None or lse_buf is None) + else None + ) result = triton_turboquant_decode_attention( query=query, @@ -900,7 +973,7 @@ def _decode_attention( mid_o_buf=mid_o_buf, output_buf=output_buf, lse_buf=lse_buf, - buf_holder=layer, + buf_holder=buf_holder, max_num_kv_splits=self.max_num_kv_splits, ) return result