Skip to content
2 changes: 1 addition & 1 deletion tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.78
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_k3v4_nc --max-model-len 4096"
2 changes: 1 addition & 1 deletion tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_k8v4 --max-model-len 4096"
2 changes: 1 addition & 1 deletion tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_3bit_nc --max-model-len 4096"
2 changes: 1 addition & 1 deletion tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B"
accuracy_threshold: 0.80
num_questions: 1319
num_fewshot: 5
server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096"
server_args: "--kv-cache-dtype turboquant_4bit_nc --max-model-len 4096"
9 changes: 7 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,16 @@ class FlashAttentionMetadata:
def _get_sliding_window_configs(
vllm_config: VllmConfig,
) -> set[tuple[int, int] | None]:
"""Get the set of all sliding window configs used in the model."""
"""Get the set of all sliding window configs used in the model.

Only inspects FlashAttentionImpl layers. Other backends (e.g.
TurboQuant, MLA) use their own metadata builders and are skipped.
"""
sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
if not isinstance(layer.impl, FlashAttentionImpl):
continue
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs

Expand Down
53 changes: 44 additions & 9 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
MultipleOf,
)
from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
Expand Down Expand Up @@ -271,13 +272,53 @@ def __init__(
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1

# Detect flash-attn version (FA2/3/4) for prefill paths.
self.fa_version = get_flash_attn_version(head_size=head_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to get_flash_attn_version should include the requires_alibi argument. Passing requires_alibi=alibi_slopes is not None ensures that the backend correctly falls back to FlashAttention 2 if ALiBi slopes are present, as FA3 and FA4 do not currently support them. This maintains consistency with the version detection logic used in FlashAttentionImpl.

Suggested change
self.fa_version = get_flash_attn_version(head_size=head_size)
self.fa_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None, head_size=head_size)

Comment on lines +275 to +276
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mirror SM90 head_dim>256 FA4 override in TurboQuant

This new FA-version selection path only calls get_flash_attn_version(head_size=...), but it does not apply the SM90 head_size > 256 upgrade to FA4 that FlashAttentionImpl already uses. On Hopper, get_flash_attn_version still defaults to FA3, so TurboQuant prefill can be routed into FA3 with unsupported large head dimensions and fail at runtime for those models. Please mirror the same SM90/head-size override logic before assigning self.fa_version.

Useful? React with 👍 / 👎.


# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode).
vllm_config = get_current_vllm_config()
self.max_num_kv_splits = (
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)

def _flash_attn_varlen(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
) -> torch.Tensor:
# fa_utils.get_flash_attn_version() returns None on backends that
# should not pass an explicit fa_version kwarg.
if self.fa_version is None:
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
)
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
fa_version=self.fa_version,
)

def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrix, midpoints).

Expand Down Expand Up @@ -503,16 +544,14 @@ def _prefill_attention(
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
return flash_attn_varlen_func(
return self._flash_attn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
)

# Continuation or no flash_attn: per-request attention.
Expand Down Expand Up @@ -552,16 +591,14 @@ def _prefill_attention(
if _HAS_FLASH_ATTN:
_cu_2[1] = q_len
cu = _cu_2
out = flash_attn_varlen_func(
out = self._flash_attn_varlen(
q=q_seq,
k=k_seq,
v=v_seq,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=q_len,
max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
)
else:
q_t = q_seq.transpose(0, 1).contiguous()
Expand Down Expand Up @@ -726,16 +763,14 @@ def _continuation_prefill(
if _HAS_FLASH_ATTN:
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
return flash_attn_varlen_func(
return self._flash_attn_varlen(
q=query,
k=k_full,
v=v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# SDPA fallback: expand KV for GQA, build causal mask
Expand Down
Loading