diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index b9567195b3a8..87301bc20249 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -6,6 +6,7 @@ """ import math +from types import SimpleNamespace import pytest import torch @@ -276,6 +277,178 @@ def test_no_hybrid_hints_returns_empty(self): assert _get_full_attention_layer_indices(mc) == [] +class TestTurboQuantDecodeWorkspace: + def test_decode_acquires_workspace_when_manager_is_initialized(self, monkeypatch): + from vllm.v1.attention.ops import triton_turboquant_decode as decode + + captured = {} + + class FakeWorkspaceManager: + def get_simultaneous(self, *shapes_and_dtypes): + captured["shapes_and_dtypes"] = shapes_and_dtypes + return ( + torch.empty(2, 8, 4, 129), + torch.empty(2, 8, 128, dtype=torch.float16), + torch.empty(2, 8), + ) + + monkeypatch.setattr(decode, "_get_layout", lambda *args: {"unused": True}) + monkeypatch.setattr( + "vllm.v1.worker.workspace.is_workspace_manager_initialized", + lambda: True, + ) + monkeypatch.setattr( + "vllm.v1.worker.workspace.current_workspace_manager", + lambda: FakeWorkspaceManager(), + ) + + def stop_after_workspace(*args, **kwargs): + raise RuntimeError("workspace acquired") + + monkeypatch.setattr(decode, "_use_fp8_e4b15", stop_after_workspace) + + with pytest.raises(RuntimeError, match="workspace acquired"): + decode.triton_turboquant_decode_attention( + query=torch.empty(2, 8, 128, dtype=torch.float16), + kv_cache=torch.empty(1, 16, 8, 102, dtype=torch.uint8), + block_table=torch.zeros(2, 1, dtype=torch.int32), + seq_lens=torch.ones(2, dtype=torch.int32), + Pi=torch.empty(128, 128), + centroids=torch.empty(8), + scale=1.0, + mse_bits=3, + key_packed_size=50, + value_quant_bits=3, + PiT=torch.empty(128, 128), + max_num_kv_splits=4, + ) + + assert captured["shapes_and_dtypes"] == ( + ((2, 8, 4, 129), torch.float32), + ((2, 8, 128), torch.float16), + ((2, 8), torch.float32), + ) + + def test_capture_model_reserves_turboquant_workspace_before_early_return(self): + from vllm.config import CUDAGraphMode + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + calls = [] + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.compilation_config = SimpleNamespace(cudagraph_mode=CUDAGraphMode.NONE) + runner._reserve_turboquant_decode_workspace = lambda: calls.append("reserve") + + assert runner.capture_model() == 0 + assert calls == ["reserve"] + + def test_reserve_turboquant_workspace_checks_all_attention_groups( + self, monkeypatch + ): + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + calls = [] + + class FakeWorkspaceManager: + def get_simultaneous(self, *shapes_and_dtypes): + calls.append(shapes_and_dtypes) + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.cache_config = SimpleNamespace( + cache_dtype="turboquant_3bit_nc", block_size=16 + ) + runner.scheduler_config = SimpleNamespace( + max_num_seqs=16, + max_num_batched_tokens=4096, + enable_chunked_prefill=True, + ) + runner.parallel_config = SimpleNamespace(tensor_parallel_size=2) + runner.dtype = torch.float16 + runner.model_config = SimpleNamespace( + max_model_len=8192, + get_num_attention_heads=lambda parallel_config: 8, + get_num_kv_heads=lambda parallel_config: 4, + get_head_size=lambda: 128, + ) + runner.vllm_config = SimpleNamespace( + attention_config=SimpleNamespace(tq_max_kv_splits_for_cuda_graph=4) + ) + flash_group = SimpleNamespace( + backend=SimpleNamespace(get_name=lambda: "FLASH_ATTN") + ) + tq_group = SimpleNamespace(backend=SimpleNamespace(get_name=lambda: "TURBOQUANT")) + runner.attn_groups = [[flash_group], [tq_group]] + tq_group.kv_cache_group_id = 1 + runner._kernel_block_sizes = [16, 32] + + monkeypatch.setattr( + "vllm.v1.worker.gpu_model_runner.current_workspace_manager", + lambda: FakeWorkspaceManager(), + ) + + runner._reserve_turboquant_decode_workspace() + + assert calls == [ + ( + ((16, 8, 4, 129), torch.float32), + ((16, 8, 128), torch.float16), + ((16, 8), torch.float32), + ), + ( + ((1, 4, 8224, 128), torch.float16), + ((1, 4, 8224, 128), torch.float16), + ), + ] + + def test_reserve_turboquant_workspace_skips_continuation_prefill_when_disabled( + self, monkeypatch + ): + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + calls = [] + + class FakeWorkspaceManager: + def get_simultaneous(self, *shapes_and_dtypes): + calls.append(shapes_and_dtypes) + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.cache_config = SimpleNamespace( + cache_dtype="turboquant_3bit_nc", block_size=16 + ) + runner.scheduler_config = SimpleNamespace( + max_num_seqs=16, + max_num_batched_tokens=4096, + enable_chunked_prefill=False, + ) + runner.parallel_config = SimpleNamespace(tensor_parallel_size=2) + runner.dtype = torch.float16 + runner.model_config = SimpleNamespace( + max_model_len=8192, + get_num_attention_heads=lambda parallel_config: 8, + get_num_kv_heads=lambda parallel_config: 4, + get_head_size=lambda: 128, + ) + runner.vllm_config = SimpleNamespace( + attention_config=SimpleNamespace(tq_max_kv_splits_for_cuda_graph=4) + ) + runner.attn_groups = [ + [SimpleNamespace(backend=SimpleNamespace(get_name=lambda: "TURBOQUANT"))] + ] + + monkeypatch.setattr( + "vllm.v1.worker.gpu_model_runner.current_workspace_manager", + lambda: FakeWorkspaceManager(), + ) + + runner._reserve_turboquant_decode_workspace() + + assert calls == [ + ( + ((16, 8, 4, 129), torch.float32), + ((16, 8, 128), torch.float16), + ((16, 8), torch.float32), + ) + ] + # ============================================================================ # Centroids tests (CPU-only) # ============================================================================ diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index af2d0fb0830f..2eebd4dc4668 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -872,7 +872,6 @@ def _decode_attention( mid_o_buf=mid_o_buf, output_buf=output_buf, lse_buf=lse_buf, - buf_holder=layer, max_num_kv_splits=self.max_num_kv_splits, ) return result diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 3adaf2610d8d..babbf7e1ba24 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -9,7 +9,6 @@ """ import math -from typing import Any import torch @@ -501,7 +500,6 @@ def triton_turboquant_decode_attention( mid_o_buf: torch.Tensor | None = None, output_buf: torch.Tensor | None = None, lse_buf: torch.Tensor | None = None, - buf_holder: Any = None, max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph) ) -> torch.Tensor: """Launch fused TQ decode attention (Triton stage1 + stage2). @@ -529,6 +527,21 @@ def triton_turboquant_decode_attention( NUM_KV_SPLITS = max_num_kv_splits + if mid_o_buf is None or output_buf is None or lse_buf is None: + from vllm.v1.worker.workspace import ( + current_workspace_manager, + is_workspace_manager_initialized, + ) + + if is_workspace_manager_initialized(): + mid_o_buf, output_buf, lse_buf = ( + current_workspace_manager().get_simultaneous( + ((B, Hq, NUM_KV_SPLITS, D + 1), torch.float32), + ((B, Hq, D), query.dtype), + ((B, Hq), torch.float32), + ) + ) + if ( mid_o_buf is not None and mid_o_buf.shape[0] >= B @@ -544,8 +557,6 @@ def triton_turboquant_decode_attention( dtype=torch.float32, device=device, ) - if buf_holder is not None: - buf_holder._tq_mid_o_buf = mid_o # Stage 1: split-KV tiled attention scoring + value accumulation fp8_e4b15 = _use_fp8_e4b15(device.index or 0) @@ -598,14 +609,10 @@ def triton_turboquant_decode_attention( output = output_buf[:B, :Hq, :D] else: output = torch.empty(B, Hq, D, dtype=out_dtype, device=device) - if buf_holder is not None: - buf_holder._tq_output_buf = output if lse_buf is not None and lse_buf.shape[0] >= B: lse = lse_buf[:B, :Hq] else: lse = torch.empty(B, Hq, dtype=torch.float32, device=device) - if buf_holder is not None: - buf_holder._tq_lse_buf = lse grid2 = (B, Hq) _fwd_kernel_stage2[grid2]( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4975077c0562..cef809ae0faa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -200,7 +200,7 @@ split_attn_metadata, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp -from vllm.v1.worker.workspace import lock_workspace +from vllm.v1.worker.workspace import current_workspace_manager, lock_workspace from .utils import ( AttentionGroup, @@ -218,6 +218,8 @@ logger = init_logger(__name__) +_TURBOQUANT_CONTINUATION_DECODE_THRESHOLD = 128 + AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -6105,6 +6107,8 @@ def profile_cudagraph_memory(self) -> int: @instrument(span_name="Capture model") def capture_model(self) -> int: + self._reserve_turboquant_decode_workspace() + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " @@ -6194,6 +6198,58 @@ def capture_model(self) -> int: ) return cuda_graph_size + def _reserve_turboquant_decode_workspace(self) -> None: + if not self.cache_config.cache_dtype.startswith("turboquant_"): + return + if not self.attn_groups: + return + + max_num_reqs = self.scheduler_config.max_num_seqs + max_num_tokens = self.scheduler_config.max_num_batched_tokens + max_model_len = self.model_config.max_model_len + num_heads = self.model_config.get_num_attention_heads(self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + max_num_splits = ( + self.vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph + ) + + for groups in self.attn_groups: + for group in groups: + if group.backend.get_name() != "TURBOQUANT": + continue + + current_workspace_manager().get_simultaneous( + ( + (max_num_reqs, num_heads, max_num_splits, head_size + 1), + torch.float32, + ), + ((max_num_reqs, num_heads, head_size), self.dtype), + ((max_num_reqs, num_heads), torch.float32), + ) + reserve_continuation_prefill = ( + self.scheduler_config.enable_chunked_prefill + and max_num_tokens > _TURBOQUANT_CONTINUATION_DECODE_THRESHOLD + ) + if reserve_continuation_prefill: + kernel_block_sizes = getattr(self, "_kernel_block_sizes", None) + group_id = getattr(group, "kv_cache_group_id", 0) + block_size = ( + kernel_block_sizes[group_id] + if kernel_block_sizes is not None + and group_id < len(kernel_block_sizes) + else self.cache_config.block_size + ) + if block_size is not None: + max_cached_len = max(0, max_model_len - 1) + alloc_len = round_up(max_cached_len, block_size) + cache_buf_shape = (1, num_kv_heads, alloc_len, head_size) + current_workspace_manager().get_simultaneous( + (cache_buf_shape, torch.float16), + (cache_buf_shape, torch.float16), + ) + return + def _warmup_and_capture( self, desc: BatchDescriptor,