diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml index fedb7416960..b9f9a7944f2 100644 --- a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" accuracy_threshold: 0.78 num_questions: 1319 num_fewshot: 5 -server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096" +server_args: "--kv-cache-dtype turboquant_k3v4_nc --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml index 9717333582b..200b570e23d 100644 --- a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" accuracy_threshold: 0.80 num_questions: 1319 num_fewshot: 5 -server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096" +server_args: "--kv-cache-dtype turboquant_k8v4 --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml index 8ece1852625..1c833fe7bf2 100644 --- a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" accuracy_threshold: 0.75 num_questions: 1319 num_fewshot: 5 -server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096" +server_args: "--kv-cache-dtype turboquant_3bit_nc --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml index 9b3a14f9b95..6a7f82b6609 100644 --- a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-4B" accuracy_threshold: 0.80 num_questions: 1319 num_fewshot: 5 -server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096" +server_args: "--kv-cache-dtype turboquant_4bit_nc --max-model-len 4096" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4926851903b..19bcdfdc98e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -255,11 +255,16 @@ class FlashAttentionMetadata: def _get_sliding_window_configs( vllm_config: VllmConfig, ) -> set[tuple[int, int] | None]: - """Get the set of all sliding window configs used in the model.""" + """Get the set of all sliding window configs used in the model. + + Only inspects FlashAttentionImpl layers. Other backends (e.g. + TurboQuant, MLA) use their own metadata builders and are skipped. + """ sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): - assert isinstance(layer.impl, FlashAttentionImpl) + if not isinstance(layer.impl, FlashAttentionImpl): + continue sliding_window_configs.add(layer.impl.sliding_window) return sliding_window_configs diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index e7baff9d899..a0bcc252d85 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -39,6 +39,7 @@ MultipleOf, ) from vllm.v1.attention.backends.fa_utils import ( + get_flash_attn_version, is_flash_attn_varlen_func_available, ) from vllm.v1.attention.backends.utils import split_decodes_and_prefills @@ -271,6 +272,9 @@ def __init__( self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8) self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1 + # Detect flash-attn version (FA2/3/4) for prefill paths. + self.fa_version = get_flash_attn_version(head_size=head_size) + # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph, # and benchmarks show no regression vs dynamic in eager mode). vllm_config = get_current_vllm_config() @@ -278,6 +282,43 @@ def __init__( vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph ) + def _flash_attn_varlen( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + ) -> torch.Tensor: + # fa_utils.get_flash_attn_version() returns None on backends that + # should not pass an explicit fa_version kwarg. + if self.fa_version is None: + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + ) + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + fa_version=self.fa_version, + ) + def _ensure_on_device(self, layer, device): """One-time derivation of TQ buffers (rotation matrix, midpoints). @@ -503,7 +544,7 @@ def _prefill_attention( # max_query_len == max_seq_len means no request has prior cached KV. # Both are Python ints — no GPU sync. if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: - return flash_attn_varlen_func( + return self._flash_attn_varlen( q=query, k=key, v=value, @@ -511,8 +552,6 @@ def _prefill_attention( cu_seqlens_k=attn_metadata.query_start_loc, max_seqlen_q=attn_metadata.max_query_len, max_seqlen_k=attn_metadata.max_query_len, - softmax_scale=self.scale, - causal=True, ) # Continuation or no flash_attn: per-request attention. @@ -552,7 +591,7 @@ def _prefill_attention( if _HAS_FLASH_ATTN: _cu_2[1] = q_len cu = _cu_2 - out = flash_attn_varlen_func( + out = self._flash_attn_varlen( q=q_seq, k=k_seq, v=v_seq, @@ -560,8 +599,6 @@ def _prefill_attention( cu_seqlens_k=cu, max_seqlen_q=q_len, max_seqlen_k=q_len, - softmax_scale=self.scale, - causal=True, ) else: q_t = q_seq.transpose(0, 1).contiguous() @@ -726,7 +763,7 @@ def _continuation_prefill( if _HAS_FLASH_ATTN: cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) - return flash_attn_varlen_func( + return self._flash_attn_varlen( q=query, k=k_full, v=v_full, @@ -734,8 +771,6 @@ def _continuation_prefill( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_len, max_seqlen_k=seq_len, - softmax_scale=self.scale, - causal=True, ) else: # SDPA fallback: expand KV for GQA, build causal mask