Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,6 @@ def __init__(
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)

# Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype
if kv_cache_dtype.startswith("turboquant_"):
self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix)

# for attn backends supporting query quantization
self.query_quant = None
if (
Expand All @@ -410,50 +406,6 @@ def __init__(
else GroupShape.PER_TENSOR,
)

def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str
) -> None:
"""Initialize TurboQuant centroids for Lloyd-Max quantization."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)

tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)

self.register_buffer(
"_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits),
)
self._tq_config = tq_config

# Pre-allocate decode intermediate buffers so model.to(device) moves
# them to GPU *before* the memory profiler runs. Without this the
# profiler gives all free memory to KV cache blocks and the first
# decode OOMs when these buffers are lazily allocated.
_vllm_cfg = get_current_vllm_config()
B = _vllm_cfg.scheduler_config.max_num_seqs
Hq = self.num_heads
S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph
D = head_size
self.register_buffer(
"_tq_mid_o_buf",
torch.empty(B, Hq, S, D + 1, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_output_buf",
torch.empty(B, Hq, D, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_lse_buf",
torch.empty(B, Hq, dtype=torch.float32),
persistent=False,
)

def forward(
self,
query: torch.Tensor,
Expand Down
123 changes: 83 additions & 40 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

from vllm.config import get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.triton_utils import triton
from vllm.v1.attention.backend import (
AttentionBackend,
Expand All @@ -49,6 +52,10 @@
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store
from vllm.v1.worker.workspace import (
current_workspace_manager,
is_workspace_manager_initialized,
)

_HAS_FLASH_ATTN = is_flash_attn_varlen_func_available()
if _HAS_FLASH_ATTN:
Expand Down Expand Up @@ -335,9 +342,15 @@ def _ensure_on_device(self, layer, device):
H = _build_hadamard(D, str(device))
layer._tq_PiT = H
layer._tq_Pi = H
# fp16 copy for rotation in continuation prefill path
layer._tq_Pi_half = H.to(torch.float16)

c = layer._tq_centroids.to(device=device, dtype=torch.float32)
c_sorted, _ = c.sort()
# Centroids for Lloyd-Max quantization.
layer._tq_centroids = get_centroids(D, self.tq_config.centroid_bits).to(
device=device, dtype=torch.float32
)

c_sorted, _ = layer._tq_centroids.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
layer._tq_cached = True

Expand Down Expand Up @@ -572,7 +585,17 @@ def _prefill_attention(

# Pre-allocate cu_seqlens for single-request flash_attn calls
# to avoid per-request host→device tensor creation.
_cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
if not hasattr(self, "_cu_2"):
self._cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
# Cache arange on self (avoid per-call kernel launch).
_max_seq = attn_metadata.max_seq_len
_ac: torch.Tensor | None = getattr(self, "_arange_cache", None)
if _ac is None or _ac.shape[0] <= _max_seq:
_ac = torch.arange(
0, _max_seq + 1, device=query.device, dtype=attn_metadata.seq_lens.dtype
)
self._arange_cache = _ac
_arange_cache: torch.Tensor = _ac

for i in range(num_reqs):
q_start = qsl[i]
Expand All @@ -589,8 +612,8 @@ def _prefill_attention(
if q_len == seq_len:
# First-chunk prefill: all K/V are in the current batch.
if _HAS_FLASH_ATTN:
_cu_2[1] = q_len
cu = _cu_2
self._cu_2[1] = q_len
cu = self._cu_2
out = self._flash_attn_varlen(
q=q_seq,
k=k_seq,
Expand Down Expand Up @@ -622,12 +645,8 @@ def _prefill_attention(
if q_len <= _CONTINUATION_DECODE_THRESHOLD:
# Fast path: treat each query as a decode request
# with incremental seq_lens for causal masking.
synth_seq_lens = torch.arange(
cached_len + 1,
seq_len + 1,
device=query.device,
dtype=attn_metadata.seq_lens.dtype,
)
# Slice from pre-built arange (no kernel launch)
synth_seq_lens = _arange_cache[cached_len + 1 : seq_len + 1]
synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1)
out = triton_turboquant_decode_attention(
query=q_seq,
Expand Down Expand Up @@ -695,16 +714,17 @@ def _continuation_prefill(
# Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
alloc_len = math.ceil(cached_len / block_size) * block_size
buf_shape = (1, Hk, alloc_len, D)
k_buf = getattr(layer, "_tq_k_dequant_buf", None)
if k_buf is None or k_buf.shape[2] < alloc_len:
k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
layer._tq_k_dequant_buf = k_buf
layer._tq_v_dequant_buf = v_buf
else:
v_buf = layer._tq_v_dequant_buf
k_cached = k_buf[:, :, :alloc_len, :].zero_()
v_cached = v_buf[:, :, :alloc_len, :].zero_()
# Use WorkspaceManager for dequant buffers.
# Shared across all layers — saves 60× memory at long context.
# Required for CUDA Graph capture (per-layer growth incompatible with CG).
k_buf, v_buf = current_workspace_manager().get_simultaneous(
(buf_shape, torch.float16),
(buf_shape, torch.float16),
)
# Skip .zero_() — kernel writes all positions up to cached_len,
# and we only read [:cached_len] afterwards.
k_cached = k_buf[:, :, :alloc_len, :]
v_cached = v_buf[:, :, :alloc_len, :]

grid = (alloc_len, 1 * Hk)
_tq_full_dequant_kv[grid](
Expand Down Expand Up @@ -740,29 +760,41 @@ def _continuation_prefill(

# Inverse-rotate MSE keys back to original space
if not self.tq_config.key_fp8:
k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
k_flat = k_flat @ Pi
k_cached_trim = (
k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
) # (cached_len, Hk, D)
# fp16 matmul for rotation (2× less bandwidth, uses fp16 tensor cores)
Pi_half = layer._tq_Pi_half
k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D)
k_flat = k_flat @ Pi_half
k_cached_trim = k_flat.reshape(Hk, cached_len, D).transpose(
0, 1
) # (cached_len, Hk, D) — already fp16
else:
k_cached_trim = (
k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
k_cached_trim = k_cached[0, :, :cached_len, :].transpose(
0, 1
) # (cached_len, Hk, D)

v_cached_trim = (
v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
) # (cached_len, Hk, D)
# Skip .contiguous() — the copy into k_full/v_full handles layout
v_cached_trim = v_cached[0, :, :cached_len, :].transpose(0, 1)

# Concatenate cached + current chunk K/V (match query dtype)
# Pre-allocate full K/V buffer, copy into slices (no cat alloc)
qdtype = query.dtype
k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)
k_full = torch.empty(seq_len, Hk, D, dtype=qdtype, device=device)
v_full = torch.empty(seq_len, Hk, D, dtype=qdtype, device=device)
k_full[:cached_len] = k_cached_trim.to(qdtype)
k_full[cached_len:] = key_chunk
v_full[:cached_len] = v_cached_trim.to(qdtype)
v_full[cached_len:] = val_chunk
Comment on lines +781 to +786
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While removing torch.cat is a great improvement, allocating k_full and v_full via torch.empty inside the request loop still triggers an 'allocation storm' and increases peak memory usage, especially at long context (e.g., 1M tokens).

Additionally, the explicit .to(qdtype) on lines 783 and 785 creates redundant temporary tensors. If qdtype is bfloat16 (common for many models), this allocates an extra 8GB per tensor at 1M context before copying into the destination.

Consider using current_workspace_manager().get_simultaneous() to allocate k_full and v_full from the shared pool, and use .copy_() for in-place casting to avoid temporary allocations.

        # Concatenate cached + current chunk K/V (match query dtype)
        # Use WorkspaceManager to avoid per-request allocations and fragmentation.
        qdtype = query.dtype
        k_full, v_full = current_workspace_manager().get_simultaneous(
            ((seq_len, Hk, D), qdtype),
            ((seq_len, Hk, D), qdtype),
        )
        k_full[:cached_len].copy_(k_cached_trim)
        k_full[cached_len:].copy_(key_chunk)
        v_full[:cached_len].copy_(v_cached_trim)
        v_full[cached_len:].copy_(val_chunk)


# Attention: q_len queries attending to seq_len K/V with causal mask
if _HAS_FLASH_ATTN:
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
# Reuse pre-allocated cu_seqlens (avoid host→device transfer)
if not hasattr(self, "_cu_2_q"):
self._cu_2_q = torch.zeros(2, device=device, dtype=torch.int32)
self._cu_2_k = torch.zeros(2, device=device, dtype=torch.int32)
self._cu_2_q[1] = q_len
self._cu_2_k[1] = seq_len
cu_seqlens_q = self._cu_2_q
cu_seqlens_k = self._cu_2_k
return self._flash_attn_varlen(
q=query,
k=k_full,
Expand Down Expand Up @@ -805,12 +837,23 @@ def _decode_attention(
PiT: torch.Tensor | None = None,
layer: torch.nn.Module | None = None,
) -> torch.Tensor:
# Grab cached decode buffers from the layer (lazily allocated).
# Acquire shared decode scratch buffers from WorkspaceManager.
# Layers execute sequentially so one set of buffers is sufficient.
# Falls back to kernel-internal allocation if workspace unavailable.
B = query.shape[0]
D = self.head_size
S = self.max_num_kv_splits
Hq = self.num_heads
mid_o_buf = output_buf = lse_buf = None
if layer is not None:
mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
output_buf = getattr(layer, "_tq_output_buf", None)
lse_buf = getattr(layer, "_tq_lse_buf", None)
if is_workspace_manager_initialized():
# output_buf in query dtype — matches the in-kernel fp16 cast in stage2.
mid_o_buf, output_buf, lse_buf = (
current_workspace_manager().get_simultaneous(
((B, Hq, S, D + 1), torch.float32),
((B, Hq, D), query.dtype),
((B, Hq), torch.float32),
)
)

result = triton_turboquant_decode_attention(
query=query,
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def _fwd_kernel_stage2(
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
OUTPUT_FP16: tl.constexpr = 0,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -587,9 +588,12 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max

result = acc / e_sum
if OUTPUT_FP16:
result = result.to(tl.float16)
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
result,
mask=mask_d,
)
lse_val = e_max + tl.log(e_sum)
Expand Down
13 changes: 10 additions & 3 deletions vllm/v1/attention/ops/triton_turboquant_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,16 @@ def triton_turboquant_decode_attention(
)

# Stage 2: Reduce across KV splits
if output_buf is not None and output_buf.shape[0] >= B:
# Output in query dtype — eliminates float16_copy kernel after stage2
out_dtype = query.dtype
if (
output_buf is not None
and output_buf.shape[0] >= B
and output_buf.dtype == out_dtype
):
output = output_buf[:B, :Hq, :D]
else:
output = torch.empty(B, Hq, D, dtype=torch.float32, device=device)
output = torch.empty(B, Hq, D, dtype=out_dtype, device=device)
if buf_holder is not None:
buf_holder._tq_output_buf = output
if lse_buf is not None and lse_buf.shape[0] >= B:
Expand All @@ -616,8 +622,9 @@ def triton_turboquant_decode_attention(
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=cfg["BLOCK_D"],
Lv=D,
OUTPUT_FP16=1 if out_dtype == torch.float16 else 0,
num_warps=4,
num_stages=2,
)

return output.to(query.dtype)
return output # already in query dtype
Loading