Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,6 @@ def __init__(
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)

# Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype
if kv_cache_dtype.startswith("turboquant_"):
self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix)

# for attn backends supporting query quantization
self.query_quant = None
if (
Expand All @@ -410,50 +406,6 @@ def __init__(
else GroupShape.PER_TENSOR,
)

def _init_turboquant_buffers(
self, cache_dtype: str, head_size: int, prefix: str
) -> None:
"""Initialize TurboQuant centroids for Lloyd-Max quantization."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)

tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size)

self.register_buffer(
"_tq_centroids",
get_centroids(head_size, tq_config.centroid_bits),
)
self._tq_config = tq_config

# Pre-allocate decode intermediate buffers so model.to(device) moves
# them to GPU *before* the memory profiler runs. Without this the
# profiler gives all free memory to KV cache blocks and the first
# decode OOMs when these buffers are lazily allocated.
_vllm_cfg = get_current_vllm_config()
B = _vllm_cfg.scheduler_config.max_num_seqs
Hq = self.num_heads
S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph
D = head_size
self.register_buffer(
"_tq_mid_o_buf",
torch.empty(B, Hq, S, D + 1, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_output_buf",
torch.empty(B, Hq, D, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"_tq_lse_buf",
torch.empty(B, Hq, dtype=torch.float32),
persistent=False,
)

def forward(
self,
query: torch.Tensor,
Expand Down
35 changes: 28 additions & 7 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

from vllm.config import get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.quantization.turboquant.centroids import (
get_centroids,
)
from vllm.triton_utils import triton
from vllm.v1.attention.backend import (
AttentionBackend,
Expand All @@ -49,6 +52,10 @@
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store
from vllm.v1.worker.workspace import (
current_workspace_manager,
is_workspace_manager_initialized,
)

_HAS_FLASH_ATTN = is_flash_attn_varlen_func_available()
if _HAS_FLASH_ATTN:
Expand Down Expand Up @@ -336,8 +343,12 @@ def _ensure_on_device(self, layer, device):
layer._tq_PiT = H
layer._tq_Pi = H

c = layer._tq_centroids.to(device=device, dtype=torch.float32)
c_sorted, _ = c.sort()
# Centroids for Lloyd-Max quantization.
layer._tq_centroids = get_centroids(D, self.tq_config.centroid_bits).to(
device=device, dtype=torch.float32
)

c_sorted, _ = layer._tq_centroids.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
layer._tq_cached = True

Expand Down Expand Up @@ -805,12 +816,22 @@ def _decode_attention(
PiT: torch.Tensor | None = None,
layer: torch.nn.Module | None = None,
) -> torch.Tensor:
# Grab cached decode buffers from the layer (lazily allocated).
# Acquire shared decode scratch buffers from WorkspaceManager.
# Layers execute sequentially so one set of buffers is sufficient.
# Falls back to kernel-internal allocation if workspace unavailable.
B = query.shape[0]
D = self.head_size
S = self.max_num_kv_splits
Hq = self.num_heads
mid_o_buf = output_buf = lse_buf = None
if layer is not None:
mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
output_buf = getattr(layer, "_tq_output_buf", None)
lse_buf = getattr(layer, "_tq_lse_buf", None)
if is_workspace_manager_initialized():
mid_o_buf, output_buf, lse_buf = (
current_workspace_manager().get_simultaneous(
((B, Hq, S, D + 1), torch.float32),
((B, Hq, D), torch.float32),
((B, Hq), torch.float32),
)
)

result = triton_turboquant_decode_attention(
query=query,
Expand Down
Loading